zoomy_jax.fvm.solver_jax module#
JAX FVM solvers: JIT-compiled time stepping.
- Solver hierarchy (mirrors NumPy):
- HyperbolicSolver (explicit flux + source)
setup_simulation — mesh/model → JAX operators (closures) step — single explicit timestep run_simulation — jax.lax.while_loop over step solve — init + setup + run
Inherits param definitions from NumPy HyperbolicSolverNumpy but
overrides all computational methods with JAX implementations.
- zoomy_jax.fvm.solver_jax.log_callback_hyperbolic(iteration, time, dt, time_stamp, log_every=10)#
Log callback hyperbolic.
- zoomy_jax.fvm.solver_jax.log_callback_poisson(iteration, res)#
Log callback poisson.
- zoomy_jax.fvm.solver_jax.log_callback_execution_time(time)#
Log callback execution time.
- zoomy_jax.fvm.solver_jax.newton_solver(residual)#
Newton solver.
- class zoomy_jax.fvm.solver_jax.HyperbolicSolver(**kwargs)#
Bases:
HyperbolicSolverJAX HyperbolicSolver — JIT-compiled explicit time stepping.
Follows the setup_simulation / step / run_simulation pattern. Inherits param definitions from the NumPy base class.
- create_runtime(Q, Qaux, mesh, model)#
Create JAX runtime: convert mesh and model to JAX-compatible forms.
- update_q(Q, Qaux, mesh, model, parameters)#
JIT-compatible update_variables via vmap (replaces NumPy cell loop).
- get_compute_source(mesh, model)#
Build JIT-compiled source operator.
- get_compute_source_jacobian(mesh, model)#
Build JIT-compiled source Jacobian operator.
- get_apply_boundary_conditions(mesh, model)#
Build JIT-compiled boundary condition operator.
- get_compute_max_abs_eigenvalue(mesh, model)#
Build JIT-compiled max eigenvalue computation.
- get_flux_operator(mesh, model)#
Build flux operator with reconstruction (conservative + nonconservative).
- setup_simulation(mesh, model)#
Build all JAX operators from mesh and model.
Converts mesh to MeshJAX, model to JaxRuntimeModel, and creates closures for flux, source, boundary conditions, and eigenvalue computation. The operators are stored as attributes for use by
stepandrun_simulation.- Returns:
Q, Qaux – Initial state arrays on device.
- Return type:
jnp.ndarray
- step(dt, Q, Qaux)#
Perform a single explicit time step.
- Pipeline:
Reconstruct + flux operator (RK1 or RK2 depending on order)
Source operator (RK1)
Apply boundary conditions
Update Q (e.g. clamp, ramp)
This method is JIT-compatible when called inside
run_simulation.- Parameters:
dt (scalar) – Time step size.
Q (jnp.ndarray, shape (n_vars, n_cells)) – Conservative state.
Qaux (jnp.ndarray, shape (n_aux, n_cells)) – Auxiliary state.
- Returns:
Q_new – Updated state after one step.
- Return type:
jnp.ndarray
- post_step(time, dt, Q, Qold, Qaux)#
Post-step processing: BCs, update_q, update_qaux.
Separated from
stepso that subclasses (e.g. IMEX) can insert implicit solves between the explicit step and the post-processing.- Parameters:
time (scalar) – Current simulation time (after dt advance).
dt (scalar) – Time step size.
Q (jnp.ndarray) – State after explicit step.
Qold (jnp.ndarray) – State before the step (for aux updates).
Qaux (jnp.ndarray) – Auxiliary state before the step.
- Returns:
Q_new, Qaux_new – Fully updated state and auxiliary arrays.
- Return type:
jnp.ndarray
- compute_timestep(Q, Qaux)#
Compute the adaptive time step using the stored eigenvalue operator.
JIT-compatible. Uses
self.compute_dt(from param) with the precomputed eigenvalue operator and min inradius.- Returns:
dt
- Return type:
scalar
- run_simulation(Q, Qaux, write_output=True)#
JIT-compiled time loop using jax.lax.while_loop.
Calls
compute_timestep->step->post_stepin a while_loop untiltime >= time_end.- Parameters:
Q (jnp.ndarray) – Initial state (from
setup_simulation).Qaux (jnp.ndarray) – Initial state (from
setup_simulation).write_output (bool) – Whether to write snapshots to HDF5.
- Returns:
Q, Qaux – Final state.
- Return type:
jnp.ndarray
- solve(mesh, model, write_output=True)#
Full solve: initialize -> setup -> run.
This is the main entry point, compatible with the NumPy solver interface. Calls
setup_simulationthenrun_simulation.
- name = 'HyperbolicSolver'#