From dafc919f1fb135f61546947ddd48d99b5df9bf2f Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Wed, 1 Apr 2026 12:54:23 -0500 Subject: [PATCH 1/4] Add reusable remap weight support --- docs/api.rst | 12 +++ docs/user-guide/remap-weights.rst | 120 +++++++++++++++++++++++ docs/userguide.rst | 4 + test/precomputed_weights_test.py | 96 +++++++++++++++++++ uxarray/__init__.py | 3 + uxarray/remap/__init__.py | 5 + uxarray/remap/accessor.py | 41 ++++++++ uxarray/remap/precomputed.py | 105 +++++++++++++++++++++ uxarray/remap/weights.py | 152 ++++++++++++++++++++++++++++++ 9 files changed, 538 insertions(+) create mode 100644 docs/user-guide/remap-weights.rst create mode 100644 test/precomputed_weights_test.py create mode 100644 uxarray/remap/precomputed.py create mode 100644 uxarray/remap/weights.py diff --git a/docs/api.rst b/docs/api.rst index a68f91ba9..bcca37eb8 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -21,6 +21,7 @@ Top Level Functions open_multigrid open_mfdataset concat + load_remap_weights Grid @@ -380,6 +381,15 @@ Remapping .. seealso:: `Remapping User Guide Section `_ + `Applying External Remap Weights `_ + +Helpers +~~~~~~~ + +.. autosummary:: + :toctree: generated/ + + RemapWeights UxDataArray ~~~~~~~~~~~ @@ -389,6 +399,7 @@ UxDataArray :template: autosummary/accessor_method.rst UxDataArray.remap + UxDataArray.remap.apply_weights UxDataArray.remap.nearest_neighbor UxDataArray.remap.inverse_distance_weighted UxDataArray.remap.bilinear @@ -401,6 +412,7 @@ UxDataset :template: autosummary/accessor_method.rst UxDataset.remap + UxDataset.remap.apply_weights UxDataset.remap.nearest_neighbor UxDataset.remap.inverse_distance_weighted UxDataset.remap.bilinear diff --git a/docs/user-guide/remap-weights.rst b/docs/user-guide/remap-weights.rst new file mode 100644 index 000000000..4257d66b9 --- /dev/null +++ b/docs/user-guide/remap-weights.rst @@ -0,0 +1,120 @@ +.. currentmodule:: uxarray + +Remap Weights +============= + +UXarray can apply precomputed offline remapping weights produced outside of UXarray. +This is useful when weights are generated once with tools such as ESMF or +TempestRemap and then reused many times across multiple ensemble members, time +slices, or variables. + +The core workflow is: + +1. Generate a weight file for a specific source grid and destination grid. +2. Load the weight file once with :func:`load_remap_weights`. +3. Reuse the loaded :class:`RemapWeights` object with :meth:`UxDataArray.remap.apply_weights` + or :meth:`UxDataset.remap.apply_weights`. + +Basic Usage +----------- + +.. code-block:: python + + import uxarray as ux + + src = ux.open_dataset("source_grid.nc", "source_data.nc") + dst = ux.open_grid("destination_grid.nc") + + weights = ux.load_remap_weights("map.nc") + + remapped_temperature = src["temperature"].remap.apply_weights( + weights, destination_grid=dst + ) + + remapped_dataset = src.remap.apply_weights(weights, destination_grid=dst) + +Repeated calls with the same path reuse a cached sparse operator, so loading the +same file again in one Python session avoids rebuilding the matrix. + +What A Weight File Represents +----------------------------- + +A remap weight file represents a linear operator from one grid to another: + +.. code-block:: text + + target_values = W @ source_values + +If the source grid has ``4800`` elements and the destination grid has ``11000`` +elements, then: + +- ``source_values.shape = (4800,)`` +- ``W.shape = (11000, 4800)`` +- ``target_values.shape = (11000,)`` + +So the weight file necessarily encodes both the source grid and the destination +grid. It is specific to that grid pair and to the ordering of the source and +destination degrees of freedom. + +Supported File Structure +------------------------ + +UXarray currently supports the standard sparse offline-map structure used by +ESMF-style and TempestRemap-style map files. The essential pieces are: + +- ``n_a``: source size +- ``n_b``: destination size +- ``n_s``: number of nonzero entries +- ``row``: destination indices +- ``col``: source indices +- ``S``: sparse weight values + +Common aliases are also accepted: + +- ``src_grid_size`` and ``dst_grid_size`` +- ``src_address`` and ``dst_address`` +- ``weights`` instead of ``S`` + +In full offline map files, these sparse arrays are typically accompanied by +source and destination metadata such as center coordinates, corner coordinates, +areas, masks, and grid-dimension metadata. + +Tool Compatibility +------------------ + +This implementation was verified against real files from both families: + +- ESMF-generated offline map files created with ``ESMF_RegridWeightGen`` +- TempestRemap-generated offline map files created with ``GenerateOfflineMap`` + +In practice, UXarray supports the standard full offline map format used by both +tools. + +Current caveats: + +- The source data ordering must match the source ordering encoded in the weight file. +- Not every possible file variant is guaranteed yet. +- ESMF ``weight_only`` outputs may require additional handling if they omit + source and destination size metadata. + +How It Applies Data +------------------- + +When remapping a :class:`UxDataArray` or :class:`UxDataset`, UXarray identifies a +single spatial dimension whose size matches the source size in the loaded +weights. That dimension is remapped to the requested destination dimension +(``faces``, ``edges``, or ``nodes``). + +Non-spatial dimensions are preserved, which makes this workflow suitable for +reusing one operator across many time steps, ensemble members, or variables. + +Why Use This Workflow +--------------------- + +This path is useful when: + +- weight generation is expensive and should be done once +- remapping needs to be repeated many times +- external tools already produce trusted offline maps +- you want to stay in Python for applying the map and preserving array metadata + diff --git a/docs/userguide.rst b/docs/userguide.rst index c281805b3..d185a4d59 100644 --- a/docs/userguide.rst +++ b/docs/userguide.rst @@ -64,6 +64,9 @@ These user guides provide detailed explanations of the core functionality in UXa `Remapping `_ Remap (a.k.a Regrid) between unstructured grids +`Applying External Remap Weights `_ + Apply precomputed ESMF or TempestRemap offline map files + `Topological Aggregations `_ Aggregate data across grid dimensions @@ -119,6 +122,7 @@ These user guides provide additional details about specific features in UXarray. user-guide/zonal-average.ipynb user-guide/azimuthal-average.ipynb user-guide/remapping.ipynb + user-guide/remap-weights.rst user-guide/topological-aggregations.ipynb user-guide/weighted_mean.ipynb user-guide/vector_calculus.ipynb diff --git a/test/precomputed_weights_test.py b/test/precomputed_weights_test.py new file mode 100644 index 000000000..0ea243db5 --- /dev/null +++ b/test/precomputed_weights_test.py @@ -0,0 +1,96 @@ +from pathlib import Path + +import numpy as np +import numpy.testing as nt +import uxarray as ux +import xarray as xr + + +def _write_sparse_map(path: Path, source_size: int, destination_size: int) -> Path: + rows = np.arange(1, destination_size + 1, dtype=np.int32) + cols = np.arange(source_size, 0, -1, dtype=np.int32) + values = np.ones(destination_size, dtype=np.float64) + + ds = xr.Dataset( + data_vars={ + "row": (("n_s",), rows), + "col": (("n_s",), cols), + "S": (("n_s",), values), + }, + coords={"n_s": np.arange(destination_size, dtype=np.int32)}, + ) + ds = ds.assign_coords( + n_a=np.arange(source_size, dtype=np.int32), + n_b=np.arange(destination_size, dtype=np.int32), + ) + ds.to_netcdf(path) + return path + + +def test_load_remap_weights_and_apply_vector(tmp_path, gridpath): + grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) + weight_file = _write_sparse_map( + tmp_path / "reverse_map.nc", grid.n_face, grid.n_face + ) + + weights = ux.load_remap_weights(weight_file) + result = weights.apply(np.arange(grid.n_face, dtype=np.float64)) + + nt.assert_equal(weights.source_size, grid.n_face) + nt.assert_equal(weights.destination_size, grid.n_face) + nt.assert_array_equal(result, np.arange(grid.n_face, dtype=np.float64)[::-1]) + assert isinstance(weights, ux.RemapWeights) + + +def test_apply_weights_to_uxdataarray(tmp_path, gridpath): + grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) + weight_file = _write_sparse_map( + tmp_path / "reverse_map.nc", grid.n_face, grid.n_face + ) + + source = ux.UxDataArray( + xr.DataArray( + np.arange(grid.n_face, dtype=np.float64), + dims=["n_face"], + name="temperature", + attrs={"units": "K"}, + ), + uxgrid=grid, + ) + + remapped = source.remap.apply_weights(weight_file, grid) + + nt.assert_array_equal(remapped.values, source.values[::-1]) + nt.assert_equal(remapped.attrs["units"], "K") + nt.assert_equal(remapped.uxgrid, grid) + + +def test_apply_weights_reuses_loaded_operator(tmp_path, gridpath): + grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) + weight_file = _write_sparse_map( + tmp_path / "reverse_map.nc", grid.n_face, grid.n_face + ) + weights = ux.load_remap_weights(weight_file) + cached_weights = ux.load_remap_weights(weight_file) + + source = ux.UxDataset( + xr.Dataset( + data_vars={ + "a": ( + ("time", "n_face"), + np.arange(2 * grid.n_face).reshape(2, grid.n_face), + ), + "flag": (("time",), np.array([1, 0], dtype=np.int32)), + }, + coords={"time": np.array([0, 1], dtype=np.int32)}, + ), + uxgrid=grid, + ) + + remapped = source.remap.apply_weights(weights, grid) + remapped_again = source["a"].remap.apply_weights(weights, grid) + + assert cached_weights is weights + nt.assert_array_equal(remapped["a"].values, source["a"].values[:, ::-1]) + nt.assert_array_equal(remapped["flag"].values, source["flag"].values) + nt.assert_array_equal(remapped_again.values, source["a"].values[:, ::-1]) diff --git a/uxarray/__init__.py b/uxarray/__init__.py index 80b0dcc9e..dc6623f37 100644 --- a/uxarray/__init__.py +++ b/uxarray/__init__.py @@ -10,6 +10,7 @@ from .core.dataarray import UxDataArray from .core.dataset import UxDataset from .grid import Grid +from .remap import RemapWeights, load_remap_weights try: from importlib.metadata import version as _version @@ -35,4 +36,6 @@ "INT_DTYPE", "INT_FILL_VALUE", "Grid", + "RemapWeights", + "load_remap_weights", ) diff --git a/uxarray/remap/__init__.py b/uxarray/remap/__init__.py index 8eabc7bd7..2c7adf907 100644 --- a/uxarray/remap/__init__.py +++ b/uxarray/remap/__init__.py @@ -1,7 +1,12 @@ from .inverse_distance_weighted import _inverse_distance_weighted_remap from .nearest_neighbor import _nearest_neighbor_remap +from .precomputed import _apply_weights +from .weights import RemapWeights, load_remap_weights __all__ = ( + "RemapWeights", + "load_remap_weights", + "_apply_weights", "_nearest_neighbor_remap", "_inverse_distance_weighted_remap", ) diff --git a/uxarray/remap/accessor.py b/uxarray/remap/accessor.py index ebf74ffa4..a65dc3ef6 100644 --- a/uxarray/remap/accessor.py +++ b/uxarray/remap/accessor.py @@ -10,6 +10,7 @@ from uxarray.remap.bilinear import _bilinear from uxarray.remap.inverse_distance_weighted import _inverse_distance_weighted_remap from uxarray.remap.nearest_neighbor import _nearest_neighbor_remap +from uxarray.remap.precomputed import _apply_weights class RemapAccessor: @@ -25,6 +26,7 @@ def __repr__(self) -> str: + "Supported methods:\n" + " • nearest_neighbor(destination_grid, remap_to='faces')\n" + " • inverse_distance_weighted(destination_grid, remap_to='faces', power=2, k=8)\n" + + " • apply_weights(weights, destination_grid, remap_to='faces')\n" ) def __call__(self, *args, **kwargs) -> UxDataArray | UxDataset: @@ -110,3 +112,42 @@ def bilinear( """ return _bilinear(self.ux_obj, destination_grid, remap_to) + + def apply_weights( + self, + weights, + destination_grid: Grid, + remap_to: str = "faces", + source_dim: str | None = None, + ) -> UxDataArray | UxDataset: + """ + Apply a sparse remap operator loaded from disk. + + Parameters + ---------- + weights : str, PathLike, xr.Dataset, or RemapWeights + Weight file or reusable loaded weights. Standard SCRIP/ESMF sparse + map files are expected to provide ``row``, ``col``, ``S``, and + dimensions ``n_a``/``n_b``. + destination_grid : Grid + Grid representing the destination topology and coordinates. + remap_to : {'nodes', 'edges', 'faces'}, default='faces' + Which destination grid element receives the remapped values. + source_dim : {'n_node', 'n_edge', 'n_face'}, optional + Explicit source spatial dimension to remap along. If omitted, UXarray + infers it from variables whose trailing spatial dimension matches + the loaded weight source size. + + Returns + ------- + UxDataArray or UxDataset + A new object with data mapped onto ``destination_grid``. + """ + + return _apply_weights( + self.ux_obj, + weights=weights, + destination_grid=destination_grid, + remap_to=remap_to, + source_dim=source_dim, + ) diff --git a/uxarray/remap/precomputed.py b/uxarray/remap/precomputed.py new file mode 100644 index 000000000..e83bcf295 --- /dev/null +++ b/uxarray/remap/precomputed.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from os import PathLike + +import numpy as np +import xarray as xr + +import uxarray.core.dataarray + +from .utils import ( + LABEL_TO_COORD, + SPATIAL_DIMS, + _assert_dimension, + _construct_remapped_ds, + _to_dataset, +) +from .weights import RemapWeights, load_remap_weights + + +def _get_source_dim( + da: xr.DataArray, + weights: RemapWeights, + source_dim: str | None, +) -> str | None: + spatial_dims = [dim for dim in da.dims if dim in SPATIAL_DIMS] + + if len(spatial_dims) > 1: + raise ValueError( + f"Precomputed weight application does not support variables with multiple " + f"spatial dimensions. Got {spatial_dims!r} for variable {da.name!r}." + ) + + if source_dim is not None: + if source_dim not in da.dims: + return None + if da.sizes[source_dim] != weights.source_size: + raise ValueError( + f"Variable {da.name!r} dimension {source_dim!r} has size " + f"{da.sizes[source_dim]}, expected {weights.source_size}." + ) + return source_dim + + matches = [dim for dim in spatial_dims if da.sizes[dim] == weights.source_size] + return matches[0] if matches else None + + +def _apply_weights( + source, + weights: str | PathLike[str] | xr.Dataset | RemapWeights, + destination_grid, + remap_to: str = "faces", + source_dim: str | None = None, +): + """Apply a sparse remap operator to UXarray data.""" + _assert_dimension(remap_to) + + weights_obj = load_remap_weights(weights) + destination_dim = LABEL_TO_COORD[remap_to] + destination_size = destination_grid.sizes[destination_dim] + + if destination_size != weights_obj.destination_size: + raise ValueError( + f"Destination grid size for {destination_dim!r} is {destination_size}, " + f"but weights target size is {weights_obj.destination_size}." + ) + + ds, is_da, name = _to_dataset(source) + remapped_vars = {} + remapped_any = False + + for var_name, da in ds.data_vars.items(): + variable_source_dim = _get_source_dim(da, weights_obj, source_dim) + if variable_source_dim is None: + remapped_vars[var_name] = da + continue + + remapped_any = True + other_dims = [dim for dim in da.dims if dim != variable_source_dim] + da_t = da.transpose(*other_dims, variable_source_dim) + remapped_values = weights_obj.apply(np.asarray(da_t.values)) + + coords = {dim: da.coords[dim] for dim in other_dims if dim in da.coords} + da_out = uxarray.core.dataarray.UxDataArray( + remapped_values, + dims=other_dims + [destination_dim], + coords=coords, + name=da.name, + attrs=da.attrs, + uxgrid=destination_grid, + ) + remapped_vars[var_name] = da_out + + if not remapped_any: + if is_da: + raise ValueError( + f"No spatial dimension matched the weight source size {weights_obj.source_size}." + ) + raise ValueError( + "No dataset variables matched the supplied weight source size." + ) + + ds_remapped = _construct_remapped_ds( + source, remapped_vars, destination_grid, remap_to + ) + return ds_remapped[name] if is_da else ds_remapped diff --git a/uxarray/remap/weights.py b/uxarray/remap/weights.py new file mode 100644 index 000000000..c0b141555 --- /dev/null +++ b/uxarray/remap/weights.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +from dataclasses import dataclass +from os import PathLike +from pathlib import Path +from typing import Any + +import numpy as np +import xarray as xr +from scipy import sparse + +_WEIGHTS_CACHE: dict[tuple[str, int, int], "RemapWeights"] = {} + + +def _first_present(mapping, names: tuple[str, ...], kind: str): + for name in names: + if name in mapping: + return mapping[name] + raise ValueError( + f"Could not find {kind}. Expected one of: {', '.join(names)}." + ) + + +def _normalize_indices(indices: np.ndarray, size: int, label: str) -> np.ndarray: + indices = np.asarray(indices, dtype=np.int64).ravel() + if indices.size == 0: + return indices + + if indices.min() >= 1 and indices.max() <= size: + indices = indices - 1 + elif indices.min() < 0 or indices.max() >= size: + raise ValueError( + f"{label} indices are out of bounds for size {size}. " + f"Found min={indices.min()}, max={indices.max()}." + ) + + return indices + + +@dataclass(frozen=True) +class RemapWeights: + """Reusable sparse remapping operator loaded from a standard weight file.""" + + matrix: sparse.csr_matrix + source_size: int + destination_size: int + path: str | None = None + + @classmethod + def from_file(cls, filename_or_obj: str | PathLike[str] | xr.Dataset): + """Load a standard sparse remap-weight file into memory once.""" + if isinstance(filename_or_obj, xr.Dataset): + ds = filename_or_obj + close_ds = False + path = None + else: + ds = xr.open_dataset(filename_or_obj) + close_ds = True + path = str(filename_or_obj) + + try: + source_size = int( + _first_present( + ds.sizes, ("n_a", "src_grid_size"), "source dimension" + ) + ) + destination_size = int( + _first_present( + ds.sizes, ("n_b", "dst_grid_size"), "destination dimension" + ) + ) + + row = _normalize_indices( + _first_present(ds.variables, ("row", "dst_address"), "row indices"), + destination_size, + "Row", + ) + col = _normalize_indices( + _first_present(ds.variables, ("col", "src_address"), "column indices"), + source_size, + "Column", + ) + values = np.asarray( + _first_present(ds.variables, ("S", "weights"), "weight values"), + dtype=np.float64, + ).ravel() + + if not (row.size == col.size == values.size): + raise ValueError( + "Remap weights require row, col, and weight arrays of equal length." + ) + + matrix = sparse.coo_matrix( + (values, (row, col)), + shape=(destination_size, source_size), + ).tocsr() + finally: + if close_ds: + ds.close() + + return cls( + matrix=matrix, + source_size=source_size, + destination_size=destination_size, + path=path, + ) + + def apply(self, values: np.ndarray) -> np.ndarray: + """Apply the sparse remap operator along the trailing dimension.""" + values = np.asarray(values) + + if values.ndim == 0: + raise ValueError("Remap weights require at least a 1-D input array.") + + if values.shape[-1] != self.source_size: + raise ValueError( + f"Expected trailing dimension of size {self.source_size}, " + f"got {values.shape[-1]}." + ) + + flat_values = values.reshape(-1, values.shape[-1]) + remapped = (self.matrix @ flat_values.T).T + return remapped.reshape(values.shape[:-1] + (self.destination_size,)) + + +def _cache_key(filename_or_obj: str | PathLike[str]) -> tuple[str, int, int]: + path = Path(filename_or_obj).expanduser().resolve() + stat = path.stat() + return str(path), stat.st_mtime_ns, stat.st_size + + +def load_remap_weights( + filename_or_obj: str | PathLike[str] | xr.Dataset | RemapWeights, +) -> RemapWeights: + """Load or normalize reusable remap weights. + + Path-based inputs are cached by resolved path, mtime, and file size so + repeated loads avoid rebuilding the sparse matrix. + """ + if isinstance(filename_or_obj, RemapWeights): + return filename_or_obj + + if isinstance(filename_or_obj, xr.Dataset): + return RemapWeights.from_file(filename_or_obj) + + cache_key = _cache_key(filename_or_obj) + weights = _WEIGHTS_CACHE.get(cache_key) + if weights is None: + weights = RemapWeights.from_file(filename_or_obj) + _WEIGHTS_CACHE[cache_key] = weights + + return weights From e1e01b614c95ff51202149af7ec3fa943c6f7ac9 Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Wed, 1 Apr 2026 13:05:19 -0500 Subject: [PATCH 2/4] Rename remap apply module --- docs/user-guide/remap-weights.rst | 1 - uxarray/remap/__init__.py | 2 +- uxarray/remap/accessor.py | 2 +- uxarray/remap/{precomputed.py => apply_weights.py} | 2 +- uxarray/remap/weights.py | 9 ++------- 5 files changed, 5 insertions(+), 11 deletions(-) rename uxarray/remap/{precomputed.py => apply_weights.py} (97%) diff --git a/docs/user-guide/remap-weights.rst b/docs/user-guide/remap-weights.rst index 4257d66b9..f9afcb93d 100644 --- a/docs/user-guide/remap-weights.rst +++ b/docs/user-guide/remap-weights.rst @@ -117,4 +117,3 @@ This path is useful when: - remapping needs to be repeated many times - external tools already produce trusted offline maps - you want to stay in Python for applying the map and preserving array metadata - diff --git a/uxarray/remap/__init__.py b/uxarray/remap/__init__.py index 2c7adf907..8ae0e2fec 100644 --- a/uxarray/remap/__init__.py +++ b/uxarray/remap/__init__.py @@ -1,6 +1,6 @@ +from .apply_weights import _apply_weights from .inverse_distance_weighted import _inverse_distance_weighted_remap from .nearest_neighbor import _nearest_neighbor_remap -from .precomputed import _apply_weights from .weights import RemapWeights, load_remap_weights __all__ = ( diff --git a/uxarray/remap/accessor.py b/uxarray/remap/accessor.py index a65dc3ef6..2e5485836 100644 --- a/uxarray/remap/accessor.py +++ b/uxarray/remap/accessor.py @@ -7,10 +7,10 @@ from uxarray.core.dataset import UxDataset from uxarray.grid.grid import Grid +from uxarray.remap.apply_weights import _apply_weights from uxarray.remap.bilinear import _bilinear from uxarray.remap.inverse_distance_weighted import _inverse_distance_weighted_remap from uxarray.remap.nearest_neighbor import _nearest_neighbor_remap -from uxarray.remap.precomputed import _apply_weights class RemapAccessor: diff --git a/uxarray/remap/precomputed.py b/uxarray/remap/apply_weights.py similarity index 97% rename from uxarray/remap/precomputed.py rename to uxarray/remap/apply_weights.py index e83bcf295..3d7b84573 100644 --- a/uxarray/remap/precomputed.py +++ b/uxarray/remap/apply_weights.py @@ -26,7 +26,7 @@ def _get_source_dim( if len(spatial_dims) > 1: raise ValueError( - f"Precomputed weight application does not support variables with multiple " + f"Weight application does not support variables with multiple " f"spatial dimensions. Got {spatial_dims!r} for variable {da.name!r}." ) diff --git a/uxarray/remap/weights.py b/uxarray/remap/weights.py index c0b141555..a6f95beae 100644 --- a/uxarray/remap/weights.py +++ b/uxarray/remap/weights.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from os import PathLike from pathlib import Path -from typing import Any import numpy as np import xarray as xr @@ -16,9 +15,7 @@ def _first_present(mapping, names: tuple[str, ...], kind: str): for name in names: if name in mapping: return mapping[name] - raise ValueError( - f"Could not find {kind}. Expected one of: {', '.join(names)}." - ) + raise ValueError(f"Could not find {kind}. Expected one of: {', '.join(names)}.") def _normalize_indices(indices: np.ndarray, size: int, label: str) -> np.ndarray: @@ -60,9 +57,7 @@ def from_file(cls, filename_or_obj: str | PathLike[str] | xr.Dataset): try: source_size = int( - _first_present( - ds.sizes, ("n_a", "src_grid_size"), "source dimension" - ) + _first_present(ds.sizes, ("n_a", "src_grid_size"), "source dimension") ) destination_size = int( _first_present( From 9049dcd56e90744dc7576dea890fd8317406b2bf Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Thu, 21 May 2026 10:42:07 -0500 Subject: [PATCH 3/4] address copilot review on apply_weights/remap weights - prefer start_index attr; treat as 1-based only when max==size - reject non-spatial source_dim - preserve aux coords whose dims are subset of other_dims - document eager dask materialization - LRU-bound weights cache; add clear_remap_weights_cache() - use _open_dataset_with_fallback in RemapWeights.from_file --- test/precomputed_weights_test.py | 100 +++++++++++++++++++++++++++++++ uxarray/__init__.py | 3 +- uxarray/remap/__init__.py | 3 +- uxarray/remap/accessor.py | 6 ++ uxarray/remap/apply_weights.py | 19 +++++- uxarray/remap/weights.py | 68 +++++++++++++++++---- 6 files changed, 182 insertions(+), 17 deletions(-) diff --git a/test/precomputed_weights_test.py b/test/precomputed_weights_test.py index 0ea243db5..8ba5ce501 100644 --- a/test/precomputed_weights_test.py +++ b/test/precomputed_weights_test.py @@ -2,9 +2,16 @@ import numpy as np import numpy.testing as nt +import pytest import uxarray as ux import xarray as xr +from uxarray.remap.weights import ( + _WEIGHTS_CACHE, + _WEIGHTS_CACHE_MAXSIZE, + _normalize_indices, +) + def _write_sparse_map(path: Path, source_size: int, destination_size: int) -> Path: rows = np.arange(1, destination_size + 1, dtype=np.int32) @@ -94,3 +101,96 @@ def test_apply_weights_reuses_loaded_operator(tmp_path, gridpath): nt.assert_array_equal(remapped["a"].values, source["a"].values[:, ::-1]) nt.assert_array_equal(remapped["flag"].values, source["flag"].values) nt.assert_array_equal(remapped_again.values, source["a"].values[:, ::-1]) + + +def test_normalize_indices_respects_start_index_attr(): + # 0-based array with an explicit start_index=0 attr — must not shift. + arr = xr.DataArray(np.array([0, 1, 2], dtype=np.int32), attrs={"start_index": 0}) + nt.assert_array_equal(_normalize_indices(arr, 4, "Row"), np.array([0, 1, 2])) + + # 1-based array with explicit start_index=1 attr. + arr1 = xr.DataArray(np.array([1, 2, 3], dtype=np.int32), attrs={"start_index": 1}) + nt.assert_array_equal(_normalize_indices(arr1, 3, "Row"), np.array([0, 1, 2])) + + +def test_normalize_indices_partial_zero_based_not_shifted(): + # 0-based partial coverage: min=1, max < size. Previous heuristic + # would have wrongly shifted to -1; new heuristic keeps as 0-based. + arr = np.array([1, 2, 3], dtype=np.int32) + nt.assert_array_equal(_normalize_indices(arr, 10, "Row"), arr) + + +def test_normalize_indices_one_based_detected_by_max(): + arr = np.array([1, 2, 3, 4], dtype=np.int32) + nt.assert_array_equal( + _normalize_indices(arr, 4, "Row"), np.array([0, 1, 2, 3]) + ) + + +def test_normalize_indices_out_of_bounds_raises(): + with pytest.raises(ValueError, match="out of bounds"): + _normalize_indices(np.array([-1, 0, 1]), 4, "Row") + + +def test_apply_weights_rejects_non_spatial_source_dim(tmp_path, gridpath): + grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) + weight_file = _write_sparse_map( + tmp_path / "reverse_map.nc", grid.n_face, grid.n_face + ) + + source = ux.UxDataArray( + xr.DataArray( + np.arange(grid.n_face, dtype=np.float64), + dims=["n_face"], + name="t", + ), + uxgrid=grid, + ) + + with pytest.raises(ValueError, match="not a spatial dimension"): + source.remap.apply_weights(weight_file, grid, source_dim="time") + + +def test_apply_weights_preserves_aux_coords(tmp_path, gridpath): + grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) + weight_file = _write_sparse_map( + tmp_path / "reverse_map.nc", grid.n_face, grid.n_face + ) + + nt_steps = 3 + da = xr.DataArray( + np.arange(nt_steps * grid.n_face, dtype=np.float64).reshape( + nt_steps, grid.n_face + ), + dims=("time", "n_face"), + coords={ + "time": np.array([10, 20, 30], dtype=np.int64), + "time_label": ("time", np.array(["a", "b", "c"])), + }, + name="t", + ) + source = ux.UxDataArray(da, uxgrid=grid) + remapped = source.remap.apply_weights(weight_file, grid) + assert "time_label" in remapped.coords + nt.assert_array_equal(remapped["time_label"].values, np.array(["a", "b", "c"])) + + +def test_clear_remap_weights_cache(tmp_path, gridpath): + grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) + weight_file = _write_sparse_map( + tmp_path / "reverse_map.nc", grid.n_face, grid.n_face + ) + ux.load_remap_weights(weight_file) + assert len(_WEIGHTS_CACHE) > 0 + ux.clear_remap_weights_cache() + assert len(_WEIGHTS_CACHE) == 0 + + +def test_remap_weights_cache_is_lru_bounded(tmp_path, gridpath): + grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) + ux.clear_remap_weights_cache() + for i in range(_WEIGHTS_CACHE_MAXSIZE + 5): + path = tmp_path / f"map_{i}.nc" + _write_sparse_map(path, grid.n_face, grid.n_face) + ux.load_remap_weights(path) + assert len(_WEIGHTS_CACHE) == _WEIGHTS_CACHE_MAXSIZE diff --git a/uxarray/__init__.py b/uxarray/__init__.py index cceb73bf7..a5ab902cb 100644 --- a/uxarray/__init__.py +++ b/uxarray/__init__.py @@ -11,7 +11,7 @@ from .core.dataarray import UxDataArray from .core.dataset import UxDataset from .grid import Grid -from .remap import RemapWeights, load_remap_weights +from .remap import RemapWeights, clear_remap_weights_cache, load_remap_weights try: from importlib.metadata import version as _version @@ -40,4 +40,5 @@ "Grid", "RemapWeights", "load_remap_weights", + "clear_remap_weights_cache", ) diff --git a/uxarray/remap/__init__.py b/uxarray/remap/__init__.py index 8ae0e2fec..765f10425 100644 --- a/uxarray/remap/__init__.py +++ b/uxarray/remap/__init__.py @@ -1,11 +1,12 @@ from .apply_weights import _apply_weights from .inverse_distance_weighted import _inverse_distance_weighted_remap from .nearest_neighbor import _nearest_neighbor_remap -from .weights import RemapWeights, load_remap_weights +from .weights import RemapWeights, clear_remap_weights_cache, load_remap_weights __all__ = ( "RemapWeights", "load_remap_weights", + "clear_remap_weights_cache", "_apply_weights", "_nearest_neighbor_remap", "_inverse_distance_weighted_remap", diff --git a/uxarray/remap/accessor.py b/uxarray/remap/accessor.py index 7fb48dd86..b2d3bd3d4 100644 --- a/uxarray/remap/accessor.py +++ b/uxarray/remap/accessor.py @@ -259,6 +259,12 @@ def apply_weights( ------- UxDataArray or UxDataset A new object with data mapped onto ``destination_grid``. + + Notes + ----- + Dask-backed inputs are materialized in memory before the sparse + operator is applied. For lazy/chunked execution, prefer + ``nearest_neighbor`` or ``inverse_distance_weighted``. """ return _apply_weights( diff --git a/uxarray/remap/apply_weights.py b/uxarray/remap/apply_weights.py index 3d7b84573..4db8d2b71 100644 --- a/uxarray/remap/apply_weights.py +++ b/uxarray/remap/apply_weights.py @@ -31,6 +31,11 @@ def _get_source_dim( ) if source_dim is not None: + if source_dim not in SPATIAL_DIMS: + raise ValueError( + f"source_dim {source_dim!r} is not a spatial dimension. " + f"Expected one of {sorted(SPATIAL_DIMS)}." + ) if source_dim not in da.dims: return None if da.sizes[source_dim] != weights.source_size: @@ -51,7 +56,12 @@ def _apply_weights( remap_to: str = "faces", source_dim: str | None = None, ): - """Apply a sparse remap operator to UXarray data.""" + """Apply a sparse remap operator to UXarray data. + + Note: this path materializes dask-backed inputs eagerly when applying + the sparse operator. For lazy/chunked execution, use one of the other + remap methods (e.g., ``nearest_neighbor``, ``inverse_distance_weighted``). + """ _assert_dimension(remap_to) weights_obj = load_remap_weights(weights) @@ -79,7 +89,12 @@ def _apply_weights( da_t = da.transpose(*other_dims, variable_source_dim) remapped_values = weights_obj.apply(np.asarray(da_t.values)) - coords = {dim: da.coords[dim] for dim in other_dims if dim in da.coords} + other_dims_set = set(other_dims) + coords = { + coord_name: coord + for coord_name, coord in da.coords.items() + if set(coord.dims).issubset(other_dims_set) + } da_out = uxarray.core.dataarray.UxDataArray( remapped_values, dims=other_dims + [destination_dim], diff --git a/uxarray/remap/weights.py b/uxarray/remap/weights.py index a6f95beae..7ce8a3051 100644 --- a/uxarray/remap/weights.py +++ b/uxarray/remap/weights.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections import OrderedDict from dataclasses import dataclass from os import PathLike from pathlib import Path @@ -8,7 +9,16 @@ import xarray as xr from scipy import sparse -_WEIGHTS_CACHE: dict[tuple[str, int, int], "RemapWeights"] = {} +from uxarray.core.utils import _open_dataset_with_fallback + +# LRU-bounded cache for loaded remap operators. +_WEIGHTS_CACHE_MAXSIZE = 32 +_WEIGHTS_CACHE: "OrderedDict[tuple[str, int, int], RemapWeights]" = OrderedDict() + + +def clear_remap_weights_cache() -> None: + """Drop all cached remap weight operators.""" + _WEIGHTS_CACHE.clear() def _first_present(mapping, names: tuple[str, ...], kind: str): @@ -18,20 +28,47 @@ def _first_present(mapping, names: tuple[str, ...], kind: str): raise ValueError(f"Could not find {kind}. Expected one of: {', '.join(names)}.") -def _normalize_indices(indices: np.ndarray, size: int, label: str) -> np.ndarray: - indices = np.asarray(indices, dtype=np.int64).ravel() - if indices.size == 0: - return indices +def _normalize_indices(indices, size: int, label: str) -> np.ndarray: + """Convert SCRIP/ESMF-style index arrays to 0-based ``np.int64``. + + Prefers an explicit ``start_index`` attribute (SCRIP/ESMF convention). + Without one, treats indices as 1-based only when ``max == size`` (the + only unambiguous evidence); otherwise treats them as 0-based when in + range, and raises on anything else. + """ + start_index = None + if hasattr(indices, "attrs"): + start_index = indices.attrs.get("start_index") + + arr = np.asarray(indices, dtype=np.int64).ravel() + if arr.size == 0: + return arr + + arr_min = int(arr.min()) + arr_max = int(arr.max()) + + if start_index is None: + if arr_max == size: + start_index = 1 + elif arr_min >= 0 and arr_max < size: + start_index = 0 + else: + raise ValueError( + f"{label} indices are out of bounds for size {size}. " + f"Found min={arr_min}, max={arr_max}." + ) + else: + start_index = int(start_index) - if indices.min() >= 1 and indices.max() <= size: - indices = indices - 1 - elif indices.min() < 0 or indices.max() >= size: + arr = arr - start_index + if arr.min() < 0 or arr.max() >= size: raise ValueError( - f"{label} indices are out of bounds for size {size}. " - f"Found min={indices.min()}, max={indices.max()}." + f"{label} indices are out of bounds for size {size} " + f"with start_index={start_index}. " + f"Found min={int(arr.min())}, max={int(arr.max())}." ) - return indices + return arr @dataclass(frozen=True) @@ -51,7 +88,7 @@ def from_file(cls, filename_or_obj: str | PathLike[str] | xr.Dataset): close_ds = False path = None else: - ds = xr.open_dataset(filename_or_obj) + ds = _open_dataset_with_fallback(filename_or_obj) close_ds = True path = str(filename_or_obj) @@ -130,7 +167,8 @@ def load_remap_weights( """Load or normalize reusable remap weights. Path-based inputs are cached by resolved path, mtime, and file size so - repeated loads avoid rebuilding the sparse matrix. + repeated loads avoid rebuilding the sparse matrix. The cache is + LRU-bounded; call :func:`clear_remap_weights_cache` to drop all entries. """ if isinstance(filename_or_obj, RemapWeights): return filename_or_obj @@ -143,5 +181,9 @@ def load_remap_weights( if weights is None: weights = RemapWeights.from_file(filename_or_obj) _WEIGHTS_CACHE[cache_key] = weights + while len(_WEIGHTS_CACHE) > _WEIGHTS_CACHE_MAXSIZE: + _WEIGHTS_CACHE.popitem(last=False) + else: + _WEIGHTS_CACHE.move_to_end(cache_key) return weights From 9361209d96910c5dc78b663bcb811acb42a9e03f Mon Sep 17 00:00:00 2001 From: Rajeev Jain Date: Fri, 22 May 2026 11:24:26 -0500 Subject: [PATCH 4/4] address remap weights review comments --- docs/api.rst | 1 - docs/user-guide/remap-weights.rst | 21 ++++++++------------ docs/user-guide/remapping.ipynb | 6 ++++++ test/precomputed_weights_test.py | 33 ++++++++++++++----------------- uxarray/__init__.py | 4 ---- uxarray/remap/accessor.py | 8 ++++---- 6 files changed, 33 insertions(+), 40 deletions(-) diff --git a/docs/api.rst b/docs/api.rst index ba3baa4f9..6b3986cbb 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -22,7 +22,6 @@ Top Level Functions open_multigrid open_mfdataset concat - load_remap_weights Tutorial -------- diff --git a/docs/user-guide/remap-weights.rst b/docs/user-guide/remap-weights.rst index f9afcb93d..1f21496e3 100644 --- a/docs/user-guide/remap-weights.rst +++ b/docs/user-guide/remap-weights.rst @@ -1,7 +1,7 @@ .. currentmodule:: uxarray -Remap Weights -============= +Remap with Weights +================== UXarray can apply precomputed offline remapping weights produced outside of UXarray. This is useful when weights are generated once with tools such as ESMF or @@ -11,9 +11,7 @@ slices, or variables. The core workflow is: 1. Generate a weight file for a specific source grid and destination grid. -2. Load the weight file once with :func:`load_remap_weights`. -3. Reuse the loaded :class:`RemapWeights` object with :meth:`UxDataArray.remap.apply_weights` - or :meth:`UxDataset.remap.apply_weights`. +2. Apply it with :meth:`UxDataArray.remap.apply_weights` or :meth:`UxDataset.remap.apply_weights`. Basic Usage ----------- @@ -25,15 +23,10 @@ Basic Usage src = ux.open_dataset("source_grid.nc", "source_data.nc") dst = ux.open_grid("destination_grid.nc") - weights = ux.load_remap_weights("map.nc") + remapped_temperature = src["temperature"].remap.apply_weights(dst, "map.nc") + remapped_dataset = src.remap.apply_weights(dst, "map.nc") - remapped_temperature = src["temperature"].remap.apply_weights( - weights, destination_grid=dst - ) - - remapped_dataset = src.remap.apply_weights(weights, destination_grid=dst) - -Repeated calls with the same path reuse a cached sparse operator, so loading the +Repeated calls with the same path reuse a cached sparse operator, so applying the same file again in one Python session avoids rebuilding the matrix. What A Weight File Represents @@ -90,6 +83,8 @@ This implementation was verified against real files from both families: In practice, UXarray supports the standard full offline map format used by both tools. +Currently, this API applies externally generated sparse remap files. Generating reusable UXarray weight maps can be added as a future extension. + Current caveats: - The source data ordering must match the source ordering encoded in the weight file. diff --git a/docs/user-guide/remapping.ipynb b/docs/user-guide/remapping.ipynb index e1d41c775..3e291a083 100644 --- a/docs/user-guide/remapping.ipynb +++ b/docs/user-guide/remapping.ipynb @@ -620,6 +620,12 @@ "It would be helpful to recall the data array contents again:" ] }, + { + "cell_type": "markdown", + "id": "d5559d1e", + "source": "## Remap with Precomputed Weights\n\nUse `.remap.apply_weights(destination_grid, weight_file)` when weights were generated externally with tools such as ESMF or TempestRemap. This path reuses a sparse offline map instead of constructing weights inside UXarray.\n\n```python\nremapped = source.remap.apply_weights(destination_grid, \"map.nc\")\n```\n\nSee [Remap with Weights](./remap-weights.rst) for file format details and examples.", + "metadata": {} + }, { "cell_type": "code", "execution_count": null, diff --git a/test/precomputed_weights_test.py b/test/precomputed_weights_test.py index 8ba5ce501..8b86b96f7 100644 --- a/test/precomputed_weights_test.py +++ b/test/precomputed_weights_test.py @@ -6,11 +6,8 @@ import uxarray as ux import xarray as xr -from uxarray.remap.weights import ( - _WEIGHTS_CACHE, - _WEIGHTS_CACHE_MAXSIZE, - _normalize_indices, -) +from uxarray.remap import RemapWeights, clear_remap_weights_cache, load_remap_weights +from uxarray.remap.weights import _WEIGHTS_CACHE, _WEIGHTS_CACHE_MAXSIZE, _normalize_indices def _write_sparse_map(path: Path, source_size: int, destination_size: int) -> Path: @@ -40,13 +37,13 @@ def test_load_remap_weights_and_apply_vector(tmp_path, gridpath): tmp_path / "reverse_map.nc", grid.n_face, grid.n_face ) - weights = ux.load_remap_weights(weight_file) + weights = load_remap_weights(weight_file) result = weights.apply(np.arange(grid.n_face, dtype=np.float64)) nt.assert_equal(weights.source_size, grid.n_face) nt.assert_equal(weights.destination_size, grid.n_face) nt.assert_array_equal(result, np.arange(grid.n_face, dtype=np.float64)[::-1]) - assert isinstance(weights, ux.RemapWeights) + assert isinstance(weights, RemapWeights) def test_apply_weights_to_uxdataarray(tmp_path, gridpath): @@ -65,7 +62,7 @@ def test_apply_weights_to_uxdataarray(tmp_path, gridpath): uxgrid=grid, ) - remapped = source.remap.apply_weights(weight_file, grid) + remapped = source.remap.apply_weights(grid, weight_file) nt.assert_array_equal(remapped.values, source.values[::-1]) nt.assert_equal(remapped.attrs["units"], "K") @@ -77,8 +74,8 @@ def test_apply_weights_reuses_loaded_operator(tmp_path, gridpath): weight_file = _write_sparse_map( tmp_path / "reverse_map.nc", grid.n_face, grid.n_face ) - weights = ux.load_remap_weights(weight_file) - cached_weights = ux.load_remap_weights(weight_file) + weights = load_remap_weights(weight_file) + cached_weights = load_remap_weights(weight_file) source = ux.UxDataset( xr.Dataset( @@ -94,8 +91,8 @@ def test_apply_weights_reuses_loaded_operator(tmp_path, gridpath): uxgrid=grid, ) - remapped = source.remap.apply_weights(weights, grid) - remapped_again = source["a"].remap.apply_weights(weights, grid) + remapped = source.remap.apply_weights(grid, weights) + remapped_again = source["a"].remap.apply_weights(grid, weights) assert cached_weights is weights nt.assert_array_equal(remapped["a"].values, source["a"].values[:, ::-1]) @@ -148,7 +145,7 @@ def test_apply_weights_rejects_non_spatial_source_dim(tmp_path, gridpath): ) with pytest.raises(ValueError, match="not a spatial dimension"): - source.remap.apply_weights(weight_file, grid, source_dim="time") + source.remap.apply_weights(grid, weight_file, source_dim="time") def test_apply_weights_preserves_aux_coords(tmp_path, gridpath): @@ -170,7 +167,7 @@ def test_apply_weights_preserves_aux_coords(tmp_path, gridpath): name="t", ) source = ux.UxDataArray(da, uxgrid=grid) - remapped = source.remap.apply_weights(weight_file, grid) + remapped = source.remap.apply_weights(grid, weight_file) assert "time_label" in remapped.coords nt.assert_array_equal(remapped["time_label"].values, np.array(["a", "b", "c"])) @@ -180,17 +177,17 @@ def test_clear_remap_weights_cache(tmp_path, gridpath): weight_file = _write_sparse_map( tmp_path / "reverse_map.nc", grid.n_face, grid.n_face ) - ux.load_remap_weights(weight_file) + load_remap_weights(weight_file) assert len(_WEIGHTS_CACHE) > 0 - ux.clear_remap_weights_cache() + clear_remap_weights_cache() assert len(_WEIGHTS_CACHE) == 0 def test_remap_weights_cache_is_lru_bounded(tmp_path, gridpath): grid = ux.open_grid(gridpath("ugrid", "quad-hexagon", "grid.nc")) - ux.clear_remap_weights_cache() + clear_remap_weights_cache() for i in range(_WEIGHTS_CACHE_MAXSIZE + 5): path = tmp_path / f"map_{i}.nc" _write_sparse_map(path, grid.n_face, grid.n_face) - ux.load_remap_weights(path) + load_remap_weights(path) assert len(_WEIGHTS_CACHE) == _WEIGHTS_CACHE_MAXSIZE diff --git a/uxarray/__init__.py b/uxarray/__init__.py index a5ab902cb..c6cb8ec3d 100644 --- a/uxarray/__init__.py +++ b/uxarray/__init__.py @@ -11,7 +11,6 @@ from .core.dataarray import UxDataArray from .core.dataset import UxDataset from .grid import Grid -from .remap import RemapWeights, clear_remap_weights_cache, load_remap_weights try: from importlib.metadata import version as _version @@ -38,7 +37,4 @@ "INT_DTYPE", "INT_FILL_VALUE", "Grid", - "RemapWeights", - "load_remap_weights", - "clear_remap_weights_cache", ) diff --git a/uxarray/remap/accessor.py b/uxarray/remap/accessor.py index b2d3bd3d4..ee6dd3b09 100644 --- a/uxarray/remap/accessor.py +++ b/uxarray/remap/accessor.py @@ -35,7 +35,7 @@ def __repr__(self) -> str: + "Supported methods:\n" + " • nearest_neighbor(destination_grid, remap_to='faces')\n" + " • inverse_distance_weighted(destination_grid, remap_to='faces', power=2, k=8)\n" - + " • apply_weights(weights, destination_grid, remap_to='faces')\n" + + " • apply_weights(destination_grid, weights, remap_to='faces')\n" ) def __call__( @@ -232,8 +232,8 @@ def bilinear( def apply_weights( self, - weights, destination_grid: Grid, + weights, remap_to: str = "faces", source_dim: str | None = None, ) -> UxDataArray | UxDataset: @@ -242,12 +242,12 @@ def apply_weights( Parameters ---------- + destination_grid : Grid + Grid representing the destination topology and coordinates. weights : str, PathLike, xr.Dataset, or RemapWeights Weight file or reusable loaded weights. Standard SCRIP/ESMF sparse map files are expected to provide ``row``, ``col``, ``S``, and dimensions ``n_a``/``n_b``. - destination_grid : Grid - Grid representing the destination topology and coordinates. remap_to : {'nodes', 'edges', 'faces'}, default='faces' Which destination grid element receives the remapped values. source_dim : {'n_node', 'n_edge', 'n_face'}, optional