lab.jax.custom module

lab.jax.custom.jax_register(f, s_f)[source]

Register a function and its sensitivity for Jax.

Parameters
  • f (function) – Function to register.

  • s_f (function) – Sensitivity of f.

Returns

Jax function.

Return type

function