Source code for lab.tensorflow.custom

from functools import wraps

import tensorflow as tf
from plum import convert, Dispatcher

from . import B

__all__ = ['tensorflow_register', 'as_tf']

_dispatch = Dispatcher()


@_dispatch(B.Numeric)
def as_tf(x):
    """Convert object to TensorFlow.

    Args:
        x (object): Object to convert.

    Returns:
        object: `x` as a TensorFlow object.
    """
    dtype = convert(B.dtype(x), B.TFDType)
    return tf.constant(x, dtype=dtype)


[docs]@_dispatch(tuple) def as_tf(xs): return tuple([as_tf(x) for x in xs])
def _np_apply(f, out_dtypes, *args, **kw_args): """Apply a NumPy function in TensorFlow. Args: f (function): NumPy function. out_dtypes (list[dtype]): List of data types of the output. *args (object): Argument to `f`. **kw_args (object): Keyword arguments to `f`. Returns: tensor: Result as a TensorFlow operation. """ return tf.py_function(lambda *args_: f(*[arg.numpy() for arg in args_], **kw_args), args, out_dtypes)
[docs]def tensorflow_register(f, s_f): """Register a function and its sensitivity for TensorFlow. Args: f (function): Function to register. s_f (function): Sensitivity of `f`. Returns: function: TensorFlow primitive. """ @wraps(f) def primitive(*args, **kw_args): # TODO: This assumes that the output is of the data type of the first # input. Generally, this is *not* true. How to best approach this? y = _np_apply(f, args[0].dtype, *args, **kw_args) def grad(s_y): # TODO: This assumes that the sensitivities of the inputs are of # the data types of the inputs. Again, generally, this is *not* # true. How to best approach this? return _np_apply(s_f, [arg.dtype for arg in args], *((s_y, y) + args), **kw_args) return y, grad return tf.custom_gradient(primitive)