Author

Ingo Steldermann

Published

July 10, 2025

VAM Tutorial (Simple)

Reference

The following the model is described in the paper:

 @article{Escalante_2024, 
    title={Vertically averaged and moment equations: New derivation, efficient numerical solution and comparison with other physical approximations for modeling non-hydrostatic free surface flows}, 
    volume={504}, 
    ISSN={00219991}, 
    DOI={10.1016/j.jcp.2024.112882}, 
    journal={Journal of Computational Physics}, 
    author={Escalante, C. and Morales De Luna, T. and Cantero-Chinchilla, F. and Castro-Orgaz, O.}, 
    year={2024}, 
    month=may, 
    pages={112882}, 
    language={en} 
}

Imports

Load packages
import os
import numpy as np
import jax
from jax import numpy as jnp
import pytest
from types import SimpleNamespace
from sympy import cos, pi, symbols, Derivative, Function, exp, I, Rational, Derivative, init_printing, Matrix, sqrt, diff
from sympy import Matrix, sqrt, Derivative, Rational
from time import time as gettime
from attr import define, field
from typing import Callable
import param
from functools import partial


from zoomy_jax.fvm.solver_jax import HyperbolicSolver, PoissonSolver, Settings
from zoomy_core.fvm.ode import RK1
import zoomy_core.fvm.timestepping as timestepping
import zoomy_core.fvm.flux as flux
import zoomy_core.fvm.nonconservative_flux as nc_flux
from zoomy_core.model.boundary_conditions import BoundaryCondition
from zoomy_core.model.models.basisfunctions import Basisfunction, Legendre_shifted
from zoomy_core.model.models.basismatrices import Basismatrices

from zoomy_jax.fvm.solver_jax import newton_solver, log_callback_hyperbolic, log_callback_execution_time
from zoomy_core.fvm.ode import RK1


from zoomy_core.model.basemodel import Model
from zoomy_core.misc.misc import Zstruct, ZArray
import zoomy_core.model.initial_conditions as IC
import zoomy_core.model.boundary_conditions as BC
import zoomy_core.misc.io as io
from zoomy_core.mesh.mesh import compute_derivatives
from zoomy_tests.swashes import plots_paper
import zoomy_core.model.analysis as analysis

import zoomy_core.mesh.mesh as petscMesh


init_printing(use_latex=True)
class VAMHyperbolic(Model):
    # Parameters replace attrs fields
    dimension = param.Integer(default=1)

    # Passing int=6 creates default variables q_0...q_5
    variables = param.Parameter(default=6)

    # Aux variables defined by name
    aux_variables = param.Parameter(
        default=["hw2", "p0", "p1", "dbdx", "dhdx", "dhp0dx", "dhp1dx"]
    )

    parameters = param.Parameter(default={"g": 9.81})

    def flux(self):
        fx = Matrix([0 for i in range(self.n_variables)])

        # Access aux variables via dot notation (ZStruct)
        hw2 = self.aux_variables.hw2

        # Access variables by index (generated from default=6)
        h = self.variables[0]
        hu0 = self.variables[1]
        hu1 = self.variables[2]
        hw0 = self.variables[3]
        hw1 = self.variables[4]

        u0 = hu0 / h
        u1 = hu1 / h
        w0 = hw0 / h
        w1 = hw1 / h

        fx[0] = hu0
        fx[1] = hu0 * u0 + Rational(1, 3) * hu1 * u1
        fx[2] = 2 * hu0 * u1
        fx[3] = hu0 * w0 + Rational(1, 3) * hu1 * w1
        fx[4] = hu0 * w1 + u1 * (hw0 + Rational(2, 5) * hw2)

        # FIX: Return ZArray(fx) directly for 1D to get shape (n, 1).
        # ZArray([fx]) would create shape (1, n, 1).
        return ZArray(fx)

    def nonconservative_matrix(self):
        # We need to construct the matrix with shape (n, n, dimension) = (n, n, 1)
        nc_val = Matrix(
            [[0 for i in range(self.n_variables)] for j in range(self.n_variables)]
        )

        hw2 = self.aux_variables.hw2
        h = self.variables[0]
        hw0 = self.variables[3]

        p = self.parameters

        u0 = self.variables[1] / h
        w0 = hw0 / h
        w2 = hw2 / h

        nc_val[1, 0] = p.g * h
        nc_val[1, 5] = p.g * h
        nc_val[2, 2] = -u0
        nc_val[4, 2] = +Rational(1, 5) * w2 - w0

        # FIX: Ensure output is (n, n, 1)
        # ZArray(nc_val) gives (n, n). We explicitly shape it.
        nc = ZArray.zeros(self.n_variables, self.n_variables, self.dimension)
        nc[:, :, 0] = ZArray(nc_val)
        return nc

    def eigenvalues(self):
        ev = Matrix([0 for i in range(self.n_variables)])
        h = self.variables[0]
        hu0 = self.variables[1]
        hu1 = self.variables[2]
        p = self.parameters

        u0 = hu0 / h
        u1 = hu1 / h

        ev[0] = u0
        ev[1] = u0 + 1 / sqrt(3) * u1
        ev[2] = u0 - 1 / sqrt(3) * u1
        ev[3] = u0 + sqrt(p.g * h + u1**2)
        ev[4] = u0 - sqrt(p.g * h + u1**2)
        ev[5] = 0

        return ZArray(ev)

    def source(self):
        R = Matrix([0 for i in range(self.n_variables)])

        p0 = self.aux_variables.p0
        p1 = self.aux_variables.p1
        dbdx = self.aux_variables.dbdx
        dhdx = self.aux_variables.dhdx
        dhp0dx = self.aux_variables.dhp0dx
        dhp1dx = self.aux_variables.dhp1dx

        R[0] = 0
        R[1] = dhp0dx + 2 * p1 * dbdx
        R[2] = dhp1dx - (3 * p0 - p1) * dhdx - 6 * (p0 - p1) * dbdx
        R[3] = -2 * p1
        R[4] = 6 * (p0 - p1)
        R[5] = 0

        return ZArray(-R)

    def constraints(self):
        C = Matrix([0 for i in range(3)])

        x = self.position[0]

        q0 = self.variables[0]
        q1 = self.variables[1]
        q2 = self.variables[2]
        q3 = self.variables[3]
        q4 = self.variables[4]
        q5 = self.variables[5]

        h = q0
        u0 = q1 / h
        u1 = q2 / h
        w0 = q3 / h
        w1 = q4 / h
        b = q5
        w2 = self.aux_variables.hw2 / q0

        C[0] = (
            h * Derivative(u0, x)
            + Rational(1, 3) * Derivative(h * u1, x)
            + Rational(1, 3) * u1 * Derivative(h, x)
            + 2 * (w0 - u0 * Derivative(b, x))
        )
        C[1] = (
            h * Derivative(u0, x)
            + u1 * Derivative(h, x)
            + 2 * (u1 * Derivative(b, x) - w1)
        )
        C[2] = (
            h * Derivative(u0, x)
            + u1 * Derivative(h, x)
            + 2 * (w0 + w2 - u0 * Derivative(b, x))
        )

        return ZArray(C)


class VAMPoisson(Model):
    dimension = param.Integer(default=1)

    # Defined explicitly by name
    variables = param.Parameter(default=["p0", "p1"])

    aux_variables = param.Parameter(
        default=[
            "dp0dx",
            "ddp0dxx",
            "dp1dx",
            "ddp1dxx",
            "d4p0dx4",
            "d4p1dx4",
            "h",
            "dbdx",
            "ddbdxx",
            "dhdx",
            "ddhdxx",
            "u0",
            "du0dx",
            "w0",
            "w1",
            "u1",
            "du1dx",
            "dt",
        ]
    )

    parameters = param.Parameter(default={"g": 9.81})

    def residual(self):
        R = Matrix([0 for i in range(self.n_variables)])

        h = self.aux_variables.h

        # Access variables via dot notation (ZStruct)
        p0 = self.variables.p0
        p1 = self.variables.p1
        dt = self.aux_variables.dt

        dbdx = self.aux_variables.dbdx
        ddbdxx = self.aux_variables.ddbdxx

        dhdx = self.aux_variables.dhdx
        ddhdxx = self.aux_variables.ddhdxx

        dp0dx = self.aux_variables.dp0dx
        dp1dx = self.aux_variables.dp1dx
        ddp0dxx = self.aux_variables.ddp0dxx
        ddp1dxx = self.aux_variables.ddp1dxx

        # Note: Values from the middle state after hyperbolic step
        oldu0 = self.aux_variables.u0
        doldu0dx = self.aux_variables.du0dx
        oldw1 = self.aux_variables.w1
        oldw0 = self.aux_variables.w0
        oldu1 = self.aux_variables.u1
        doldu1dx = self.aux_variables.du1dx

        I1 = (
            -Rational(1, 3)
            * dt
            * (
                -(3 * p0 - p1) * ddhdxx
                - (6 * p0 - 6 * p1) * ddbdxx
                - (3 * dp0dx - dp1dx) * dhdx
                - (6 * dp0dx - 6 * dp1dx) * dbdx
                + h * ddp1dxx
                + p1 * ddhdxx
                + 2 * dhdx * dp1dx
            )
            - 2 * (-dt * (h * dp0dx + p0 * dhdx + 2 * p1 * dbdx) + h * oldu0) * dbdx / h
            + Rational(1, 3)
            * (
                -dt
                * (
                    -(3 * p0 - p1) * dhdx
                    - (6 * p0 - 6 * p1) * dbdx
                    + h * dp1dx
                    + p1 * dhdx
                )
                + h * oldu1
            )
            * dhdx
            / h
            + 2 * (2 * dt * p1 + h * oldw0) / h
            + (
                -(-dt * (h * dp0dx + p0 * dhdx + 2 * p1 * dbdx) + h * oldu0)
                * dhdx
                / h**2
                + (
                    -dt
                    * (
                        h * ddp0dxx
                        + p0 * ddhdxx
                        + 2 * p1 * ddbdxx
                        + 2 * dbdx * dp1dx
                        + 2 * dhdx * dp0dx
                    )
                    + h * doldu0dx
                    + oldu0 * dhdx
                )
                / h
            )
            * h
            + Rational(1, 3) * h * doldu1dx
            + Rational(1, 3) * oldu1 * dhdx
        )

        I2 = (
            -2 * (-dt * (6 * p0 - 6 * p1) + h * oldw1) / h
            + 2
            * (
                -dt
                * (
                    -(3 * p0 - p1) * dhdx
                    - (6 * p0 - 6 * p1) * dbdx
                    + h * dp1dx
                    + p1 * dhdx
                )
                + h * oldu1
            )
            * dbdx
            / h
            + (
                -dt
                * (
                    -(3 * p0 - p1) * dhdx
                    - (6 * p0 - 6 * p1) * dbdx
                    + h * dp1dx
                    + p1 * dhdx
                )
                + h * oldu1
            )
            * dhdx
            / h
            + (
                -(-dt * (h * dp0dx + p0 * dhdx + 2 * p1 * dbdx) + h * oldu0)
                * dhdx
                / h**2
                + (
                    -dt
                    * (
                        h * ddp0dxx
                        + p0 * ddhdxx
                        + 2 * p1 * ddbdxx
                        + 2 * dbdx * dp1dx
                        + 2 * dhdx * dp0dx
                    )
                    + h * doldu0dx
                    + oldu0 * dhdx
                )
                / h
            )
            * h
        )

        R[0] = I1
        R[1] = I2

        return ZArray(R)

    def eigenvalues(self):
        ev = Matrix([0 for i in range(self.n_variables)])
        return ZArray(ev)

Model definition

model = VAMHyperbolic()

Linear stability analysis


analyzer = analysis.ModelAnalyser(model)
h0, b0 = symbols('h_c b_c')
eps = analyzer.get_eps()
h, u0, u1, w0, w1, w2, p0, p1 = analyzer.create_functions_from_list(['h', 'u_0', 'u_1', 'w_0', 'w_1', 'w_2', 'p_0', 'p_1'])
t, x, y, z = analyzer.get_time_space()
Q = Matrix([h0 + eps * h, (h0 + eps * h) * eps * u0, (h0 + eps * h) * eps * u1, (h0 + eps * h) * eps * w0, (h0 + eps * h)* eps * w1, b0])
Qaux = Matrix([(h0 + eps * h) * eps*w2, eps*p0, eps*p1, diff(b0, x), diff(h0 + eps * h, x), diff((h0 + eps * h) * eps * p0, x), diff((h0 + eps * h) * eps * p1, x)])

linearized_system = analyzer.linearize_system(Q, Qaux, source=model.source(), constraints=model.constraints())
analyzer.print_equations()

\[\begin{align*} & h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial t} h{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \left(g \frac{\partial}{\partial X_{0}} h{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial X_{0}} p_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial t} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)}\right) = 0 \\ & h_{c} \left(\frac{\partial}{\partial X_{0}} p_{1}{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial t} u_{1}{\left(t,X_{0},X_{1},X_{2} \right)}\right) = 0 \\ & h_{c} \frac{\partial}{\partial t} w_{0}{\left(t,X_{0},X_{1},X_{2} \right)} - 2 p_{1}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \frac{\partial}{\partial t} w_{1}{\left(t,X_{0},X_{1},X_{2} \right)} + 6 p_{0}{\left(t,X_{0},X_{1},X_{2} \right)} - 6 p_{1}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & \text{True} \\ & 3 h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + h_{c} \frac{\partial}{\partial X_{0}} u_{1}{\left(t,X_{0},X_{1},X_{2} \right)} + 6 w_{0}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} - 2 w_{1}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + 2 w_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + 2 w_{2}{\left(t,X_{0},X_{1},X_{2} \right)} = 0\end{align*}\]

linearized_system = analyzer.delete_equations([5])
analyzer.print_equations()

\[\begin{align*} & h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial t} h{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \left(g \frac{\partial}{\partial X_{0}} h{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial X_{0}} p_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial t} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)}\right) = 0 \\ & h_{c} \left(\frac{\partial}{\partial X_{0}} p_{1}{\left(t,X_{0},X_{1},X_{2} \right)} + \frac{\partial}{\partial t} u_{1}{\left(t,X_{0},X_{1},X_{2} \right)}\right) = 0 \\ & h_{c} \frac{\partial}{\partial t} w_{0}{\left(t,X_{0},X_{1},X_{2} \right)} - 2 p_{1}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \frac{\partial}{\partial t} w_{1}{\left(t,X_{0},X_{1},X_{2} \right)} + 6 p_{0}{\left(t,X_{0},X_{1},X_{2} \right)} - 6 p_{1}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & 3 h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + h_{c} \frac{\partial}{\partial X_{0}} u_{1}{\left(t,X_{0},X_{1},X_{2} \right)} + 6 w_{0}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} - 2 w_{1}{\left(t,X_{0},X_{1},X_{2} \right)} = 0 \\ & h_{c} \frac{\partial}{\partial X_{0}} u_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + 2 w_{0}{\left(t,X_{0},X_{1},X_{2} \right)} + 2 w_{2}{\left(t,X_{0},X_{1},X_{2} \right)} = 0\end{align*}\]

analyzer.insert_plane_wave_ansatz([h, u0, u1, w0, w1, w2, p0, p1])
analyzer.print_equations()

\[\begin{align*} & - i \bar{h} \omega e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} + i \bar{u_0} h_{c} k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} = 0 \\ & h_{c} \left(i \bar{h} g k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} + i \bar{p_0} k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} - i \bar{u_0} \omega e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)}\right) = 0 \\ & h_{c} \left(i \bar{p_1} k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} - i \bar{u_1} \omega e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)}\right) = 0 \\ & - 2 \bar{p_1} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} - i \bar{w_0} h_{c} \omega e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} = 0 \\ & 6 \bar{p_0} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} - 6 \bar{p_1} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} - i \bar{w_1} h_{c} \omega e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} = 0 \\ & 3 i \bar{u_0} h_{c} k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} + i \bar{u_1} h_{c} k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} + 6 \bar{w_0} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} = 0 \\ & i \bar{u_0} h_{c} k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} - 2 \bar{w_1} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} = 0 \\ & i \bar{u_0} h_{c} k_{x} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} + 2 \bar{w_0} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} + 2 \bar{w_2} e^{i \left(X_{0} k_{x} + X_{1} k_{y} + X_{2} k_{z} - \omega t\right)} = 0\end{align*}\]

analyzer.remove_exponential()
analyzer.print_equations()

\[\begin{align*} & i \left(- \bar{h} \omega + \bar{u_0} h_{c} k_{x}\right) = 0 \\ & i h_{c} \left(\bar{h} g k_{x} + \bar{p_0} k_{x} - \bar{u_0} \omega\right) = 0 \\ & i h_{c} \left(\bar{p_1} k_{x} - \bar{u_1} \omega\right) = 0 \\ & 2 \bar{p_1} + i \bar{w_0} h_{c} \omega = 0 \\ & - 6 \bar{p_0} + 6 \bar{p_1} + i \bar{w_1} h_{c} \omega = 0 \\ & 3 i \bar{u_0} h_{c} k_{x} + i \bar{u_1} h_{c} k_{x} + 6 \bar{w_0} = 0 \\ & i \bar{u_0} h_{c} k_{x} - 2 \bar{w_1} = 0 \\ & i \bar{u_0} h_{c} k_{x} + 2 \bar{w_0} + 2 \bar{w_2} = 0\end{align*}\]

dispersion_relation = analyzer.solve_for_dispersion_relation()
dispersion_relation

\(\displaystyle \left[ 0, \ - 2 \sqrt{3} k_{x} \sqrt{\frac{g h_{c} \left(h_{c}^{2} k_{x}^{2} + 12\right)}{h_{c}^{4} k_{x}^{4} + 60 h_{c}^{2} k_{x}^{2} + 144}}, \ 2 \sqrt{3} k_{x} \sqrt{\frac{g h_{c} \left(h_{c}^{2} k_{x}^{2} + 12\right)}{h_{c}^{4} k_{x}^{4} + 60 h_{c}^{2} k_{x}^{2} + 144}}\right]\)

group_velocity = dispersion_relation[1]**2 / model.parameters.g / h0
group_velocity

\(\displaystyle \frac{12 k_{x}^{2} \left(h_{c}^{2} k_{x}^{2} + 12\right)}{h_{c}^{4} k_{x}^{4} + 60 h_{c}^{2} k_{x}^{2} + 144}\)

Numerical solution

Solver definitions


@define(frozen=True, slots=True, kw_only=True)            
class PredictorCorrectorSolver(HyperbolicSolver):
    settings: Zstruct = field(factory=lambda: Settings.default())
    compute_dt: Callable = field(factory=lambda: timestepping.adaptive(CFL=0.9))
    num_flux: Callable = field(factory=lambda: flux.Zero())
    nc_flux: Callable = field(factory=lambda: nc_flux.segmentpath())
    pressuresolver: Callable = field(factory=lambda: PoissonSolver())
    time_end: float = 0.1
    
    def map_Q_to_P(self, Q, Qaux, P, Paux, mesh, dt):
        h = Q[0]
        u0 = Q[1]/h
        u1 = Q[2]/h
        w0 = Q[3]/h
        w1 = Q[4]/h
        b = Q[5]
        
        dbdx = Qaux[3]
        ddbdxx = compute_derivatives(b, mesh, derivatives_multi_index=([[2]]))[:, 0]
        dhdx = Qaux[4]
        ddhdxx = compute_derivatives(h, mesh, derivatives_multi_index=([[2]]))[:, 0]
        du0dx = compute_derivatives(u0, mesh, derivatives_multi_index=([[1]]))[:, 0]
        du1dx = compute_derivatives(u1, mesh, derivatives_multi_index=([[1]]))[:, 0]

        offset = 2
        Paux = Paux.at[4+offset].set(h)
        Paux = Paux.at[5+offset].set(dbdx)
        Paux = Paux.at[6+offset].set(ddbdxx)
        Paux = Paux.at[7+offset].set(dhdx)
        Paux = Paux.at[8+offset].set(ddhdxx)
        Paux = Paux.at[9+offset].set(u0)
        Paux = Paux.at[10+offset].set(du0dx)
        Paux = Paux.at[11+offset].set(w0)        
        u0 = Q[1]/h
        u1 = Q[2]/h
        w0 = Q[3]/h
        w1 = Q[4]/h
        b = Q[5]
        
        dbdx = Qaux[3]
        ddbdxx = compute_derivatives(b, mesh, derivatives_multi_index=([[2]]))[:, 0]
        dhdx = Qaux[4]
        ddhdxx = compute_derivatives(h, mesh, derivatives_multi_index=([[2]]))[:, 0]
        du0dx = compute_derivatives(u0, mesh, derivatives_multi_index=([[1]]))[:, 0]
        du1dx = compute_derivatives(u1, mesh, derivatives_multi_index=([[1]]))[:, 0]

        offset = 2
        Paux = Paux.at[4+offset].set(h)
        Paux = Paux.at[5+offset].set(dbdx)
        Paux = Paux.at[6+offset].set(ddbdxx)
        Paux = Paux.at[7+offset].set(dhdx)
        Paux = Paux.at[8+offset].set(ddhdxx)
        Paux = Paux.at[9+offset].set(u0)
        Paux = Paux.at[10+offset].set(du0dx)
        Paux = Paux.at[11+offset].set(w0)
        Paux = Paux.at[12+offset].set(w1)
        Paux = Paux.at[13+offset].set(u1)
        Paux = Paux.at[14+offset].set(du1dx)
        Paux = Paux.at[15+offset].set(dt)
        return Paux
    
    def map_P_to_Q(self, Q, Qaux, P, Paux, mesh, dt):
        Qaux = Qaux.at[1].set(P[0])
        Qaux = Qaux.at[2].set(P[1])
        h =  Q[0]
        
        dhp0dx = compute_derivatives(h*P[0], mesh, derivatives_multi_index=([[1]]))[:, 0]
        dhp1dx = compute_derivatives(h*P[1], mesh, derivatives_multi_index=([[1]]))[:, 0]

        Qaux = Qaux.at[5].set(dhp0dx)
        Qaux = Qaux.at[6].set(dhp1dx)
        return Qaux
    
    
    def update_qaux(self, Q, Qaux, Qold, Qauxold, mesh, model, parameters, time, dt):

        h=Q[0]
        hu0=Q[1]
        hu1=Q[2]
        hw0=Q[3]
        hw1=Q[4]
        b=Q[5]
        
        w0 = hw0 / h
        w1 = hw1 / h
        u0 = hu0 / h
        u1 = hu1 / h
        # aux_fields=['hw2', 'p0', 'p1', 'dbdx', 'dhdx', 'dhp0dx', 'dhp1dx'],

        dbdx  = compute_derivatives(b, mesh, derivatives_multi_index=([[1]]))[:,0]
        Qaux = Qaux.at[3].set(dbdx)
        
        hw2 = h*(-(w0 + w1) + (u0 + u1) * dbdx)
        Qaux = Qaux.at[0].set(hw2)
        
        
        dhdx   = compute_derivatives(h, mesh, derivatives_multi_index=([[1]]))[:, 0]
        Qaux = Qaux.at[4].set(dhdx)
        return Qaux
    
    # @partial(jax.jit, static_argnames=["self", "mesh", "pde"])
    def compute_source_pressure(self, mesh, model):
        @jax.jit
        def f(dt, Q, Qaux, parameters):
            dQ = jnp.zeros_like(Q)
            dQ = dQ.at[:, : mesh.n_inner_cells].set(
                model.residual(
                    Q[:, : mesh.n_inner_cells],
                    Qaux[:, : mesh.n_inner_cells],
                    parameters,
                )
            )
            return Q + dt * dQ
        return f
    
    def solve(self, mesh, model, pressure_model, write_output=True):
        modelP = pressure_model
        Q, Qaux = self.initialize(mesh, model)
        Q, Qaux, parameters, mesh, model = self.create_runtime(Q, Qaux, mesh, model)
        P, Paux = self.pressuresolver.initialize(mesh, modelP)
        P, Paux, parametersP, mesh, modelP = self.pressuresolver.create_runtime(P, Paux, mesh, modelP)
        
        if write_output:
            output_hdf5_path = os.path.join(
                self.settings.output.directory, f"{self.settings.output.filename}.h5"
            )
            save_fields = io.get_save_fields(output_hdf5_path)
        else:
            def skip_save(time, time_stamp, i_snapshot, Q, Qaux):
                return i_snapshot
            save_fields = skip_save

        def run(Q, Qaux, parameters, model, P, Paux, parametersP, modelP):
            iteration = 0.0
            time = 0.0
            assert model.dimension == mesh.dimension

            i_snapshot = 0.0
            dt_snapshot = self.time_end / (self.settings.output.snapshots - 1)
            if write_output:
                io.init_output_directory(
                    self.settings.output.directory, self.settings.output.clean_directory
                )
                mesh.write_to_hdf5(output_hdf5_path)
                io.save_settings(self.settings)
            i_snapshot = save_fields(time, 0.0, i_snapshot, Q, Qaux)

            Qnew = Q
            Qauxnew = Qaux
            

            min_inradius = jnp.min(mesh.cell_inradius)

            compute_max_abs_eigenvalue = self.get_compute_max_abs_eigenvalue(mesh, model)
            flux_operator = self.get_flux_operator(mesh, model)
            source_operator = self.get_compute_source(mesh, model)
            boundary_operator = self.get_apply_boundary_conditions(mesh, model)


            boundary_operatorP = self.get_apply_boundary_conditions(mesh, modelP)


            @jax.jit
            @partial(jax.named_call, name="time loop")
            def time_loop(time, iteration, i_snapshot, Qnew, Qaux, P, Paux):
                loop_val = (time, iteration, i_snapshot, Qnew, Qaux, P, Paux)

                @partial(jax.named_call, name="time_step")
                def loop_body(init_value):
                    time, iteration, i_snapshot, Qnew, Qauxnew, P, Paux = init_value
                    Q = Qnew
                    Qaux = Qauxnew

                    dt = self.compute_dt(
                        Q, Qaux, parameters, min_inradius, compute_max_abs_eigenvalue
                    )
                    def step(Q, Qaux, P, Paux):
                        Qnew = RK1(flux_operator, Q, Qaux, parameters, dt)
                        
                        ## TODO remove
                        Qnew = Qnew.at[5].set(Q[5])
                        Qnew = boundary_operator(time, Qnew, Qaux, parameters)
                        Qauxnew = self.update_qaux(Qnew, Qaux, Q, Qaux, mesh, model, parameters, time, dt)
                        
                        
                        
                        Paux = self.map_Q_to_P(Qnew, Qauxnew, P, Paux, mesh, dt)
                        Paux = self.pressuresolver.update_qaux(P, Paux, P, Paux, mesh, modelP, parametersP, time, dt)
                        residual = self.pressuresolver.get_residual(Paux, P, Paux, parametersP, mesh, modelP, boundary_operatorP, time, dt)
                        newton_solve = newton_solver(residual)
                        P = newton_solve(P)
                        Paux = self.pressuresolver.update_qaux(P, Paux, P, Paux, mesh, modelP, parametersP, time, dt)
                        
                        Qauxnew = self.map_P_to_Q(Qnew, Qauxnew, P, Paux, mesh, dt)
                        Qauxnew = self.update_qaux(Qnew, Qauxnew, Q, Qaux, mesh, model, parameters, time, dt)

                        Qnew = RK1(
                            source_operator,
                            Qnew,
                            Qauxnew,
                            parameters,
                            dt,
                        )

                        Qnew = boundary_operator(time, Qnew, Qauxnew, parameters)
                        return Qnew, Qauxnew
                    
                    Q1, Qaux1= step(Q, Qaux, P, Paux)
                    Q2, Qaux2= step(Q1, Qaux1, P, Paux)
                    Qnew = 0.5 * (Q2 + Q)
                    Qauxnew = 0.5 * (Qaux2 + Qaux)

                    
                    # Update solution and time
                    time += dt
                    iteration += 1

                    time_stamp = (i_snapshot) * dt_snapshot

                    i_snapshot = save_fields(time, time_stamp, i_snapshot, Qnew, Qauxnew)

                    
                    jax.experimental.io_callback(
                        log_callback_hyperbolic,                 
                        None,                          
                        iteration, time, dt, time_stamp 
                    )
                    

                    return (time, iteration, i_snapshot, Qnew, Qauxnew, P, Paux)

                def proceed(loop_val):
                    time, iteration, i_snapshot, Qnew, Qaux, P, Paux = loop_val
                    return time < self.time_end

                (time, iteration, i_snapshot, Qnew, Qaux, P, Paux) = jax.lax.while_loop(
                    proceed, loop_body, loop_val
                )

                return Qnew, Qauxnew

            Qnew = time_loop(time, iteration, i_snapshot, Qnew, Qaux, P, Paux)
            return Qnew, Qaux

        time_start = gettime()
        Qnew, Qaux = run(Q, Qaux, parameters, model, P, Paux, parametersP, modelP)
        jax.experimental.io_callback(
            log_callback_execution_time,                 
            None,                          
            gettime() - time_start 
        )
        return Qnew, Qaux
    
class MyPoissonSolver(PoissonSolver):
    def update_qaux(self, Q, Qaux, Qold, Qauxold, mesh, model, parameters, time, dt):

        p0 = Q[0]
        p1 = Q[1]
        
        dp0dx = compute_derivatives(p0, mesh, derivatives_multi_index=([[1]]))[:, 0]
        ddp0dxx = compute_derivatives(p0, mesh, derivatives_multi_index=([[2]]))[:, 0]
        dp1dx = compute_derivatives(p1, mesh, derivatives_multi_index=([[1]]))[:, 0]
        ddp1dxx = compute_derivatives(p1, mesh, derivatives_multi_index=([[2]]))[:, 0]


        Qaux = Qaux.at[0].set(dp0dx)
        Qaux = Qaux.at[1].set(ddp0dxx)
        Qaux = Qaux.at[2].set(dp1dx)
        Qaux = Qaux.at[3].set(ddp1dxx)

        return Qaux

Simulation


settings = Settings(
    name="VAM",
    output = Zstruct(directory="outputs/vam", filename='vam')
)

bc_tags = ["left", "right"]
bc_tags_periodic_to = ["right", "left"]

bcs1 = BC.BoundaryConditions(
    [
        BC.Lambda(tag='left', prescribe_fields={
            1: lambda t, x, dx, q, qaux, p, n: .11197,
            2: lambda t, x, dx, q, qaux, p, n: 0.,
            3: lambda t, x, dx, q, qaux, p, n: 0.,
            4: lambda t, x, dx, q, qaux, p, n: 0.
        }),
        BC.Extrapolation(tag='right')

    ]
)

bcs2 = BC.BoundaryConditions(
    [
        BC.Extrapolation(tag='left'),
        BC.Extrapolation(tag='right'),
    ]
)

def custom_ic1(x):
    Q = np.zeros(6, dtype=float)
    Q[1] = np.where(x[0]-5 < 1, 0.0, 0.)
    Q[5] = 0.20*np.exp(-(x[0]-0.)**2 / (2*0.2**2)) 
    Q[0] = np.where(x[0] < 1, 0.34, 0.015) - Q[5]
    Q[0] = np.where(Q[0] > 0.015, Q[0], 0.015)
    # Q[0] = np.where(x[0]**2 < 0.5, 0.2, 0.1)
    return Q


ic1 = IC.UserFunction(custom_ic1)


model1 = VAMHyperbolic(
    boundary_conditions=bcs1,
    initial_conditions=ic1,
)

model2 = VAMPoisson(
    boundary_conditions=bcs2,
)

mesh = petscMesh.Mesh.create_1d((-1.5, 1.5), 60, lsq_degree=2)

solver = PredictorCorrectorSolver(time_end = 20, settings=settings, pressuresolver=MyPoissonSolver())

Q, Qaux = solver.solve(mesh, model1, model2, write_output=True)
2026-02-26 15:54:46.983 | WARNING  | zoomy_core.misc.misc:__init__:384 - No 'clean_directory' attribute found in output Zstruct. Default: 'False'

2026-02-26 15:54:46.987 | WARNING  | zoomy_core.misc.misc:__init__:384 - No 'snapshots' attribute found in output Zstruct. Default: '2'
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[16], line 53
     47 model2 = VAMPoisson(
     48     boundary_conditions=bcs2,
     49 )
     51 mesh = petscMesh.Mesh.create_1d((-1.5, 1.5), 60, lsq_degree=2)
---> 53 solver = PredictorCorrectorSolver(time_end = 20, settings=settings, pressuresolver=MyPoissonSolver())
     55 Q, Qaux = solver.solve(mesh, model1, model2, write_output=True)

File <attrs generated methods __main__.PredictorCorrectorSolver>:60, in __init__(self, flux, settings, compute_dt, num_flux, nc_flux, pressuresolver, time_end)
     58     _setattr('nc_flux', nc_flux)
     59 else:
---> 60     _setattr('nc_flux', __attr_factory_nc_flux())
     61 if pressuresolver is not NOTHING:
     62     _setattr('pressuresolver', pressuresolver)

Cell In[14], line 6, in PredictorCorrectorSolver.<lambda>()
      4 compute_dt: Callable = field(factory=lambda: timestepping.adaptive(CFL=0.9))
      5 num_flux: Callable = field(factory=lambda: flux.Zero())
----> 6 nc_flux: Callable = field(factory=lambda: nc_flux.segmentpath())
      7 pressuresolver: Callable = field(factory=lambda: PoissonSolver())
      8 time_end: float = 0.1

AttributeError: module 'zoomy_core.fvm.nonconservative_flux' has no attribute 'segmentpath'

Visualization

io.generate_vtk(os.path.join(settings.output.directory, f"{settings.output.filename}.h5"))
fig = plots_paper.plot_vam(os.path.join(settings.output.directory, settings.output.filename + ".h5"))