Source code for lab.jax.custom

from functools import wraps

import jax.numpy as jnp
from jax import custom_vjp
from jax.core import Primitive
from plum import Dispatcher

from . import B
from ..util import as_tuple

__all__ = ['jax_register']
_dispatch = Dispatcher()


@_dispatch(B.Numeric)
def as_jax(x):
    """Convert object to Jax.

    Args:
        x (object): Object to convert.

    Returns:
        object: `x` as a Jax object.
    """
    return jnp.asarray(x)


@_dispatch(tuple)
def as_jax(xs):
    return tuple([as_jax(x) for x in xs])


def _as_primitive(f):
    def f_wrapped(*args, **kw_args):
        return as_jax(f(*B.to_numpy(args), **kw_args))

    primitive = Primitive(f.__name__)
    primitive.def_impl(f_wrapped)

    # Wrap `primitive.bind` to preserve the metadata of `f`.

    @wraps(f)
    def bind_wrapped(*args, **kw_args):
        return primitive.bind(*args, **kw_args)

    return bind_wrapped


[docs]def jax_register(f, s_f): """Register a function and its sensitivity for Jax. Args: f (function): Function to register. s_f (function): Sensitivity of `f`. Returns: function: Jax function. """ f = _as_primitive(f) s_f = _as_primitive(s_f) f = custom_vjp(f) # Define and register the forward and backward pass. def forward(*args, **kw_args): y = f(*args, **kw_args) return y, (y, args, kw_args) def backward(res, s_y): y, args, kw_args = res return as_tuple(s_f(s_y, y, *args, **kw_args)) f.defvjp(forward, backward) return f