zoomy_jax.fvm.reconstruction_jax module

Contents

zoomy_jax.fvm.reconstruction_jax module#

JIT-compatible FVM face reconstruction and diffusion for JAX.

Mirrors the NumPy reconstruction classes in zoomy_core.fvm.reconstruction with the same interface: recon(Q) (Q_L, Q_R).

All operations use JAX primitives (jnp.at[].add, jnp.where) for full JIT and autodiff compatibility.

Classes#

  • ConstantReconstruction: 1st-order piecewise-constant.

  • MUSCLReconstruction: 2nd-order MUSCL with slope limiting.

  • FreeSurfaceMUSCL: MUSCL with wet-dry fallback and h-positivity.

  • DiffusionOperatorJAX: Sparse discrete Laplacian with Crank-Nicolson implicit solve.

class zoomy_jax.fvm.reconstruction_jax.ConstantReconstruction(mesh, dim)#

Bases: object

First-order piecewise-constant reconstruction (JAX).

class zoomy_jax.fvm.reconstruction_jax.MUSCLReconstruction(mesh, dim, limiter='venkatakrishnan')#

Bases: object

Second-order MUSCL reconstruction with slope limiting (JAX, JIT-compatible).

Parameters:
  • mesh (MeshJAX) –

  • dim (int) –

  • limiter ("venkatakrishnan", "barth_jespersen", or "minmod") –

class zoomy_jax.fvm.reconstruction_jax.FreeSurfaceMUSCL(mesh, dim, h_index, eps_wet=0.001, limiter='venkatakrishnan')#

Bases: MUSCLReconstruction

MUSCL with wet-dry fallback for free-surface flows (JAX).

In dry cells (h < eps_wet), falls back to 1st order (φ = 0). Clamps h ≥ 0 at face states after reconstruction.

class zoomy_jax.fvm.reconstruction_jax.DiffusionOperatorJAX(mesh, dim, nu=1.0)#

Bases: object

Sparse discrete diffusion operator for JAX: L(u) = nabla . (nu nabla u).

Assembled once per mesh + viscosity as a dense (nc, nc) matrix. Provides: - explicit(u): L @ u (for explicit stepping) - implicit_solve(u_star, dt): Crank-Nicolson solve

The dense matrix approach is JIT-compatible and works inside jax.lax.while_loop. For typical 1D/2D FVM grids (up to ~1000 cells) this is efficient; for larger grids a matrix-free GMRES variant is also available via implicit_solve_gmres.

explicit(u)#

Compute L @ u[:nc] (for explicit stepping). Returns shape (nc,).

implicit_solve(u_star, dt)#

Crank-Nicolson: (I - dt/2 * L) u^{n+1} = (I + dt/2 * L) u*.

Second-order in time for diffusion. Uses dense linear solve (jnp.linalg.solve) which is fully JIT-compatible.

Parameters:
  • u_star (jnp.ndarray, shape (n_cells,)) – State after explicit advection step (includes ghost cells).

  • dt (scalar) – Time step.

Returns:

u_new – Updated state with ghost cells copied from inner neighbors.

Return type:

jnp.ndarray, shape (n_cells,)

implicit_solve_gmres(u_star, dt, tol=1e-08, maxiter=100)#

Crank-Nicolson via matrix-free GMRES. JIT-compatible.

Use this for larger grids where the dense solve becomes expensive.