zoomy_jax.fvm.solver_imex_jax module

zoomy_jax.fvm.solver_imex_jax module#

JAX IMEX solver: explicit flux + implicit diffusion + implicit source.

Extends HyperbolicSolver (JAX) with: - RK2 for explicit stage when reconstruction_order >= 2 - Implicit diffusion via Crank-Nicolson (DiffusionOperatorJAX) - Implicit source stepping: local (cell-wise Newton) or global (Newton-GMRES)

All operations are JIT-compatible inside jax.lax.while_loop.

class zoomy_jax.fvm.solver_imex_jax.IMEXStats(n_steps=0, source_mode='auto', implicit_calls=0, implicit_time_s=0.0, init_time_s=0.0, compile_time_s=0.0, runtime_only_s=0.0, total_time_s=0.0)#

Bases: object

Statistics from an IMEX solve.

Parameters:
  • n_steps (int) –

  • source_mode (str) –

  • implicit_calls (int) –

  • implicit_time_s (float) –

  • init_time_s (float) –

  • compile_time_s (float) –

  • runtime_only_s (float) –

  • total_time_s (float) –

n_steps: int = 0#
source_mode: str = 'auto'#
implicit_calls: int = 0#
implicit_time_s: float = 0.0#
init_time_s: float = 0.0#
compile_time_s: float = 0.0#
runtime_only_s: float = 0.0#
total_time_s: float = 0.0#
class zoomy_jax.fvm.solver_imex_jax.IMEXSourceSolverJax(**kwargs)#

Bases: DerivativeAwareSolverMixin, HyperbolicSolver

Pure-JAX IMEX solver with implicit diffusion and source stepping.

The entire time integration compiles into a single XLA program via jax.lax.while_loop.

Features: - RK2 for explicit advection when reconstruction_order >= 2 - Crank-Nicolson implicit diffusion via DiffusionOperatorJAX - Local sources -> cell-wise Newton via jax.lax.fori_loop - Global sources -> Newton-GMRES via jax.lax.while_loop

source_mode = 'auto'#
implicit_tol = 1e-08#
implicit_maxiter = 6#
gmres_tol = 1e-07#
gmres_maxiter = 30#
jv_backend = 'ad'#
fd_eps = 1e-07#
create_runtime(Q, Qaux, mesh, model)#

Create runtime.

solve(mesh, model, write_output=False)#

Full IMEX solve: setup + JIT-compiled time loop.

name = 'IMEXSourceSolverJax'#