zoomy_jax.fvm.solver_jax module

zoomy_jax.fvm.solver_jax module#

JAX FVM solvers: JIT-compiled time stepping.

Solver hierarchy (mirrors NumPy):
HyperbolicSolver (explicit flux + source)

setup_simulation — mesh/model → JAX operators (closures) step — single explicit timestep run_simulation — jax.lax.while_loop over step solve — init + setup + run

Inherits param definitions from NumPy HyperbolicSolverNumpy but overrides all computational methods with JAX implementations.

zoomy_jax.fvm.solver_jax.log_callback_hyperbolic(iteration, time, dt, time_stamp, log_every=10)#

Log callback hyperbolic.

zoomy_jax.fvm.solver_jax.log_callback_poisson(iteration, res)#

Log callback poisson.

zoomy_jax.fvm.solver_jax.log_callback_execution_time(time)#

Log callback execution time.

zoomy_jax.fvm.solver_jax.newton_solver(residual)#

Newton solver.

class zoomy_jax.fvm.solver_jax.HyperbolicSolver(**kwargs)#

Bases: HyperbolicSolver

JAX HyperbolicSolver — JIT-compiled explicit time stepping.

Follows the setup_simulation / step / run_simulation pattern. Inherits param definitions from the NumPy base class.

create_runtime(Q, Qaux, mesh, model)#

Create JAX runtime: convert mesh and model to JAX-compatible forms.

update_q(Q, Qaux, mesh, model, parameters)#

JIT-compatible update_variables via vmap (replaces NumPy cell loop).

get_compute_source(mesh, model)#

Build JIT-compiled source operator.

get_compute_source_jacobian(mesh, model)#

Build JIT-compiled source Jacobian operator.

get_apply_boundary_conditions(mesh, model)#

Build JIT-compiled boundary condition operator.

get_compute_max_abs_eigenvalue(mesh, model)#

Build JIT-compiled max eigenvalue computation.

get_flux_operator(mesh, model)#

Build flux operator with reconstruction (conservative + nonconservative).

setup_simulation(mesh, model)#

Build all JAX operators from mesh and model.

Converts mesh to MeshJAX, model to JaxRuntimeModel, and creates closures for flux, source, boundary conditions, and eigenvalue computation. The operators are stored as attributes for use by step and run_simulation.

Returns:

Q, Qaux – Initial state arrays on device.

Return type:

jnp.ndarray

step(dt, Q, Qaux)#

Perform a single explicit time step.

Pipeline:
  1. Reconstruct + flux operator (RK1 or RK2 depending on order)

  2. Source operator (RK1)

  3. Apply boundary conditions

  4. Update Q (e.g. clamp, ramp)

This method is JIT-compatible when called inside run_simulation.

Parameters:
  • dt (scalar) – Time step size.

  • Q (jnp.ndarray, shape (n_vars, n_cells)) – Conservative state.

  • Qaux (jnp.ndarray, shape (n_aux, n_cells)) – Auxiliary state.

Returns:

Q_new – Updated state after one step.

Return type:

jnp.ndarray

post_step(time, dt, Q, Qold, Qaux)#

Post-step processing: BCs, update_q, update_qaux.

Separated from step so that subclasses (e.g. IMEX) can insert implicit solves between the explicit step and the post-processing.

Parameters:
  • time (scalar) – Current simulation time (after dt advance).

  • dt (scalar) – Time step size.

  • Q (jnp.ndarray) – State after explicit step.

  • Qold (jnp.ndarray) – State before the step (for aux updates).

  • Qaux (jnp.ndarray) – Auxiliary state before the step.

Returns:

Q_new, Qaux_new – Fully updated state and auxiliary arrays.

Return type:

jnp.ndarray

compute_timestep(Q, Qaux)#

Compute the adaptive time step using the stored eigenvalue operator.

JIT-compatible. Uses self.compute_dt (from param) with the precomputed eigenvalue operator and min inradius.

Returns:

dt

Return type:

scalar

run_simulation(Q, Qaux, write_output=True)#

JIT-compiled time loop using jax.lax.while_loop.

Calls compute_timestep -> step -> post_step in a while_loop until time >= time_end.

Parameters:
  • Q (jnp.ndarray) – Initial state (from setup_simulation).

  • Qaux (jnp.ndarray) – Initial state (from setup_simulation).

  • write_output (bool) – Whether to write snapshots to HDF5.

Returns:

Q, Qaux – Final state.

Return type:

jnp.ndarray

solve(mesh, model, write_output=True)#

Full solve: initialize -> setup -> run.

This is the main entry point, compatible with the NumPy solver interface. Calls setup_simulation then run_simulation.

name = 'HyperbolicSolver'#
class zoomy_jax.fvm.solver_jax.PoissonSolver(**kwargs)#

Bases: Solver

PoissonSolver. (class).

get_residual(Qaux, Qold, Qauxold, parameters, mesh, model, boundary_operator, time, dt)#

Get residual.

solve(mesh, model, write_output=True)#

Solve.

name = 'PoissonSolver'#