import lab as B
import matrix # noqa
from .. import _dispatch
from ..aggregate import AggregateInput
from ..parallel import Parallel
from ..util import register_composite_coder, register_module
__all__ = ["MapDiagonal"]
class MapDiagonal:
"""Map to the diagonal of the squared space.
coder (coder): Coder to apply the mapped values to.
coder (coder): Coder to apply the mapped values to.
def __init__(self, coder):
self.coder = coder
def code(coder: MapDiagonal, xz, z, x, **kw_args):
x, d = _mapdiagonal_duplicate_target(x)
# The encoding might already be on the diagonal. Therefore, only duplicate the
# inputs if the dimensionalities don't line up.
xz = _mapdiagonal_possibly_duplicate_context(xz, d)
return code(coder.coder, xz, z, x, **kw_args)
def code_track(coder: MapDiagonal, xz, z, x, h, **kw_args):
x, d = _mapdiagonal_duplicate_target(x)
xz = _mapdiagonal_possibly_duplicate_context(xz, d)
return code_track(coder.coder, xz, z, x, h + [(x, d)], **kw_args)
def recode(coder: MapDiagonal, xz, z, h, **kw_args):
(_, d), h = h[0], h[1:]
xz = _mapdiagonal_possibly_duplicate_context(xz, d)
return recode(coder.coder, xz, z, h, **kw_args)
def _mapdiagonal_duplicate_target(x: B.Numeric):
return (x, x), 2
def _mapdiagonal_duplicate_target(x: AggregateInput):
xis, ds = zip(*(_mapdiagonal_duplicate_target(xi) for xi, _ in x))
if not all([d == ds[0] for d in ds[1:]]):
raise NotImplementedError("All data dimensionalities must be equal.")
d = ds[0]
return AggregateInput(*((xi, i) for xi, (_, i) in zip(xis, x))), d
def _mapdiagonal_possibly_duplicate_context(xz: B.Numeric, d: B.Int):
if B.shape(xz, -2) != d:
return B.concat(xz, xz, axis=-2)
return xz
def _mapdiagonal_possibly_duplicate_context(xz: tuple, d: B.Int):
if len(xz) != d:
return xz * 2
return xz
def _mapdiagonal_possibly_duplicate_context(xz: Parallel, d: B.Int):
return Parallel(*(_mapdiagonal_possibly_duplicate_context(xzi, d) for xzi in xz))