lab.torch.custom module

lab.torch.custom.as_torch[source]

Convert object to PyTorch.

Parameters
  • x (object) – Object to convert.

  • grad (bool, optional) – Requires gradient. Defaults to False.

Returns

x as a PyTorch object.

Return type

object

lab.torch.custom.torch_register(f, s_f)[source]

Register a function and its sensitivity for PyTorch.

Parameters
  • f (function) – Function to register.

  • s_f (function) – Sensitivity of f.

Returns

PyTorch primitive.

Return type

function