diff --git a/docs/api.rst b/docs/api.rst index 9cdaf9992..f595a487d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -552,6 +552,7 @@ Zonal Average UxDataArray.zonal_average UxDataArray.zonal_mean + UxDataArray.zonal_anomaly Weighted diff --git a/docs/user-guide/zonal-average.ipynb b/docs/user-guide/zonal-average.ipynb index 2eab4a091..88471d8f2 100644 --- a/docs/user-guide/zonal-average.ipynb +++ b/docs/user-guide/zonal-average.ipynb @@ -745,6 +745,62 @@ "preview_levels = per_level_max.isel({level_dim: slice(0, 5)})\n", "preview_levels" ] + }, + { + "cell_type": "markdown", + "id": "1af6beaf", + "source": "## 7. Zonal Anomalies\n\nA zonal anomaly is the per-face departure from the mean of its latitude band. `zonal_anomaly` returns a `UxDataArray` with the same dims and dtype as the input (integer dtypes are promoted to float so empty bands can hold `NaN`).\n\n- **Centroid mode** (`conservative=False`, default): each face is assigned to one band by its centroid latitude (`np.digitize`). The unweighted per-band mean is exactly zero.\n- **Conservative mode** (`conservative=True`): faces straddling band edges contribute to multiple bands by area overlap (reusing the `zonal_mean` weight kernel), so per-band means are small but not exactly zero.\n\n### Step 7.1: Compute the centroid-mode anomaly", + "metadata": {} + }, + { + "cell_type": "code", + "id": "c0347122", + "source": "anomaly = uxds[\"psi\"].zonal_anomaly(lat=(-90, 90, 10))\nanomaly", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "1783125a", + "source": "### Step 7.2: Verify the per-band sum-to-zero property\n\nIn centroid mode every populated band has an unweighted mean of exactly zero. We can confirm by binning faces the same way `zonal_anomaly` does (`np.digitize`) and reducing each band.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "ed0f731f", + "source": "bands = np.arange(-90, 91, 10)\nface_lat = uxds.uxgrid.face_lat.values\nband_idx = np.clip(np.digitize(face_lat, bands) - 1, 0, len(bands) - 2)\n\nfor bi in range(len(bands) - 1):\n mask = band_idx == bi\n if not mask.any():\n continue\n band_mean = float(anomaly.values[mask].mean())\n print(\n f\"band [{bands[bi]:+4d}, {bands[bi + 1]:+4d}) \"\n f\"n_faces={int(mask.sum()):4d} mean={band_mean:+.2e}\"\n )", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "41dbb45d", + "source": "### Step 7.3: Visualize the anomaly field\n\nPlot the anomaly map and the zonal mean it was subtracted from side by side using a diverging colormap centered at zero.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "84cd278e", + "source": "vmax = float(np.nanmax(np.abs(anomaly.values)))\nanomaly_map = anomaly.plot(\n cmap=\"RdBu_r\",\n periodic_elements=\"split\",\n clim=(-vmax, vmax),\n title=\"Zonal Anomaly (psi - zonal mean)\",\n).opts(width=525, height=400, colorbar=True)\n\nzm = uxds[\"psi\"].zonal_mean(lat=(-90, 90, 10))\nzm_df = zm.to_dataframe(name=\"zonal_mean\").reset_index()\nzm_panel = zm_df.hvplot.line(\n x=\"zonal_mean\",\n y=\"latitudes\",\n line_width=2,\n title=\"Zonal Mean (subtracted)\",\n ylim=(-90, 90),\n width=400,\n height=400,\n).opts(show_grid=True)\n\n(anomaly_map + zm_panel).cols(2)", + "metadata": {}, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "id": "15c921d4", + "source": "### Step 7.4: Compare centroid vs conservative anomaly\n\nConservative mode blends straddling faces across band edges, so its anomaly differs slightly from the centroid version. The difference shows where face geometry crosses band boundaries.", + "metadata": {} + }, + { + "cell_type": "code", + "id": "36a9096c", + "source": "anomaly_cons = uxds[\"psi\"].zonal_anomaly(lat=(-90, 90, 10), conservative=True)\ndiff = anomaly_cons - anomaly\n\nprint(f\"max |centroid| = {float(np.nanmax(np.abs(anomaly.values))):.4f}\")\nprint(f\"max |conservative| = {float(np.nanmax(np.abs(anomaly_cons.values))):.4f}\")\nprint(f\"max |difference| = {float(np.nanmax(np.abs(diff.values))):.4f}\")", + "metadata": {}, + "execution_count": null, + "outputs": [] } ], "metadata": { diff --git a/test/grid/integrate/test_zonal.py b/test/grid/integrate/test_zonal.py index 9fc400d59..75c3921a7 100644 --- a/test/grid/integrate/test_zonal.py +++ b/test/grid/integrate/test_zonal.py @@ -242,3 +242,161 @@ def test_conservative_vs_nonconservative_comparison(self, gridpath, datasetpath) # Check they are in the same ballpark assert np.all(np.abs(conservative.values - non_conservative.values) < np.abs(non_conservative.values) * 0.5) + + +class TestZonalAnomaly: + """Tests for UxDataArray.zonal_anomaly.""" + + def _open(self, gridpath, datasetpath): + grid_path = gridpath("ugrid", "outCSne30", "outCSne30.ug") + data_path = datasetpath("ugrid", "outCSne30", "outCSne30_vortex.nc") + return ux.open_dataset(grid_path, data_path) + + def test_output_dims_match_input(self, gridpath, datasetpath): + """Output shape and dims must equal input (face axis preserved).""" + uxds = self._open(gridpath, datasetpath) + psi = uxds["psi"] + res = psi.zonal_anomaly(lat=(-90, 90, 30)) + assert res.shape == psi.shape + assert res.dims == psi.dims + assert "n_face" in res.dims + + def test_conservative_anomaly_band_mean_small(self, gridpath, datasetpath): + """Conservative anomaly: per-band area-weighted mean is small. + + Faces that straddle a band boundary are intentionally blended across + neighbouring band means (sharing the same weight kernel as + zonal_mean), so the per-band mean is not exactly zero — but it must + be small relative to the raw signal magnitude. + """ + uxds = self._open(gridpath, datasetpath) + bands = np.array([-90.0, -30.0, 30.0, 90.0]) + anom = uxds["psi"].zonal_anomaly(lat=bands, conservative=True) + + raw_std = float(uxds["psi"].values.std()) + per_band = _compute_face_band_weights(uxds["psi"].uxgrid, bands) + for overlapping, w in per_band: + if overlapping.size == 0: + continue + vals = anom.isel(n_face=overlapping, ignore_grid=True).values + weighted = abs((w * vals).sum() / w.sum()) + assert weighted < raw_std * 0.05 + + def test_band_anomaly_centroid_sums_to_zero(self, gridpath, datasetpath): + """Non-conservative anomaly: simple mean within each band ≈ 0.""" + uxds = self._open(gridpath, datasetpath) + bands = np.array([-90.0, -30.0, 30.0, 90.0]) + psi = uxds["psi"] + anom = psi.zonal_anomaly(lat=bands, conservative=False) + + face_lats = psi.uxgrid.face_lat.values + for bi in range(len(bands) - 1): + mask = (face_lats >= bands[bi]) & (face_lats < bands[bi + 1]) + if bi == len(bands) - 2: + mask |= face_lats == bands[bi + 1] + if not mask.any(): + continue + assert anom.values[mask].mean() == pytest.approx(0.0, abs=1e-12) + + def test_multidim_face_not_last_axis(self): + """Works when n_face is not the last axis and preserves other dims.""" + uxgrid = ux.Grid.from_healpix(zoom=1) + # shape (T, n_face, L); face is axis=1 + T, L = 3, 4 + rng = np.random.default_rng(0) + data = rng.standard_normal((T, uxgrid.n_face, L)) + uxda = ux.UxDataArray( + data, dims=["time", "n_face", "level"], uxgrid=uxgrid + ) + + anom = uxda.zonal_anomaly(lat=(-90, 90, 30)) + assert anom.shape == uxda.shape + assert anom.dims == uxda.dims + + # Per band, per (t, l), the anomaly mean should be ~0. + face_lats = uxgrid.face_lat.values + bands = np.linspace(-90, 90, int(round(180 / 30)) + 1) + for bi in range(len(bands) - 1): + mask = (face_lats >= bands[bi]) & (face_lats < bands[bi + 1]) + if bi == len(bands) - 2: + mask |= face_lats == bands[bi + 1] + if not mask.any(): + continue + band_vals = anom.values[:, mask, :] + # Mean across face dim per (t, l) should be ~0 + nt.assert_allclose(band_vals.mean(axis=1), 0.0, atol=1e-12) + + def test_dask_input_stays_lazy(self, gridpath, datasetpath): + """Centroid path keeps dask laziness when input is chunked.""" + uxds = self._open(gridpath, datasetpath) + uxds["psi"] = uxds["psi"].chunk() + res = uxds["psi"].zonal_anomaly(lat=(-90, 90, 30)) + assert isinstance(res.data, da.Array) + # Verify computation still produces finite numbers + computed = res.compute() + assert np.all(np.isfinite(computed.values)) + + def test_dask_input_conservative_lazy(self, gridpath, datasetpath): + """Conservative path keeps dask laziness for the subtract step.""" + uxds = self._open(gridpath, datasetpath) + uxds["psi"] = uxds["psi"].chunk() + res = uxds["psi"].zonal_anomaly(lat=(-90, 90, 30), conservative=True) + assert isinstance(res.data, da.Array) + computed = res.compute() + assert np.all(np.isfinite(computed.values)) + + def test_conservative_vs_centroid_close(self, gridpath, datasetpath): + """Conservative and centroid anomalies should be comparable in magnitude.""" + uxds = self._open(gridpath, datasetpath) + bands = np.array([-90.0, -30.0, 30.0, 90.0]) + a_cons = uxds["psi"].zonal_anomaly(lat=bands, conservative=True) + a_cent = uxds["psi"].zonal_anomaly(lat=bands, conservative=False) + # Same shape + assert a_cons.shape == a_cent.shape + # Same order of magnitude (allow generous tolerance — methods differ) + std_cons = float(np.nanstd(a_cons.values)) + std_cent = float(np.nanstd(a_cent.values)) + assert std_cons > 0 and std_cent > 0 + assert 0.25 < std_cons / std_cent < 4.0 + + def test_int_input_promotes_dtype(self): + """Integer inputs are promoted so NaN-bearing anomalies fit.""" + uxgrid = ux.Grid.from_healpix(zoom=1) + uxda = ux.UxDataArray( + np.ones(uxgrid.n_face, dtype=np.int32), + dims=["n_face"], + uxgrid=uxgrid, + ) + res = uxda.zonal_anomaly(lat=(-90, 90, 30)) + assert np.issubdtype(res.dtype, np.floating) + # All-ones input → all-zero anomalies wherever defined + finite = res.values[np.isfinite(res.values)] + assert finite.size > 0 + nt.assert_allclose(finite, 0.0, atol=1e-12) + + def test_non_face_centered_raises(self, gridpath, datasetpath): + """Only face-centered data is supported.""" + uxgrid = ux.Grid.from_healpix(zoom=1) + uxda = ux.UxDataArray( + np.zeros(uxgrid.n_node), dims=["n_node"], uxgrid=uxgrid + ) + with pytest.raises(ValueError, match="face-centered"): + uxda.zonal_anomaly() + + def test_invalid_lat_input_raises(self): + """Invalid lat specs raise ValueError.""" + uxgrid = ux.Grid.from_healpix(zoom=1) + uxda = ux.UxDataArray( + np.zeros(uxgrid.n_face), dims=["n_face"], uxgrid=uxgrid + ) + with pytest.raises(ValueError, match="Step size"): + uxda.zonal_anomaly(lat=(-90, 90, 0)) + with pytest.raises(ValueError, match="Step size"): + uxda.zonal_anomaly(lat=(-90, 90, -1)) + with pytest.raises(ValueError): + uxda.zonal_anomaly(lat=[42.0]) # too few edges + with pytest.raises(ValueError, match="monotonic"): + uxda.zonal_anomaly(lat=[10.0, -10.0, 30.0]) + + +from uxarray.core.zonal import _compute_face_band_weights # noqa: E402 diff --git a/uxarray/core/dataarray.py b/uxarray/core/dataarray.py index b75199cc7..8ae2315a1 100644 --- a/uxarray/core/dataarray.py +++ b/uxarray/core/dataarray.py @@ -24,6 +24,7 @@ from uxarray.core.zonal import ( _compute_conservative_zonal_mean_bands, _compute_non_conservative_zonal_mean, + _compute_zonal_anomaly, ) from uxarray.cross_sections import UxDataArrayCrossSectionAccessor from uxarray.formatting_html import array_repr @@ -767,6 +768,70 @@ def zonal_average(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs) """Alias of zonal_mean; prefer `zonal_mean` for primary API.""" return self.zonal_mean(lat=lat, conservative=conservative, **kwargs) + def zonal_anomaly(self, lat=(-90, 90, 10), conservative: bool = False): + """Compute the zonal anomaly: each face value minus the mean of its latitude band. + + Returns a new ``UxDataArray`` with the same dimensions as the input, + where each face holds its original value minus the zonal mean of the + latitude band it belongs to. + + Parameters + ---------- + lat : tuple or array-like, default=(-90, 90, 10) + Latitude band specification: + - tuple (start, end, step): band edges via np.linspace(start, end, n) + - array-like: explicit band edges in degrees + conservative : bool, default=False + If True, uses area-weighted band means and blends across bands for + faces that straddle a band boundary, reusing the face-band weight + matrix computed for zonal_mean so no geometry is duplicated. + If False, assigns each face to a band by its centroid latitude. + + Returns + ------- + UxDataArray + Same dimensions as input with per-face band mean subtracted. + + Examples + -------- + >>> uxds["var"].zonal_anomaly() + >>> uxds["var"].zonal_anomaly(lat=(-60, 60, 5), conservative=True) + """ + if not self._face_centered(): + raise ValueError( + "Zonal anomaly is only supported for face-centered data variables." + ) + + if isinstance(lat, tuple): + start, end, step = lat + if step <= 0: + raise ValueError("Step size must be positive.") + num_points = int(round((end - start) / step)) + 1 + edges = np.linspace(start, end, num_points) + edges = np.clip(edges, -90, 90) + elif isinstance(lat, (list, np.ndarray)): + edges = np.asarray(lat, dtype=float) + else: + raise ValueError( + "Invalid value for 'lat'. Must be a tuple (start, end, step) or array-like band edges." + ) + + if edges.ndim != 1 or edges.size < 2: + raise ValueError("Band edges must be 1D with at least two values.") + + res = _compute_zonal_anomaly(self, edges, conservative=conservative) + + return UxDataArray( + res, + dims=self.dims, + coords=self.coords, + name=self.name + "_zonal_anomaly" + if self.name is not None + else "zonal_anomaly", + attrs={"zonal_anomaly": True, "conservative": conservative}, + uxgrid=self.uxgrid, + ) + def azimuthal_mean( self, center_coord, diff --git a/uxarray/core/zonal.py b/uxarray/core/zonal.py index 173bb3f30..24a957c26 100644 --- a/uxarray/core/zonal.py +++ b/uxarray/core/zonal.py @@ -225,31 +225,39 @@ def _compute_band_overlap_area( return area -def _compute_conservative_zonal_mean_bands(uxda, bands): - """ - Compute conservative zonal mean over latitude bands. +def _compute_face_band_weights(uxgrid, bands): + """Compute overlap area between every face and every latitude band. + + Shared geometry kernel used by both zonal_mean and zonal_anomaly so the + expensive intersection calculations are never duplicated. - Uses get_faces_between_latitudes to optimize computation by avoiding - overlap area calculations for fully contained faces. + Returns a sparse per-band representation so memory scales with the number + of faces that overlap each band (typically O(n_face) total) rather than + O(n_face * n_bands), which would OOM on large grids with fine bands. Parameters ---------- - uxda : UxDataArray - The data array to compute zonal means for + uxgrid : Grid bands : array-like - Latitude band edges in degrees + Latitude band edges in degrees, shape (n_bands + 1,). Must be + monotonic non-decreasing. Returns ------- - result : array - Zonal means for each band + per_band : list of (indices, weights) tuples, length n_bands + For band ``bi``: ``indices`` is an int ndarray of face indices that + overlap the band, and ``weights`` is the corresponding overlap-area + ndarray. Fully-contained faces carry their full face area; partially- + overlapping faces carry the exact intersection area. """ - import dask.array as da - - uxgrid = uxda.uxgrid - face_axis = uxda.get_axis_num("n_face") + bands = np.asarray(bands, dtype=float) + if bands.ndim != 1 or bands.size < 2: + raise ValueError("bands must be 1D with at least two edges") + if np.any(np.diff(bands) < 0): + raise ValueError( + f"bands must be monotonic non-decreasing; got diff(bands)={np.diff(bands)}" + ) - # Pre-compute face properties faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array( uxgrid.face_node_connectivity.values, uxgrid.n_face, @@ -262,25 +270,12 @@ def _compute_conservative_zonal_mean_bands(uxda, bands): face_bounds_lat = uxgrid.face_bounds_lat.values face_areas = uxgrid.face_areas.values - bands = np.asarray(bands, dtype=float) - if bands.ndim != 1 or bands.size < 2: - raise ValueError("bands must be 1D with at least two edges") - nb = bands.size - 1 - - # Initialize result array - shape = list(uxda.shape) - shape[face_axis] = nb - if isinstance(uxda.data, da.Array): - result = da.zeros(shape, dtype=uxda.dtype) - else: - result = np.zeros(shape, dtype=uxda.dtype) + per_band = [] for bi in range(nb): lat0 = float(np.clip(bands[bi], -90.0, 90.0)) lat1 = float(np.clip(bands[bi + 1], -90.0, 90.0)) - - # Ensure lat0 <= lat1 if lat0 > lat1: lat0, lat1 = lat1, lat0 @@ -288,55 +283,219 @@ def _compute_conservative_zonal_mean_bands(uxda, bands): z1 = np.sin(np.deg2rad(lat1)) zmin, zmax = (z0, z1) if z0 <= z1 else (z1, z0) - # Step 1: Get fully contained faces - fully_contained_faces = uxgrid.get_faces_between_latitudes((lat0, lat1)) - - # Step 2: Get all overlapping faces (including partial) mask = ~((face_bounds_lat[:, 1] < lat0) | (face_bounds_lat[:, 0] > lat1)) - all_overlapping_faces = np.nonzero(mask)[0] + all_overlapping = np.nonzero(mask)[0] - if all_overlapping_faces.size == 0: - # No faces in this band - idx = [slice(None)] * result.ndim - idx[face_axis] = bi - result[tuple(idx)] = np.nan + if all_overlapping.size == 0: + per_band.append((np.empty(0, dtype=np.int64), np.empty(0, dtype=float))) continue - # Step 3: Partition faces into fully contained vs partially overlapping - is_fully_contained = np.isin(all_overlapping_faces, fully_contained_faces) - partially_overlapping_faces = all_overlapping_faces[~is_fully_contained] - - # Step 4: Compute weights - all_weights = np.zeros(all_overlapping_faces.size, dtype=float) - - # For fully contained faces, use their full area - if fully_contained_faces.size > 0: - fully_contained_indices = np.where(is_fully_contained)[0] - all_weights[fully_contained_indices] = face_areas[fully_contained_faces] - - # For partially overlapping faces, compute fractional area - if partially_overlapping_faces.size > 0: - partial_indices = np.where(~is_fully_contained)[0] - for i, face_idx in enumerate(partially_overlapping_faces): - nedge = n_nodes_per_face[face_idx] - face_edges = faces_edge_nodes_xyz[face_idx, :nedge] - overlap_area = _compute_band_overlap_area(face_edges, zmin, zmax) - all_weights[partial_indices[i]] = overlap_area - - # Step 5: Compute weighted average - data_slice = uxda.isel(n_face=all_overlapping_faces, ignore_grid=True).data - total_weight = all_weights.sum() - - if total_weight == 0.0: - weighted = np.nan * data_slice[..., 0] - else: - w_shape = [1] * data_slice.ndim - w_shape[face_axis] = all_weights.size - w_reshaped = all_weights.reshape(w_shape) - weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total_weight + fully_contained = uxgrid.get_faces_between_latitudes((lat0, lat1)) + is_fully_contained = np.isin(all_overlapping, fully_contained) + + weights = np.empty(all_overlapping.size, dtype=float) + + fc_mask = is_fully_contained + fc = all_overlapping[fc_mask] + weights[fc_mask] = face_areas[fc] + + partial = all_overlapping[~fc_mask] + partial_pos = np.nonzero(~fc_mask)[0] + for pos, f in zip(partial_pos, partial): + nedge = n_nodes_per_face[f] + weights[pos] = _compute_band_overlap_area( + faces_edge_nodes_xyz[f, :nedge], zmin, zmax + ) + + per_band.append((all_overlapping.astype(np.int64), weights)) + + return per_band + + +def _compute_conservative_zonal_mean_bands(uxda, bands): + """Compute conservative zonal mean over latitude bands. + + Parameters + ---------- + uxda : UxDataArray + bands : array-like + Latitude band edges in degrees + + Returns + ------- + result : array + Zonal means for each band, with n_face axis replaced by n_bands + """ + import dask.array as da + + bands = np.asarray(bands, dtype=float) + per_band = _compute_face_band_weights(uxda.uxgrid, bands) + nb = len(per_band) + face_axis = uxda.get_axis_num("n_face") + + if np.issubdtype(uxda.dtype, np.integer) or np.issubdtype(uxda.dtype, np.bool_): + result_dtype = np.float64 + else: + result_dtype = uxda.dtype + + shape = list(uxda.shape) + shape[face_axis] = nb + if isinstance(uxda.data, da.Array): + result = da.full(shape, np.nan, dtype=result_dtype) + else: + result = np.full(shape, np.nan, dtype=result_dtype) + + for bi, (overlapping, w) in enumerate(per_band): + if overlapping.size == 0: + continue + + total = w.sum() + if total == 0.0 or not np.isfinite(total): + continue + + data_slice = uxda.isel(n_face=overlapping, ignore_grid=True).data + w_shape = [1] * data_slice.ndim + w_shape[face_axis] = w.size + weighted = (data_slice * w.reshape(w_shape)).sum(axis=face_axis) / total idx = [slice(None)] * result.ndim idx[face_axis] = bi result[tuple(idx)] = weighted return result + + +def _compute_zonal_anomaly(uxda, bands, conservative=False): + """Compute zonal anomaly: each face value minus the mean of its latitude band. + + Preserves the input dtype (promoting only integer/bool inputs so NaNs can + fit), the input shape (n_face axis stays in place even if it is not the + last axis), and dask laziness when ``uxda`` is chunked. + + Parameters + ---------- + uxda : UxDataArray + bands : array-like + Latitude band edges in degrees. Must be monotonic non-decreasing. + conservative : bool + If True, uses area-weighted band means and blends across bands for + faces that straddle a boundary, reusing the same sparse weight kernel + as zonal_mean so geometry is computed only once. + If False, assigns each face to a band by centroid latitude. + + Returns + ------- + array-like + Same shape and axis order as ``uxda.data``. Returns a dask array when + ``uxda.data`` is a dask array; otherwise a numpy array. + """ + import dask.array as da + + bands = np.asarray(bands, dtype=float) + if bands.ndim != 1 or bands.size < 2: + raise ValueError("Band edges must be 1D with at least two values.") + if np.any(np.diff(bands) < 0): + raise ValueError( + "Band edges must be monotonic non-decreasing; got " + f"diff(bands)={np.diff(bands)}" + ) + + face_axis = uxda.get_axis_num("n_face") + n_face = uxda.uxgrid.n_face + nb = bands.size - 1 + is_dask = isinstance(uxda.data, da.Array) + + if np.issubdtype(uxda.dtype, np.integer) or np.issubdtype(uxda.dtype, np.bool_): + out_dtype = np.float64 + else: + out_dtype = uxda.dtype + + reduced_shape = list(uxda.shape) + reduced_shape.pop(face_axis) + + def _reshape_along_face(w_1d): + s = [1] * uxda.ndim + s[face_axis] = w_1d.size + return w_1d.reshape(s) + + if conservative: + per_band = _compute_face_band_weights(uxda.uxgrid, bands) + + # Compute per-band means along the n_face axis, preserving other dims. + # band_means is a list of length nb; entries are arrays with shape + # reduced_shape (or None when no overlap). They are small relative to + # uxda, so materializing them is cheap. + band_means = [None] * nb + face_totals = np.zeros(n_face, dtype=float) + + for bi, (overlapping, w) in enumerate(per_band): + if overlapping.size == 0: + continue + total = w.sum() + if total == 0.0 or not np.isfinite(total): + continue + face_totals[overlapping] += w + data_slice = uxda.isel(n_face=overlapping, ignore_grid=True).data + band_mean = (data_slice * _reshape_along_face(w)).sum( + axis=face_axis + ) / total + if isinstance(band_mean, da.Array): + band_mean = band_mean.compute() + band_means[bi] = band_mean.astype(out_dtype, copy=False) + + # face_means_num[..., f, ...] = sum_b W[f,b] * band_mean[b] + # This is the output-shaped per-face mean field. Built eagerly because + # the scatter pattern is awkward in dask; uxda.data itself is not + # touched so its laziness is preserved by the final subtract. + face_means_num = np.zeros(uxda.shape, dtype=out_dtype) + for bi, (overlapping, w) in enumerate(per_band): + if overlapping.size == 0 or band_means[bi] is None: + continue + bm_expanded = np.expand_dims(band_means[bi], face_axis) + contrib = bm_expanded * _reshape_along_face(w) + idx = [slice(None)] * uxda.ndim + idx[face_axis] = overlapping + face_means_num[tuple(idx)] += contrib + + valid = face_totals > 0 + face_means = np.full(uxda.shape, np.nan, dtype=out_dtype) + if valid.any(): + valid_idx = np.nonzero(valid)[0] + idx = [slice(None)] * uxda.ndim + idx[face_axis] = valid_idx + face_means[tuple(idx)] = face_means_num[tuple(idx)] / _reshape_along_face( + face_totals[valid_idx] + ) + + else: + # Centroid-based: fast, no intersection geometry needed. + face_lats = uxda.uxgrid.face_lat.values + band_indices = np.clip(np.digitize(face_lats, bands) - 1, 0, nb - 1) + + # Compute per-band mean reducing only over the face axis. Build a + # stack of shape (nb, *reduced_shape); preserve dask laziness. + per_band_means = [] + for bi in range(nb): + sel = np.nonzero(band_indices == bi)[0] + if sel.size == 0: + if is_dask: + per_band_means.append( + da.full(tuple(reduced_shape), np.nan, dtype=out_dtype) + ) + else: + per_band_means.append( + np.full(tuple(reduced_shape), np.nan, dtype=out_dtype) + ) + else: + sub = uxda.isel(n_face=sel, ignore_grid=True).data + per_band_means.append(sub.mean(axis=face_axis)) + + if is_dask: + band_means = da.stack(per_band_means, axis=0) + face_means_face_first = band_means[band_indices] + else: + band_means = np.stack(per_band_means, axis=0) + face_means_face_first = np.take(band_means, band_indices, axis=0) + face_means = np.moveaxis(face_means_face_first, 0, face_axis) + + return uxda.data - face_means