zoomy_jax.transformation.to_jax module#
JAX code transformation: compile symbolic Model/Numerics/Kernel to JAX.
- class zoomy_jax.transformation.to_jax.JaxRuntimeModel(model, module=None, printer=None, kernel=None)#
Bases:
NumpyRuntimeModelJAX-backed runtime model compiled from symbolic functions.
- Parameters:
model (Model) –
module (Optional[Dict[str, Callable]]) –
printer (Optional[str]) –
- module = {'array': <function array>, 'clamp_momentum': <function JaxRuntimeModel.<lambda>>, 'clamp_positive': <function JaxRuntimeModel.<lambda>>, 'conditional': <function JaxRuntimeModel.<lambda>>, 'max_wavespeed': None, 'ones_like': <function ones_like>, 'squeeze': <function squeeze>, 'zeros_like': <function zeros_like>}#
- printer = 'jax'#
- class zoomy_jax.transformation.to_jax.JaxRuntimeSymbolic(symbolic_obj, module=None, printer=None)#
Bases:
NumpyRuntimeSymbolicJAX-backed runtime wrapper for symbolic registrars (e.g. Numerics, Kernel).
- Parameters:
module (Optional[Dict[str, Callable]]) –
printer (Optional[str]) –
- module = {'array': <function array>, 'clamp_momentum': <function JaxRuntimeSymbolic.<lambda>>, 'clamp_positive': <function JaxRuntimeSymbolic.<lambda>>, 'conditional': <function JaxRuntimeSymbolic.<lambda>>, 'max_wavespeed': None, 'ones_like': <function ones_like>, 'squeeze': <function squeeze>, 'zeros_like': <function zeros_like>}#
- printer = 'jax'#