zoomy_jax.fvm.halo_exchange module#
JIT-compatible halo exchange for MPI-parallel JAX solves.
The HaloExchange object is initialised once with the static
partition topology and MPI communicator. Each Runge-Kutta stage calls
it to synchronise ghost cells between ranks using mpi4jax.sendrecv
inside jax.lax.fori_loop so that the exchange is fully
JIT-traceable.
If mpi4jax is not installed the module still imports, but
create_halo_exchange() returns a no-op callable so that serial
code keeps working.
- class zoomy_jax.fvm.halo_exchange.HaloExchange(send_indices, recv_indices, neighbor_ranks, comm)#
Bases:
objectJIT-compatible ghost-cell synchronisation.
- Parameters:
send_indices (list[jnp.ndarray]) –
send_indices[i]contains the local cell indices to pack and send toneighbor_ranks[i].recv_indices (list[jnp.ndarray]) –
recv_indices[i]contains the local cell indices where received data fromneighbor_ranks[i]is written.neighbor_ranks (list[int]) – MPI ranks of the communication partners (ordered).
comm (MPI.Comm) – The MPI communicator.
- zoomy_jax.fvm.halo_exchange.create_halo_exchange(partition_info, comm=None)#
Build a
HaloExchange(or no-op) from partition info.- Parameters:
partition_info (PartitionInfo) – The partition descriptor for the local rank.
comm (MPI.Comm or None) – MPI communicator. If None or if
mpi4jaxis not available, a no-op callable is returned.
- Returns:
A function
Q -> Qthat performs the halo exchange.- Return type:
callable
- zoomy_jax.fvm.halo_exchange.allreduce_min(value, comm=None)#
Global minimum via
mpi4jax.allreduce(MPI_MIN).Falls back to identity when MPI is unavailable or comm size is 1.
- Parameters:
value (Array) –
- Return type:
Array