zoomy_jax.fvm.halo_exchange module

zoomy_jax.fvm.halo_exchange module#

JIT-compatible halo exchange for MPI-parallel JAX solves.

The HaloExchange object is initialised once with the static partition topology and MPI communicator. Each Runge-Kutta stage calls it to synchronise ghost cells between ranks using mpi4jax.sendrecv inside jax.lax.fori_loop so that the exchange is fully JIT-traceable.

If mpi4jax is not installed the module still imports, but create_halo_exchange() returns a no-op callable so that serial code keeps working.

class zoomy_jax.fvm.halo_exchange.HaloExchange(send_indices, recv_indices, neighbor_ranks, comm)#

Bases: object

JIT-compatible ghost-cell synchronisation.

Parameters:
  • send_indices (list[jnp.ndarray]) – send_indices[i] contains the local cell indices to pack and send to neighbor_ranks[i].

  • recv_indices (list[jnp.ndarray]) – recv_indices[i] contains the local cell indices where received data from neighbor_ranks[i] is written.

  • neighbor_ranks (list[int]) – MPI ranks of the communication partners (ordered).

  • comm (MPI.Comm) – The MPI communicator.

zoomy_jax.fvm.halo_exchange.create_halo_exchange(partition_info, comm=None)#

Build a HaloExchange (or no-op) from partition info.

Parameters:
  • partition_info (PartitionInfo) – The partition descriptor for the local rank.

  • comm (MPI.Comm or None) – MPI communicator. If None or if mpi4jax is not available, a no-op callable is returned.

Returns:

A function Q -> Q that performs the halo exchange.

Return type:

callable

zoomy_jax.fvm.halo_exchange.allreduce_min(value, comm=None)#

Global minimum via mpi4jax.allreduce (MPI_MIN).

Falls back to identity when MPI is unavailable or comm size is 1.

Parameters:

value (Array) –

Return type:

Array