zoomy_jax.transformation.to_jax module

zoomy_jax.transformation.to_jax module#

JAX code transformation: compile symbolic Model/Numerics/Kernel to JAX.

class zoomy_jax.transformation.to_jax.JaxRuntimeModel(model, module=None, printer=None, kernel=None)#

Bases: NumpyRuntimeModel

JAX-backed runtime model compiled from symbolic functions.

Parameters:
  • model (Model) –

  • module (Optional[Dict[str, Callable]]) –

  • printer (Optional[str]) –

module = {'array': <function array>, 'clamp_momentum': <function JaxRuntimeModel.<lambda>>, 'clamp_positive': <function JaxRuntimeModel.<lambda>>, 'conditional': <function JaxRuntimeModel.<lambda>>, 'max_wavespeed': None, 'ones_like': <function ones_like>, 'squeeze': <function squeeze>, 'zeros_like': <function zeros_like>}#
printer = 'jax'#
class zoomy_jax.transformation.to_jax.JaxRuntimeSymbolic(symbolic_obj, module=None, printer=None)#

Bases: NumpyRuntimeSymbolic

JAX-backed runtime wrapper for symbolic registrars (e.g. Numerics, Kernel).

Parameters:
  • module (Optional[Dict[str, Callable]]) –

  • printer (Optional[str]) –

module = {'array': <function array>, 'clamp_momentum': <function JaxRuntimeSymbolic.<lambda>>, 'clamp_positive': <function JaxRuntimeSymbolic.<lambda>>, 'conditional': <function JaxRuntimeSymbolic.<lambda>>, 'max_wavespeed': None, 'ones_like': <function ones_like>, 'squeeze': <function squeeze>, 'zeros_like': <function zeros_like>}#
printer = 'jax'#