Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/io4dolfinx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from importlib.metadata import metadata

from .backends import FileMode, get_backend
from .backends import FileMode, get_backend, set_default_backend
from .checkpointing import (
read_attributes,
read_function,
Expand Down Expand Up @@ -55,12 +55,12 @@
"write_function_on_input_mesh",
"write_mesh_input_order",
"write_attributes",
"write_data",
"read_attributes",
"read_function_names",
"read_point_data",
"read_timestamps",
"get_backend",
"set_default_backend",
"write_cell_data",
"write_point_data",
"reconstruct_mesh",
Expand Down
18 changes: 16 additions & 2 deletions src/io4dolfinx/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@

from ..structures import ArrayData, FunctionData, MeshData, MeshTagsData, ReadMeshData

__all__ = ["FileMode", "IOBackend", "get_backend"]
__all__ = ["FileMode", "IOBackend", "get_backend", "set_default_backend"]

_DEFAULT_BACKEND = "adios2"


def set_default_backend(backend: str):
"""Set the global default backend for io4dolfinx."""
global _DEFAULT_BACKEND
if backend not in BUILTIN_BAKENDS:
# Optional: You can choose to warn or raise an error if it's not a known backend
pass
_DEFAULT_BACKEND = backend


class ReadMode(Enum):
Expand Down Expand Up @@ -404,7 +415,7 @@ def write_data(
...


def get_backend(backend: str) -> IOBackend:
def get_backend(backend: str | None = None) -> IOBackend:
"""Get backend class from backend name.

Args:
Expand All @@ -413,6 +424,9 @@ def get_backend(backend: str) -> IOBackend:
Returns:
Backend class
"""
if backend is None:
backend = _DEFAULT_BACKEND

if backend == "h5py":
from .h5py import backend as H5PYInterface

Expand Down
40 changes: 16 additions & 24 deletions src/io4dolfinx/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def write_attributes(
name: str,
attributes: dict[str, np.ndarray],
backend_args: dict[str, typing.Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""Write attributes to file.

Expand All @@ -85,7 +85,7 @@ def read_attributes(
comm: MPI.Intracomm,
name: str,
backend_args: dict[str, typing.Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
) -> dict[str, typing.Any]:
"""Read attributes from file.

Expand All @@ -110,7 +110,7 @@ def read_timestamps(
comm: MPI.Intracomm,
function_name: str,
backend_args: dict[str, typing.Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
) -> npt.NDArray[np.float64 | str]: # type: ignore[type-var]
"""
Read time-stamps from a checkpoint file.
Expand Down Expand Up @@ -138,7 +138,7 @@ def write_meshtags(
meshtags: dolfinx.mesh.MeshTags,
meshtag_name: typing.Optional[str] = None,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
on_input_mesh: bool = False,
):
"""
Expand Down Expand Up @@ -217,7 +217,7 @@ def read_meshtags(
mesh: dolfinx.mesh.Mesh,
meshtag_name: str,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
) -> dolfinx.mesh.MeshTags:
"""
Read meshtags from file and return a :class:`dolfinx.mesh.MeshTags` object.
Expand Down Expand Up @@ -258,7 +258,7 @@ def read_function(
time: float = 0.0,
name: str | None = None,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""
Read checkpoint from file and fill it into `u`.
Expand Down Expand Up @@ -406,7 +406,7 @@ def read_mesh(
time: float | str | None = 0.0,
read_from_partition: bool = False,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
max_facet_to_cell_links: int = 2,
) -> dolfinx.mesh.Mesh:
"""
Expand All @@ -426,10 +426,8 @@ def read_mesh(
The distributed mesh
"""
logger.debug(f"Reading mesh from {filename}")
logger.debug(
f"Using {backend} backend with arguments {backend_args}, "
f"time {time} and read_from_partition {read_from_partition}"
)
logger.debug(f"Using {backend} backend with arguments {backend_args}")
logger.debug(f"Time {time} and read_from_partition {read_from_partition}")
# Read in data in a distributed fashin
check_file_exists(filename)
backend_cls = get_backend(backend)
Expand Down Expand Up @@ -496,7 +494,7 @@ def write_mesh(
time: float = 0.0,
store_partition_info: bool = False,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""
Write a mesh to file.
Expand All @@ -510,10 +508,8 @@ def write_mesh(
logger.debug(f"Writing mesh to {filename}")
logger.debug(f"Preparing mesh data for storage storing partition info: {store_partition_info}")
mesh_data = prepare_meshdata_for_storage(mesh=mesh, store_partition_info=store_partition_info)
logger.debug(
f"Write mesh using {backend} backend, with arguments {backend_args}, "
f"mode {mode} and time {time}"
)
logger.debug(f"Write mesh using {backend} backend, with arguments {backend_args}")
logger.debug(f"Mode {mode} and time {time}")
_internal_mesh_writer(
filename,
mesh.comm,
Expand All @@ -532,7 +528,7 @@ def write_function(
mode: FileMode = FileMode.append,
name: str | None = None,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""
Write function checkpoint to file.
Expand All @@ -546,13 +542,9 @@ def write_function(
backend_args: Arguments to the IO backend.
backend: The backend to use
"""
logger.debug(
f"Writing function checkpoint to {filename} for function {name or u.name} at time {time}"
)
logger.debug(
f"Extracting data from function and dofmap for storage using {backend} "
f"backend with arguments {backend_args}"
)
n = u.name if name is None else name
logger.debug(f"Writing function checkpoint to {filename} for function {n} at time {time}")
logger.debug(f"Using {backend} backend with arguments {backend_args}")
dofmap = u.function_space.dofmap
values = u.x.array
mesh = u.function_space.mesh
Expand Down
4 changes: 2 additions & 2 deletions src/io4dolfinx/original_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def write_function_on_input_mesh(
name: typing.Optional[str] = None,
mode: FileMode = FileMode.append,
backend_args: dict[str, typing.Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""
Write function checkpoint (to be read with the input mesh).
Expand Down Expand Up @@ -388,7 +388,7 @@ def write_mesh_input_order(
mesh: dolfinx.mesh.Mesh,
time: float = 0.0,
mode: FileMode = FileMode.write,
backend: str = "adios2",
backend: str | None = None,
backend_args: dict[str, typing.Any] | None = None,
):
"""
Expand Down
6 changes: 3 additions & 3 deletions src/io4dolfinx/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def send_cells_and_receive_dofmap_index(
dofmap_path: str,
xdofmap_path: str,
bs: int,
backend: str,
backend: str | None,
) -> npt.NDArray[np.int64]:
"""
Given a set of positions in input dofmap, give the global input index of this dofmap entry
Expand Down Expand Up @@ -151,7 +151,7 @@ def read_mesh_from_legacy_h5(
comm: MPI.Intracomm,
group: str,
cell_type: str = "tetrahedron",
backend: str = "adios2",
backend: str | None = None,
max_facet_to_cell_links: int = 2,
) -> dolfinx.mesh.Mesh:
"""
Expand Down Expand Up @@ -229,7 +229,7 @@ def read_function_from_legacy_h5(
group: str = "mesh",
step: typing.Optional[int] = None,
vector_group: str | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""
Read function from a `h5`-file generated by legacy DOLFIN `HDF5File.write`
Expand Down
2 changes: 1 addition & 1 deletion src/io4dolfinx/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def snapshot_checkpoint(
file: Path,
mode: FileMode,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""Read or write a snapshot checkpoint

Expand Down
27 changes: 14 additions & 13 deletions src/io4dolfinx/writers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,18 +114,19 @@ def write_mesh(
time: float = 0.0,
mode: FileMode = FileMode.write,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""
Write a mesh to file using ADIOS2
Write a mesh to file

Args:
comm: MPI communicator used in storage
mesh: Internal data structure for the mesh data to save to file
filename: Path to file to write to
engine: ADIOS2 engine to use
mode: ADIOS2 mode to use (write or append)
io_name: Internal name used for the ADIOS IO object
comm: MPI communicator used in storage
mesh_data: Internal data structure for the mesh data to save to file
time: Time stamp associated with mesh
mode: File mode to use (write or append)
backend_args: Arguments for the backend
backend: Backend to use
"""
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
Expand All @@ -139,19 +140,19 @@ def write_function(
time: float = 0.0,
mode: FileMode = FileMode.append,
backend_args: dict[str, Any] | None = None,
backend: str = "adios2",
backend: str | None = None,
):
"""
Write a function to file using ADIOS2
Write a function to file

Args:
filename: Path to file to write to
comm: MPI communicator used in storage
u: Internal data structure for the function data to save to file
filename: Path to file to write to
engine: ADIOS2 engine to use
mode: ADIOS2 mode to use (write or append)
time: Time stamp associated with function
io_name: Internal name used for the ADIOS IO object
mode: File mode to use (write or append)
backend_args: Arguments for the backend
backend: Backend to use
"""
backend_cls = get_backend(backend)
backend_args = backend_cls.get_default_backend_args(backend_args)
Expand Down
100 changes: 100 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from mpi4py import MPI

import dolfinx
import h5py
import pytest

import io4dolfinx
from io4dolfinx.backends import BUILTIN_BAKENDS


@pytest.fixture(autouse=True)
def reset_default_backend():
"""
Fixture to ensure the global default backend is always reset to 'adios2'
after each test, preventing state leakage to other tests.
"""
# Setup: Ensure starting state is adios2
io4dolfinx.set_default_backend("adios2")
yield
# Teardown: Reset back to adios2
io4dolfinx.set_default_backend("adios2")


def test_explicit_backend_overrides_default(tmp_path):
"""
Test that explicitly passing `backend="h5py"` overrides the global
default backend (which is "adios2").
"""
comm = MPI.COMM_WORLD

# Ensure default is currently adios2
assert io4dolfinx.backends._DEFAULT_BACKEND == "adios2"

mesh = dolfinx.mesh.create_unit_square(comm, 5, 5)

# We use .h5 suffix, but the backend argument is what actually dictates the writer
fname = comm.bcast(tmp_path, root=0) / "override_test.h5"

# Explicitly pass the h5py backend
io4dolfinx.write_mesh(fname, mesh, backend="h5py")

comm.Barrier()

# Verify that h5py was actually used by attempting to open it as an HDF5 file.
# If adios2 was used, this would raise an OSError/ValueError.
if comm.rank == 0:
assert fname.exists()
with h5py.File(fname, "r") as f:
assert "mesh" in f.keys()


def test_set_default_backend_takes_effect(tmp_path):
"""
Test that calling `set_default_backend("h5py")` successfully changes the
default behavior for API calls where `backend` is not explicitly provided.
"""
comm = MPI.COMM_WORLD

# Update the global default backend to h5py
io4dolfinx.set_default_backend("h5py")

mesh = dolfinx.mesh.create_unit_square(comm, 5, 5)

fname = comm.bcast(tmp_path, root=0) / "default_update_test.h5"

# Call the API without providing the `backend` argument
io4dolfinx.write_mesh(fname, mesh)

comm.Barrier()

# Verify that h5py was implicitly used based on the new default
if comm.rank == 0:
assert fname.exists()
with h5py.File(fname, "r") as f:
assert "mesh" in f.keys()
assert "Topology" in f["mesh"].keys()


def test_list_builtin_backends():
"""
Test that list_builtin_backends returns a valid list containing
a subset of the supported built-in backends based on the current environment.
"""
# Call the function to get the list of available backends
available_backends = io4dolfinx.backends.list_builtin_backends()

# Verify the return type is a list
assert isinstance(available_backends, list)

# Depending on the test environment, at least one backend should be available
assert len(available_backends) > 0

# Verify that all returned backends are recognized as built-in backends
for backend in available_backends:
assert isinstance(backend, str)
assert backend in BUILTIN_BAKENDS

# We can be reasonably certain that 'h5py' or 'adios2' should be
# present if the io4dolfinx test suite is running successfully
assert "h5py" in available_backends or "adios2" in available_backends