import lab as B
import matrix # noqa
from .. import _dispatch
from ..aggregate import AggregateInput
from ..parallel import Parallel
from ..util import register_composite_coder, register_module
__all__ = ["MapDiagonal"]
[docs]@register_composite_coder
@register_module
class MapDiagonal:
"""Map to the diagonal of the squared space.
Args:
coder (coder): Coder to apply the mapped values to.
Attributes:
coder (coder): Coder to apply the mapped values to.
"""
def __init__(self, coder):
self.coder = coder
@_dispatch
def code(coder: MapDiagonal, xz, z, x, **kw_args):
x, d = _mapdiagonal_duplicate_target(x)
# The encoding might already be on the diagonal. Therefore, only duplicate the
# inputs if the dimensionalities don't line up.
xz = _mapdiagonal_possibly_duplicate_context(xz, d)
return code(coder.coder, xz, z, x, **kw_args)
@_dispatch
def code_track(coder: MapDiagonal, xz, z, x, h, **kw_args):
x, d = _mapdiagonal_duplicate_target(x)
xz = _mapdiagonal_possibly_duplicate_context(xz, d)
return code_track(coder.coder, xz, z, x, h + [(x, d)], **kw_args)
@_dispatch
def recode(coder: MapDiagonal, xz, z, h, **kw_args):
(_, d), h = h[0], h[1:]
xz = _mapdiagonal_possibly_duplicate_context(xz, d)
return recode(coder.coder, xz, z, h, **kw_args)
@_dispatch
def _mapdiagonal_duplicate_target(x: B.Numeric):
return (x, x), 2
@_dispatch
def _mapdiagonal_duplicate_target(x: AggregateInput):
xis, ds = zip(*(_mapdiagonal_duplicate_target(xi) for xi, _ in x))
if not all([d == ds[0] for d in ds[1:]]):
raise NotImplementedError("All data dimensionalities must be equal.")
else:
d = ds[0]
return AggregateInput(*((xi, i) for xi, (_, i) in zip(xis, x))), d
@_dispatch
def _mapdiagonal_possibly_duplicate_context(xz: B.Numeric, d: B.Int):
if B.shape(xz, -2) != d:
return B.concat(xz, xz, axis=-2)
else:
return xz
@_dispatch
def _mapdiagonal_possibly_duplicate_context(xz: tuple, d: B.Int):
if len(xz) != d:
return xz * 2
else:
return xz
@_dispatch
def _mapdiagonal_possibly_duplicate_context(xz: Parallel, d: B.Int):
return Parallel(*(_mapdiagonal_possibly_duplicate_context(xzi, d) for xzi in xz))