import lab as B
from algebra import Element, Join, Wrapped
from mlkernels import pairwise, elwise, num_elements, Kernel
from .. import PromisedFDD as FDD, _dispatch
__all__ = ["infer_size", "dimensionality"]
@pairwise.dispatch(precedence=1)
def pairwise(k, x: tuple, y: tuple):
return B.block(*[[pairwise(k, xi, yi) for yi in y] for xi in x])
@pairwise.dispatch(precedence=1)
def pairwise(k, x: tuple, y):
return pairwise(k, x, (y,))
@pairwise.dispatch(precedence=1)
def pairwise(k, x, y: tuple):
return pairwise(k, (x,), y)
@elwise.dispatch(precedence=1)
def elwise(k, x: tuple, y: tuple):
if len(x) != len(y):
raise ValueError('"elwise" must be called with similarly sized tuples.')
return B.concat(*[elwise(k, xi, yi) for xi, yi in zip(x, y)], axis=0)
@elwise.dispatch(precedence=1)
def elwise(k, x: tuple, y):
return elwise(k, x, (y,))
@elwise.dispatch(precedence=1)
def elwise(k, x, y: tuple):
return elwise(k, (x,), y)
@num_elements.dispatch
def num_elements(x: tuple):
return sum(map(num_elements, x))
@_dispatch
def infer_size(k: Kernel, x: tuple):
"""Infer the size of `k` evaluated at `x`.
Args:
k (:class:`mlkernels.Kernel`): Kernel to evaluate.
x (input): Input to evaluate kernel at.
Returns:
int: Size of kernel matrix.
"""
return sum([infer_size(k, xi) for xi in x])
@_dispatch
def infer_size(k: Kernel, x: B.Numeric):
return num_elements(x) * dimensionality(k)
[docs]@_dispatch
def infer_size(k: Kernel, x: FDD):
return num_elements(x)
@_dispatch
def dimensionality(k: Join):
"""Infer the output dimensionality of `k`.
Args:
k (:class:`mlkernels.Kernel`): Kernel to get the output dimensionality of.
Returns:
int: Output dimensionality of `k`.
"""
d1 = dimensionality(k[0])
d2 = dimensionality(k[1])
if d1 != d2:
raise RuntimeError(
f"Inferred dimensionalities {d1} and {d2} do not match. "
f"Did you join incompatible elements?"
)
return d1
@_dispatch
def dimensionality(k: Wrapped):
return dimensionality(k[0])
[docs]@_dispatch
def dimensionality(k: Element):
return 1