diff --git a/pyproject.toml b/pyproject.toml index 9a67f5d83..897ec77f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,7 @@ dependencies = [ "matplotlib-scalebar>=0.8", "networkx>=2.6", "numba>=0.56.4", + "numba-progress>=1.1", "numpy>=1.23", "omnipath>=1.0.7", "pandas>=2.1", @@ -66,7 +67,7 @@ dependencies = [ "scikit-image>=0.25", # due to https://github.com/scikit-image/scikit-image/issues/6850 breaks rescale ufunc "scikit-learn>=0.24", - "spatialdata>=0.7.2", # 0.7.2 dropped xarray-schema (pkg_resources break, #1115) + "spatialdata>=0.7.2", # 0.7.2 dropped xarray-schema (pkg_resources break, #1115) "spatialdata-plot>=0.3.3", "statsmodels>=0.12", # https://github.com/scverse/squidpy/issues/526 diff --git a/src/squidpy/datasets/__init__.py b/src/squidpy/datasets/__init__.py index 9d56f44e2..41c1b2f16 100644 --- a/src/squidpy/datasets/__init__.py +++ b/src/squidpy/datasets/__init__.py @@ -6,6 +6,7 @@ ImageDatasets, SpatialDataDatasets, VisiumDatasets, + cells, # AnnData datasets four_i, imc, @@ -25,7 +26,6 @@ visium_hne_image, visium_hne_image_crop, visium_hne_sdata, - cells, ) __all__ = [ diff --git a/src/squidpy/gr/_nhood.py b/src/squidpy/gr/_nhood.py index c679f42dd..f9348a3c9 100644 --- a/src/squidpy/gr/_nhood.py +++ b/src/squidpy/gr/_nhood.py @@ -2,16 +2,17 @@ from __future__ import annotations +import warnings from collections.abc import Callable, Iterable, Sequence from functools import partial from typing import Any, NamedTuple import networkx as nx -import numba.types as nt import numpy as np import pandas as pd from anndata import AnnData -from numba import njit +from numba import njit, prange, set_num_threads +from numba_progress import ProgressBar from numpy.typing import NDArray from pandas import CategoricalDtype from scanpy import logging as logg @@ -26,7 +27,6 @@ _assert_categorical_obs, _assert_connectivity_key, _save_data, - _shuffle_group, extract_adata_if_sdata, ) @@ -37,22 +37,20 @@ class NhoodEnrichmentResult(NamedTuple): """Result of nhood_enrichment function.""" zscore: NDArray[np.number] + """Z-score values of enrichment statistic.""" counts: NDArray[np.number] # NamedTuple inherits from tuple so cannot use 'count' as attribute name + """Enrichment count.""" + conditional_ratio: NDArray[np.number] | None = None + """Conditional ratio. Only present if ``normalization='conditional'``.""" -# data type aliases (both for numpy and numba should match) -dt = nt.uint32 +# integer dtype used for cluster labels and CSR index arrays (numpy/numba must match) ndt = np.uint32 -_template = """ -from __future__ import annotations -from numba import njit, prange -import numpy as np -@njit(dt[:, :](dt[:], dt[:], dt[:]), parallel={parallel}, fastmath=True) -def _nenrich_{n_cls}_{parallel}(indices: NDArrayA, indptr: NDArrayA, clustering: NDArrayA) -> np.ndarray: - ''' - Count how many times clusters :math:`i` and :math:`j` are connected. +@njit(nogil=True, cache=True) +def _nenrich(indices: NDArrayA, indptr: NDArrayA, clustering: NDArrayA, n_cls: int) -> NDArrayA: + """Count how many times clusters are connected. Parameters ---------- @@ -61,76 +59,149 @@ def _nenrich_{n_cls}_{parallel}(indices: NDArrayA, indptr: NDArrayA, clustering: indptr :attr:`scipy.sparse.csr_matrix.indptr`. clustering - Array of shape ``(n_cells,)`` containig cluster labels ranging from `0` to `n_clusters - 1` inclusive. + Array of shape ``(n_cells,)`` containing cluster labels ranging from ``0`` to ``n_cls - 1`` inclusive. + n_cls + Number of clusters. Returns ------- - :class:`numpy.ndarray` - Array of shape ``(n_clusters, n_clusters)`` containing the pairwise counts. - ''' - res = np.zeros((indptr.shape[0] - 1, {n_cls}), dtype=ndt) - - for i in prange(res.shape[0]): - xs, xe = indptr[i], indptr[i + 1] - cols = indices[xs:xe] - for c in cols: - res[i, clustering[c]] += 1 - {init} - {loop} - {finalize} -""" - - -def _create_function(n_cls: int, parallel: bool = False) -> Callable[[NDArrayA, NDArrayA, NDArrayA], NDArrayA]: + Array of shape ``(n_cls, n_cls)`` where entry ``(a, b)`` is the number of directed edges + from a cluster-``a`` cell to a cluster-``b`` neighbor. + """ + out = np.zeros((n_cls, n_cls), dtype=np.uint32) + for i in range(indptr.shape[0] - 1): + a = clustering[i] + for c in indices[indptr[i] : indptr[i + 1]]: + out[a, clustering[c]] += 1 + return out + + +@njit(nogil=True, cache=True) +def _conditional_counts(indices: NDArrayA, indptr: NDArrayA, clustering: NDArrayA, n_cls: int) -> NDArrayA: + """Count, per cluster pair ``(a, b)``, how many cluster-``a`` cells have at least one cluster-``b`` neighbor. + + This is the COZI conditional denominator. Returns an ``(n_cls, n_cls)`` float array. + """ + cond = np.zeros((n_cls, n_cls), dtype=np.float64) + seen = np.zeros(n_cls, dtype=np.bool_) + for i in range(indptr.shape[0] - 1): + seen[:] = False + for c in indices[indptr[i] : indptr[i + 1]]: + seen[clustering[c]] = True + a = clustering[i] + for b in range(n_cls): + if seen[b]: + cond[a, b] += 1.0 + return cond + + +@njit(parallel=True, nogil=True) +def _permutation_counts( + indices: NDArrayA, + indptr: NDArrayA, + int_clust: NDArrayA, + group_offsets: NDArrayA, + group_indices: NDArrayA, + n_cls: int, + n_perms: int, + seed: int, + norm_code: int, + progress: ProgressBar, +) -> NDArrayA: + """Compute the (normalized) cluster-connection counts for ``n_perms`` label permutations. + + Parallelizes over permutations with ``prange``: each iteration ``p`` reseeds its thread-local + RNG with ``seed + p`` and shuffles a private copy of the labels, so the result is reproducible + and independent of the thread count. ``numba``'s ``np.random`` reproduces numpy's legacy + ``RandomState`` bit-for-bit, so this matches a ``RandomState(seed + p)`` reference exactly. + ``norm_code`` is ``0`` (none), ``1`` (total) or ``2`` (conditional). + + Labels are shuffled *within groups*: ``group_indices[group_offsets[g]:group_offsets[g + 1]]`` + are the cell indices of group ``g`` (e.g. one library/slide). A single group spanning all cells + reproduces a plain global shuffle, so this handles the no-``library_key`` case too. Groups are + shuffled in order off one RNG stream, matching :func:`squidpy.gr._utils._shuffle_group`. + + ``progress`` is a :class:`numba_progress.ProgressBar` proxy, ticked once per permutation from + inside the parallel loop (atomic, GIL-free); pass a ``disable=True`` bar to suppress it. + """ + perms = np.empty((n_perms, n_cls, n_cls), dtype=np.float64) + for p in prange(n_perms): + np.random.seed(seed + p) + shuffled = int_clust.copy() + for g in range(group_offsets.shape[0] - 1): + s, e = group_offsets[g], group_offsets[g + 1] + sub = np.empty(e - s, dtype=int_clust.dtype) + for t in range(e - s): + sub[t] = int_clust[group_indices[s + t]] + np.random.shuffle(sub) + for t in range(e - s): + shuffled[group_indices[s + t]] = sub[t] + + out = np.zeros((n_cls, n_cls), dtype=np.float64) + for i in range(indptr.shape[0] - 1): + a = shuffled[i] + for c in indices[indptr[i] : indptr[i + 1]]: + out[a, shuffled[c]] += 1.0 + + if norm_code == 1: # total + for a in range(n_cls): + s = 0.0 + for b in range(n_cls): + s += out[a, b] + if s == 0.0: + s = 1.0 + for b in range(n_cls): + out[a, b] /= s + elif norm_code == 2: # conditional + cond = _conditional_counts(indices, indptr, shuffled, n_cls) + for a in range(n_cls): + for b in range(n_cls): + d = cond[a, b] if cond[a, b] != 0.0 else 1.0 + out[a, b] /= d + + perms[p] = out + progress.update(1) + return perms + + +_NORM_CODES = {"none": 0, "total": 1, "conditional": 2} + + +def filter_clusters_by_min_cell_count( + adata: AnnData, + int_clust: NDArrayA, + connectivity_key: str, + min_cell_count: int, +) -> tuple[NDArrayA, NDArrayA]: """ - Create a :mod:`numba` function which counts the number of connections between clusters. + Filter clusters by minimum cell count. Parameters ---------- - n_cls - Number of clusters. We're assuming that cluster labels are `0`, `1`, ..., `n_cls - 1`. - parallel - Whether to enable :mod:`numba` parallelization. + %(adata)s + int_clust + Array of cluster labels per cell + connectivity_key + Key in adata.obsp with adjacency matrix + min_cell_count + Minimum number of cells required to keep a cluster Returns ------- - The aforementioned function. + int_clust_filtered + Filtered cluster labels + adj + Adjacency matrix corresponding to filtered cells """ - if n_cls <= 1: - raise ValueError(f"Expected at least `2` clusters, found `{n_cls}`.") - - rng = range(n_cls) - init = "".join( - f""" - g{i} = np.zeros(({n_cls},), dtype=ndt)""" - for i in rng - ) - - loop_body = """ - if cl == 0: - g0 += res[row]""" - loop_body = loop_body + "".join( - f""" - elif cl == {i}: - g{i} += res[row]""" - for i in range(1, n_cls) - ) - loop = f""" - for row in prange(res.shape[0]): - cl = clustering[row] - {loop_body} - else: - assert False, "Unhandled case." - """ - finalize = ", ".join(f"g{i}" for i in rng) - finalize = f"return np.stack(({finalize}))" # must really be a tuple + clust_sizes = pd.Series(int_clust).value_counts() + valid_clusters = clust_sizes[clust_sizes >= min_cell_count].index.to_numpy() - fn_key = f"_nenrich_{n_cls}_{parallel}" - if fn_key not in globals(): - template = _template.format(init=init, loop=loop, finalize=finalize, n_cls=n_cls, parallel=parallel) - exec(compile(template, "", "exec"), globals()) + valid_mask = np.isin(int_clust, valid_clusters) + valid_cells_idx = np.where(valid_mask)[0] + int_clust = int_clust[valid_mask] - return globals()[fn_key] # type: ignore[no-any-return] + adj = adata.obsp[connectivity_key][np.ix_(valid_cells_idx, valid_cells_idx)] + return int_clust, adj @d.get_sections(base="nhood_ench", sections=["Parameters"]) @@ -145,7 +216,10 @@ def nhood_enrichment( seed: int | None = None, copy: bool = False, n_jobs: int | None = None, - backend: str = "loky", + backend: str | None = None, + normalization: str = "none", + min_cell_count: int = 0, + handle_nan: str = "keep", show_progress_bar: bool = True, *, table_key: str | None = None, @@ -165,15 +239,29 @@ def nhood_enrichment( %(seed)s %(copy)s %(parallelize)s + normalization + Normalization mode to use: + - ``'none'``: No normalization of neighbor counts + - ``'total'``: Normalize neighbor counts by total number of cells per cluster (SEA) + - ``'conditional'``: Normalize neighbor counts by number of cells with at least one neighbor of given type (COZI) + min_cell_count + Minimum number of cells a cluster must contain to be included. Clusters with fewer cells are + dropped before counting (default ``0`` keeps all clusters). + handle_nan + How to handle NaN values in z-scores: + - ``'zero'``: Replace NaN values with 0 + - ``'keep'``: Keep NaN values (undefined enrichment) Returns ------- If ``copy = True``, returns a :class:`~squidpy.gr.NhoodEnrichmentResult` with the z-score and the enrichment count. + If normalization = "conditional", also contains the conditional ratio, otherwise it is None. Otherwise, modifies the ``adata`` with the following keys: - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['zscore']`` - the enrichment z-score. - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['count']`` - the enrichment count. + - :attr:`anndata.AnnData.uns` ``['{cluster_key}_nhood_enrichment']['conditional_ratio']`` - the ratio of cells of type A that neighbor type B. """ adata = extract_adata_if_sdata(adata, table_key=table_key) connectivity_key = Key.obsp.spatial_conn(connectivity_key) @@ -181,11 +269,33 @@ def nhood_enrichment( _assert_connectivity_key(adata, connectivity_key) assert_positive(n_perms, name="n_perms") + if numba_parallel: + warnings.warn( + "`numba_parallel` is deprecated and no longer has any effect; permutations are now " + "parallelized across threads. It will be removed in a future version.", + FutureWarning, + stacklevel=2, + ) + if backend is not None: + warnings.warn( + "`backend` is deprecated and no longer has any effect; permutations now run on a " + "thread pool. It will be removed in a future version.", + FutureWarning, + stacklevel=2, + ) + adj = adata.obsp[connectivity_key] original_clust = adata.obs[cluster_key] - clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} # map categories + clust_map = {v: i for i, v in enumerate(original_clust.cat.categories.values)} int_clust = np.array([clust_map[c] for c in original_clust], dtype=ndt) + n_total_cells = len(int_clust) + int_clust, adj = filter_clusters_by_min_cell_count( + adata=adata, + int_clust=int_clust, + connectivity_key=connectivity_key, + min_cell_count=min_cell_count, + ) if library_key is not None: _assert_categorical_obs(adata, key=library_key) libraries: pd.Series | None = adata.obs[library_key] @@ -194,39 +304,93 @@ def nhood_enrichment( indices, indptr = (adj.indices.astype(ndt), adj.indptr.astype(ndt)) n_cls = len(clust_map) + if n_cls <= 1: + raise ValueError(f"Expected at least `2` clusters, found `{n_cls}`.") + + count = _nenrich(indices, indptr, int_clust, n_cls) + conditional_ratio = np.full((n_cls, n_cls), np.nan, dtype=np.float64) + + if normalization == "total": + row_sums = count.sum(axis=1, keepdims=True) + row_sums[row_sums == 0] = 1 + count_normalized = count / row_sums + elif normalization == "conditional": + cond_counts = _conditional_counts(indices, indptr, int_clust, n_cls) + + cluster_sizes = np.bincount(int_clust, minlength=n_cls).astype(np.float64) + nonempty = cluster_sizes > 0 + conditional_ratio[nonempty] = cond_counts[nonempty] / cluster_sizes[nonempty, None] + + safe_cond_counts = cond_counts.copy() + safe_cond_counts[safe_cond_counts == 0] = 1.0 - _test = _create_function(n_cls, parallel=numba_parallel) - count = _test(indices, indptr, int_clust) + count_normalized = count / safe_cond_counts + + n_retained_cells = len(int_clust) + n_filtered = n_total_cells - n_retained_cells + frac_filtered = n_filtered / n_total_cells * 100 + + if n_filtered > 0: + warnings.warn( + f"{frac_filtered:.3f}% of cells were excluded because their clusters had fewer than {min_cell_count} cells.", + UserWarning, + stacklevel=2, + ) + + elif normalization == "none": + count_normalized = count.copy() + else: + raise ValueError(f"Invalid normalization mode `{normalization}`. Choose from 'none', 'total', 'conditional'.") n_jobs = _get_n_cores(n_jobs) start = logg.info(f"Calculating neighborhood enrichment using `{n_jobs}` core(s)") - perms = parallelize( - _nhood_enrichment_helper, - collection=np.arange(n_perms).tolist(), - extractor=np.vstack, - n_jobs=n_jobs, - backend=backend, - show_progress_bar=show_progress_bar, - )( - callback=_test, - indices=indices, - indptr=indptr, - int_clust=int_clust, - libraries=libraries, - n_cls=n_cls, - seed=seed, - ) - zscore = (count - perms.mean(axis=0)) / perms.std(axis=0) + # Every permutation is seeded by its global index (``seed + p``), so results are reproducible + # and independent of the thread count. A user seed is used directly; otherwise a random base + # seed is drawn once so a single call is internally consistent. + norm_code = _NORM_CODES[normalization] + base_seed = int(seed) if seed is not None else int(np.random.SeedSequence().generate_state(1)[0]) % 1_000_000_000 + + # Group structure for within-group shuffling, as a CSR-like (offsets, indices) pair in category + # order with ascending indices per group (matching `_shuffle_group`). Without a `library_key`, + # a single group spanning all cells reproduces a plain global shuffle. + group_offsets, group_indices = _build_shuffle_groups(libraries, len(int_clust)) + + # A single numba ``prange`` kernel shuffles + counts + normalizes per thread with the GIL + # released, and ticks the progress bar from inside the loop; numba owns the parallelism. + set_num_threads(n_jobs) + with ProgressBar(total=n_perms, unit="perm", desc="nhood_enrichment", disable=not show_progress_bar) as progress: + perms = _permutation_counts( + indices, indptr, int_clust, group_offsets, group_indices, n_cls, n_perms, base_seed, norm_code, progress + ) + + std = perms.std(axis=0) + std[std == 0] = np.nan + zscore = (count_normalized - perms.mean(axis=0)) / std + + if handle_nan == "zero": + zscore = np.nan_to_num(zscore, nan=0.0) + elif handle_nan == "keep": + pass + else: + raise ValueError("handle_nan must be 'keep' or 'zero'") + + result_kwargs = {"zscore": zscore, "count": count} + if normalization == "conditional": + result_kwargs["conditional_ratio"] = conditional_ratio if copy: - return NhoodEnrichmentResult(zscore=zscore, counts=count) + return NhoodEnrichmentResult( + zscore=result_kwargs["zscore"], + counts=result_kwargs["count"], + conditional_ratio=result_kwargs.get("conditional_ratio"), + ) _save_data( adata, attr="uns", key=Key.uns.nhood_enrichment(cluster_key), - data={"zscore": zscore, "count": count}, + data=result_kwargs, time=start, ) @@ -436,32 +600,21 @@ def _centrality_scores_helper( return pd.DataFrame(res_list, columns=[method], index=cat) -def _nhood_enrichment_helper( - ixs: NDArrayA, - callback: Callable[[NDArrayA, NDArrayA, NDArrayA], NDArrayA], - indices: NDArrayA, - indptr: NDArrayA, - int_clust: NDArrayA, +def _build_shuffle_groups( libraries: pd.Series[CategoricalDtype] | None, - n_cls: int, - seed: int | None = None, - queue: SigQueue | None = None, -) -> NDArrayA: - perms = np.empty((len(ixs), n_cls, n_cls), dtype=np.float64) - int_clust = int_clust.copy() # threading - rs = np.random.RandomState(seed=None if seed is None else seed + ixs[0]) - - for i in range(len(ixs)): - if libraries is not None: - int_clust = _shuffle_group(int_clust, libraries, rs) - else: - rs.shuffle(int_clust) - perms[i, ...] = callback(indices, indptr, int_clust) + n_cells: int, +) -> tuple[NDArrayA, NDArrayA]: + """Build a CSR-like ``(offsets, indices)`` description of the within-group shuffling. - if queue is not None: - queue.put(Signal.UPDATE) - - if queue is not None: - queue.put(Signal.FINISH) - - return perms + ``indices[offsets[g]:offsets[g + 1]]`` are the cell indices of group ``g`` in ascending order, + with groups in category order — matching :func:`squidpy.gr._utils._shuffle_group`. Without a + ``library_key`` there is a single group spanning all cells, which reproduces a global shuffle. + """ + if libraries is None: + return np.array([0, n_cells], dtype=np.int64), np.arange(n_cells, dtype=np.int64) + + codes = libraries.cat.codes.to_numpy() + n_groups = len(libraries.cat.categories) + group_indices = np.argsort(codes, kind="stable").astype(np.int64) + group_offsets = np.concatenate(([0], np.cumsum(np.bincount(codes, minlength=n_groups)))).astype(np.int64) + return group_offsets, group_indices diff --git a/src/squidpy/pl/__init__.py b/src/squidpy/pl/__init__.py index 9ddbf2386..313ff7364 100644 --- a/src/squidpy/pl/__init__.py +++ b/src/squidpy/pl/__init__.py @@ -7,6 +7,7 @@ co_occurrence, interaction_matrix, nhood_enrichment, + nhood_enrichment_dotplot, ripley, ) from squidpy.pl._ligrec import ligrec diff --git a/src/squidpy/pl/_graph.py b/src/squidpy/pl/_graph.py index f2077603b..92117320f 100644 --- a/src/squidpy/pl/_graph.py +++ b/src/squidpy/pl/_graph.py @@ -2,6 +2,7 @@ from __future__ import annotations +import warnings from collections.abc import Mapping, Sequence from pathlib import Path from types import MappingProxyType @@ -13,6 +14,7 @@ import seaborn as sns from anndata import AnnData from matplotlib.axes import Axes +from matplotlib.lines import Line2D from squidpy._constants._constants import RipleyStat from squidpy._constants._pkg_constants import Key @@ -237,6 +239,160 @@ def nhood_enrichment( save_fig(fig, path=save) +@d.dedent +def nhood_enrichment_dotplot( + adata: AnnData, + cluster_key: str, + zscore_key: str = "ct_nhood_enrichment", + annotate: bool = False, + title: str | None = None, + cmap: str = "coolwarm", + palette: Palette_t = None, + cbar_kwargs: Mapping[str, Any] = MappingProxyType({}), + figsize: tuple[float, float] | None = None, + dpi: int | None = None, + size_range: tuple[float, float] = (10, 200), + save: str | Path | None = None, + ax: Axes | None = None, + **kwargs: Any, +) -> None: + """ + Dot plot of neighborhood enrichment. + + This plots the result of :func:`squidpy.gr.nhood_enrichment`, using: + - Color for z-score of enrichment + - Dot size for conditional cell ratio (CCR), scaled continuously + + Parameters + ---------- + adata : AnnData + Annotated data matrix. + cluster_key : str + Key in `adata.obs` where the cluster (cell type) annotation is stored. + zscore_key : str, optional + Key in `adata.uns` where the enrichment results are stored. + annotate : bool, optional + Whether to annotate dots with CCR values. + title : str, optional + Title of the plot. + cmap : str, optional + Colormap used for the z-score values. + palette : Palette_t, optional + Not used, reserved for compatibility. + cbar_kwargs : dict, optional + Keyword arguments for `fig.colorbar`. + figsize : tuple, optional + Figure size. + dpi : int, optional + Dots per inch for the figure. + size_range : tuple of float, optional + Min and max dot sizes for conditional cell ratio scaling. + save : str | Path, optional + Path to save the figure. + ax : matplotlib.axes.Axes, optional + Axes object to draw the plot onto, otherwise a new figure is created. + **kwargs : Any + Additional keyword arguments passed to `plt.scatter`. + + Returns + ------- + None + """ + _assert_categorical_obs(adata, key=cluster_key) + enrichment = _get_data(adata, cluster_key=cluster_key, func_name="nhood_enrichment") + + zscore = enrichment["zscore"] + ccr = enrichment.get("conditional_ratio") + + if ccr is None: + warnings.warn( + "'conditional_ratio' is None in nhood_enrichment results. Please run nhood_erichment with normalization = 'conditional'." + "Dot size will not reflect conditional cell ratios.", + UserWarning, + stacklevel=2, + ) + ccr = np.ones_like(zscore) + + cats = adata.obs[cluster_key].cat.categories + + df = pd.DataFrame( + { + "x": np.tile(np.arange(len(cats)), len(cats)), + "y": np.repeat(np.arange(len(cats)), len(cats)), + "zscore": zscore.flatten(), + "ccr": ccr.flatten(), + } + ) + + size_min, size_max = size_range + ccr_norm = (df["ccr"] - df["ccr"].min()) / (df["ccr"].max() - df["ccr"].min() + 1e-10) + df["size"] = size_min + ccr_norm * (size_max - size_min) + + fig, ax = plt.subplots(figsize=figsize, dpi=dpi) if ax is None else (ax.figure, ax) + cmap = "YlGnBu" + sc = ax.scatter( + df["x"], + df["y"], + c=df["zscore"], + s=df["size"], + cmap=cmap, + edgecolors="black", + linewidths=0.3, + **kwargs, + ) + + ax.set_xticks(np.arange(len(cats))) + ax.set_yticks(np.arange(len(cats))) + ax.set_xticklabels(cats, rotation=90) + ax.set_yticklabels(cats) + ax.set_xlabel("Neighbor cell type") + ax.set_ylabel("Index cell type") + + ax.set_title(title or "Neighborhood enrichment (dot plot)") + + # Colorbar + cbar = fig.colorbar(sc, ax=ax, **cbar_kwargs) + cbar.set_label("Z-score") + + legend_ccr_vals = np.linspace(df["ccr"].min(), df["ccr"].max(), 5) + legend_sizes = size_min + (legend_ccr_vals - df["ccr"].min()) / (df["ccr"].max() - df["ccr"].min() + 1e-10) * ( + size_max - size_min + ) + + legend_elements = [ + Line2D( + [0], + [0], + marker="o", + color="w", + label=f"{v:.2f}", + markerfacecolor="gray", + markersize=np.sqrt(s), # scatter size is area → sqrt for legend + markeredgecolor="black", + ) + for v, s in zip(legend_ccr_vals, legend_sizes, strict=True) + ] + + ax.legend( + handles=legend_elements, + title="CCR", + loc="center left", + bbox_to_anchor=(1.3, 0.5), + borderaxespad=0.0, + frameon=False, + ) + + if annotate: + for _, row in df.iterrows(): + ax.text(row["x"], row["y"], f"{row['ccr']:.2f}", ha="center", va="center") + + ax.invert_yaxis() + ax.set_aspect("equal") + + if save is not None: + save_fig(fig, path=save) + + @d.dedent def ripley( adata: AnnData, diff --git a/tests/_images/Graph_nhood_enrichment_dotplot.png b/tests/_images/Graph_nhood_enrichment_dotplot.png new file mode 100644 index 000000000..325a25b25 Binary files /dev/null and b/tests/_images/Graph_nhood_enrichment_dotplot.png differ diff --git a/tests/graph/test_nhood.py b/tests/graph/test_nhood.py index 74764f236..0a0d5e337 100644 --- a/tests/graph/test_nhood.py +++ b/tests/graph/test_nhood.py @@ -1,5 +1,7 @@ from __future__ import annotations +import warnings + import numpy as np import pandas as pd import pytest @@ -30,11 +32,37 @@ def test_nhood_enrichment(self, adata: AnnData): self._assert_common(adata) + @pytest.mark.parametrize("n_jobs", [1, 2, 3]) + def test_parallel_works(self, adata: AnnData, n_jobs: int): + spatial_neighbors(adata) + + nhood_enrichment(adata, cluster_key=_CK, n_jobs=n_jobs, n_perms=20) + + self._assert_common(adata) + @pytest.mark.parametrize("backend", ["threading", "multiprocessing", "loky"]) - def test_parallel_works(self, adata: AnnData, backend: str): + def test_backend_is_deprecated(self, adata: AnnData, backend: str): spatial_neighbors(adata) - nhood_enrichment(adata, cluster_key=_CK, n_jobs=2, n_perms=20, backend=backend) + with pytest.warns(FutureWarning, match="`backend` is deprecated"): + nhood_enrichment(adata, cluster_key=_CK, n_jobs=2, n_perms=20, backend=backend) + + self._assert_common(adata) + + def test_numba_parallel_is_deprecated(self, adata: AnnData): + spatial_neighbors(adata) + + with pytest.warns(FutureWarning, match="`numba_parallel` is deprecated"): + nhood_enrichment(adata, cluster_key=_CK, n_perms=20, numba_parallel=True) + + self._assert_common(adata) + + def test_no_deprecation_warning_by_default(self, adata: AnnData): + spatial_neighbors(adata) # kept outside the block: it emits its own FutureWarning + + with warnings.catch_warnings(): + warnings.simplefilter("error", FutureWarning) + nhood_enrichment(adata, cluster_key=_CK, n_perms=20) self._assert_common(adata) @@ -143,3 +171,74 @@ def test_interaction_matrix_nan_values(adata_intmat: AnnData): np.testing.assert_array_equal(expected_weighted, result_weighted) np.testing.assert_array_equal(expected_unweighted, result_unweighted) + + +@pytest.mark.parametrize("normalization", ["none", "total", "conditional"]) +def test_nhood_enrichment_normalization_modes(adata: AnnData, normalization: str): + spatial_neighbors(adata) + result = nhood_enrichment(adata, cluster_key=_CK, normalization=normalization, n_jobs=1, n_perms=20, copy=True) + + z, count, ccr = result + + assert isinstance(z, np.ndarray) + assert isinstance(count, np.ndarray) + if normalization == "conditional": + assert isinstance(ccr, np.ndarray) + assert z.shape == ccr.shape + assert count.shape == ccr.shape + assert z.shape == count.shape + assert z.shape[0] == adata.obs[_CK].cat.categories.shape[0] + + +def test_conditional_normalization_zero_division(adata: AnnData): + adata = adata.copy() + min_cells = 10 + if _CK not in adata.obs: + raise ValueError(f"Cluster key '{_CK}' not in adata.obs") + if not pd.api.types.is_categorical_dtype(adata.obs[_CK]): + adata.obs[_CK] = adata.obs[_CK].astype("category") + adata.obs[_CK] = adata.obs[_CK].cat.add_categories("isolated") + adata.obs.loc[adata.obs.index[0], _CK] = "isolated" + spatial_neighbors(adata) + valid_clusters = [c for c, count in adata.obs[_CK].value_counts().items() if count >= min_cells] + valid_idx = [i for i, cat in enumerate(adata.obs[_CK].cat.categories) if cat in valid_clusters] + + result = nhood_enrichment(adata, cluster_key=_CK, normalization="conditional", copy=True) + assert result is not None + zscore, count_normalized, conditional_ratio = result + assert not np.any(np.isinf(zscore)) + assert not np.any(np.isinf(count_normalized)) + assert not np.any(np.isinf(conditional_ratio)) + assert not np.isnan(zscore[np.ix_(valid_idx, valid_idx)]).any() + assert not np.isnan(count_normalized[np.ix_(valid_idx, valid_idx)]).any() + assert not np.isnan(conditional_ratio[np.ix_(valid_idx, valid_idx)]).any() + + +@pytest.mark.parametrize( + "normalization, expected_dtype", + [ + ("none", np.uint32), + ("total", np.uint32), + ("conditional", np.uint32), + ], +) +def test_output_dtype(adata: AnnData, normalization: str, expected_dtype): + spatial_neighbors(adata) + result = nhood_enrichment( + adata, + cluster_key=_CK, + normalization=normalization, + n_jobs=1, + n_perms=20, + copy=True, + ) + + count = result.counts + + assert count.dtype == expected_dtype + + +def test_invalid_normalization_raises(adata: AnnData): + spatial_neighbors(adata) + with pytest.raises(ValueError, match="Invalid normalization mode"): + nhood_enrichment(adata, cluster_key=_CK, normalization="invalid_mode", copy=True) diff --git a/tests/graph/test_nhood_correctness.py b/tests/graph/test_nhood_correctness.py new file mode 100644 index 000000000..7ab2526e7 --- /dev/null +++ b/tests/graph/test_nhood_correctness.py @@ -0,0 +1,344 @@ +"""Correctness tests for :func:`squidpy.gr.nhood_enrichment`. + +These tests pin down the *numerical* behaviour of neighborhood enrichment, as +opposed to the shape/dtype smoke tests in ``test_nhood.py``. + +The reference implementations below are written from scratch with plain Python +loops and share no code with the numba kernels in ``squidpy.gr._nhood``. Their +only purpose is to be an independent specification: when the reference and the +real implementation agree, the numba kernels *and* the normalization math are +validated. They are deliberately slow and simple. + +They also act as the regression guard for the permutation / parallelization +machinery. Each permutation is seeded by its global index (``seed + p``), and +numba's ``np.random`` reproduces numpy's ``RandomState`` bit-for-bit, so +:func:`_reference_nhood_enrichment` reproduces the exact z-score independently +of the thread count. +""" + +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from scipy.sparse import csr_matrix + +from squidpy.gr import nhood_enrichment, spatial_neighbors +from squidpy.gr._utils import _shuffle_group + +_CK = "leiden" + + +# --------------------------------------------------------------------------- # +# Independent reference implementations (plain Python, no numba) +# --------------------------------------------------------------------------- # +def _ref_count(adj: csr_matrix, int_clust: np.ndarray, n_cls: int) -> np.ndarray: + """``count[a, b]`` = number of directed edges from a cluster-``a`` cell to a cluster-``b`` neighbor.""" + adj = adj.tocsr() + count = np.zeros((n_cls, n_cls), dtype=np.uint32) + for i in range(adj.shape[0]): + a = int_clust[i] + for j in adj.indices[adj.indptr[i] : adj.indptr[i + 1]]: + count[a, int_clust[j]] += 1 + return count + + +def _ref_total(count: np.ndarray) -> np.ndarray: + row_sums = count.sum(axis=1, keepdims=True).astype(np.float64) + row_sums[row_sums == 0] = 1 + return count / row_sums + + +def _ref_conditional(adj: csr_matrix, int_clust: np.ndarray, n_cls: int) -> tuple[np.ndarray, np.ndarray]: + """Return ``(count_normalized, conditional_ratio)`` for the conditional mode.""" + adj = adj.tocsr() + count = _ref_count(adj, int_clust, n_cls) + + # per-cell, per-cluster neighbor counts -> boolean "has at least one neighbor of type b" + has_neighbor = np.zeros((len(int_clust), n_cls), dtype=bool) + for i in range(len(int_clust)): + for j in adj.indices[adj.indptr[i] : adj.indptr[i + 1]]: + has_neighbor[i, int_clust[j]] = True + + cond_counts = np.zeros((n_cls, n_cls), dtype=np.float64) + conditional_ratio = np.full((n_cls, n_cls), np.nan, dtype=np.float64) + for a in range(n_cls): + a_cells = int_clust == a + n_a = int(a_cells.sum()) + if n_a == 0: + continue + for b in range(n_cls): + cond_counts[a, b] = has_neighbor[a_cells, b].sum() + conditional_ratio[a, :] = cond_counts[a, :] / n_a + + safe = cond_counts.copy() + safe[safe == 0] = 1.0 + return count / safe, conditional_ratio + + +def _ref_normalize(adj: csr_matrix, int_clust: np.ndarray, n_cls: int, normalization: str) -> np.ndarray: + """``count_normalized`` for an arbitrary normalization mode.""" + count = _ref_count(adj, int_clust, n_cls) + if normalization == "none": + return count.astype(np.float64) + if normalization == "total": + return _ref_total(count) + if normalization == "conditional": + return _ref_conditional(adj, int_clust, n_cls)[0] + raise ValueError(normalization) + + +def _reference_nhood_enrichment( + adj: csr_matrix, + int_clust: np.ndarray, + n_cls: int, + *, + n_perms: int, + seed: int, + normalization: str, + libraries: pd.Series | None = None, +) -> np.ndarray: + """Full z-score reference replicating the production per-permutation seeding scheme. + + Each permutation ``p`` gets its own ``RandomState(seed + p)`` and shuffles a private copy of + the original labels once. Because the seed depends only on the global permutation index, the + result is independent of how permutations are grouped across workers (i.e. of ``n_jobs``). + Numba's ``np.random`` reproduces numpy's legacy ``RandomState`` bit-for-bit, so this matches + :func:`squidpy.gr._nhood._permutation_counts` exactly. + """ + observed = _ref_normalize(adj, int_clust, n_cls, normalization) + + perms = np.empty((n_perms, n_cls, n_cls), dtype=np.float64) + for p in range(n_perms): + rs = np.random.RandomState(seed + p) + if libraries is not None: + shuffled = _shuffle_group(int_clust, libraries, rs) + else: + shuffled = int_clust.copy() + rs.shuffle(shuffled) + perms[p] = _ref_normalize(adj, shuffled, n_cls, normalization) + + std = perms.std(axis=0) + std[std == 0] = np.nan + return (observed - perms.mean(axis=0)) / std + + +# --------------------------------------------------------------------------- # +# Tiny, fully deterministic graph +# --------------------------------------------------------------------------- # +@pytest.fixture() +def adata_tiny() -> AnnData: + """A 6-cell, 3-cluster graph with a hand-chosen, symmetric adjacency. + + Layout (clusters in brackets):: + + 0[0] - 1[0] - 2[1] - 3[1] - 4[2] - 5[2] + |_______________________________________| (0-5 edge closes the ring) + |_______| (2-4 cross edge) + """ + n = 6 + edges = [(0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 0), (2, 4)] + rows, cols = [], [] + for i, j in edges: + rows += [i, j] + cols += [j, i] + data = np.ones(len(rows), dtype=np.float64) + adj = csr_matrix((data, (rows, cols)), shape=(n, n)) + + clusters = pd.Categorical.from_codes([0, 0, 1, 1, 2, 2], categories=["A", "B", "C"]) + adata = AnnData( + np.zeros((n, n)), + obs={_CK: clusters}, + obsp={"spatial_connectivities": adj}, + ) + return adata + + +def _int_clust(adata: AnnData) -> np.ndarray: + return adata.obs[_CK].cat.codes.to_numpy() + + +# --------------------------------------------------------------------------- # +# Deterministic output: counts +# --------------------------------------------------------------------------- # +def test_counts_match_reference_tiny(adata_tiny: AnnData): + adj = adata_tiny.obsp["spatial_connectivities"] + int_clust = _int_clust(adata_tiny) + n_cls = 3 + + result = nhood_enrichment(adata_tiny, cluster_key=_CK, n_perms=20, seed=0, copy=True) + + expected = _ref_count(adj, int_clust, n_cls) + np.testing.assert_array_equal(result.counts, expected) + # the counts matrix is independent of the normalization mode + assert result.counts.dtype == np.uint32 + + +def test_counts_hardcoded_tiny(adata_tiny: AnnData): + """A fully hand-computable expectation, so the reference itself can't drift silently.""" + result = nhood_enrichment(adata_tiny, cluster_key=_CK, n_perms=20, seed=0, copy=True) + # cluster A = {0,1}, B = {2,3}, C = {4,5} + # directed edges by (src cluster -> dst cluster), counting both directions of each undirected edge: + # A-A: 0-1 -> 2 + # A-B: 1-2 -> 1 each way + # A-C: 5-0 -> 1 each way + # B-B: 2-3 -> 2 + # B-C: 3-4, 2-4 -> 2 each way + # C-C: 4-5 -> 2 + expected = np.array( + [ + [2, 1, 1], + [1, 2, 2], + [1, 2, 2], + ], + dtype=np.uint32, + ) + np.testing.assert_array_equal(result.counts, expected) + + +@pytest.mark.parametrize("normalization", ["none", "total", "conditional"]) +def test_counts_invariant_to_normalization(adata_tiny: AnnData, normalization: str): + """Raw ``counts`` must be the observed edge counts regardless of normalization.""" + adj = adata_tiny.obsp["spatial_connectivities"] + expected = _ref_count(adj, _int_clust(adata_tiny), 3) + result = nhood_enrichment(adata_tiny, cluster_key=_CK, normalization=normalization, n_perms=20, seed=0, copy=True) + np.testing.assert_array_equal(result.counts, expected) + + +# --------------------------------------------------------------------------- # +# Deterministic output: conditional_ratio +# --------------------------------------------------------------------------- # +def test_conditional_ratio_matches_reference_tiny(adata_tiny: AnnData): + adj = adata_tiny.obsp["spatial_connectivities"] + int_clust = _int_clust(adata_tiny) + _, expected_ratio = _ref_conditional(adj, int_clust, 3) + + result = nhood_enrichment(adata_tiny, cluster_key=_CK, normalization="conditional", n_perms=20, seed=0, copy=True) + np.testing.assert_allclose(result.conditional_ratio, expected_ratio) + + +def test_conditional_ratio_is_a_fraction(adata_tiny: AnnData): + """Conditional ratios are fractions of cells, so they live in ``[0, 1]``.""" + result = nhood_enrichment(adata_tiny, cluster_key=_CK, normalization="conditional", n_perms=20, seed=0, copy=True) + ratio = result.conditional_ratio + assert np.all((ratio >= 0) & (ratio <= 1)) + + +def test_conditional_ratio_none_for_other_modes(adata_tiny: AnnData): + for normalization in ("none", "total"): + result = nhood_enrichment( + adata_tiny, cluster_key=_CK, normalization=normalization, n_perms=20, seed=0, copy=True + ) + assert result.conditional_ratio is None + + +# --------------------------------------------------------------------------- # +# Full z-score: independent numpy reference for the numba permutation kernel +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("n_jobs", [1, 2, 3]) +@pytest.mark.parametrize("normalization", ["none", "total", "conditional"]) +def test_zscore_matches_reference_tiny(adata_tiny: AnnData, normalization: str, n_jobs: int): + adj = adata_tiny.obsp["spatial_connectivities"] + int_clust = _int_clust(adata_tiny) + seed, n_perms = 0, 50 + + result = nhood_enrichment( + adata_tiny, + cluster_key=_CK, + normalization=normalization, + n_perms=n_perms, + seed=seed, + n_jobs=n_jobs, + copy=True, + ) + expected = _reference_nhood_enrichment(adj, int_clust, 3, n_perms=n_perms, seed=seed, normalization=normalization) + np.testing.assert_allclose(result.zscore, expected, equal_nan=True) + + +@pytest.mark.parametrize("n_jobs", [1, 2, 3]) +def test_zscore_reference_holds_on_real_data(adata: AnnData, n_jobs: int): + """The numpy per-index reference must reproduce z-scores on a realistic graph, for every n_jobs. + + Numba's ``np.random`` matches numpy's ``RandomState`` bit-for-bit, so the kernel and this + reference agree exactly regardless of thread count. + """ + spatial_neighbors(adata) + adj = adata.obsp["spatial_connectivities"] + int_clust = adata.obs[_CK].cat.codes.to_numpy() + n_cls = adata.obs[_CK].cat.categories.shape[0] + seed, n_perms = 7, 30 + + result = nhood_enrichment(adata, cluster_key=_CK, n_perms=n_perms, seed=seed, n_jobs=n_jobs, copy=True) + expected = _reference_nhood_enrichment(adj, int_clust, n_cls, n_perms=n_perms, seed=seed, normalization="none") + np.testing.assert_allclose(result.zscore, expected, equal_nan=True) + + +@pytest.mark.parametrize("normalization", ["none", "total", "conditional"]) +def test_zscore_independent_of_n_jobs(adata_tiny: AnnData, normalization: str): + """Per-index seeding makes the z-score identical regardless of the worker/thread count.""" + kw = {"cluster_key": _CK, "normalization": normalization, "n_perms": 50, "seed": 0, "copy": True} + r1 = nhood_enrichment(adata_tiny, n_jobs=1, **kw) + r8 = nhood_enrichment(adata_tiny, n_jobs=8, **kw) + np.testing.assert_array_equal(r1.zscore, r8.zscore) + + +@pytest.mark.parametrize("n_jobs", [1, 3]) +@pytest.mark.parametrize("normalization", ["none", "total", "conditional"]) +def test_zscore_library_key_matches_reference(adata_tiny: AnnData, normalization: str, n_jobs: int): + """The numba within-group shuffle must reproduce the ``_shuffle_group`` reference, per ``library_key``.""" + adata = adata_tiny.copy() + # two libraries over the six cells, intentionally uneven and interleaved + adata.obs["library"] = pd.Categorical.from_codes([0, 0, 1, 1, 0, 1], categories=["s1", "s2"]) + adj = adata.obsp["spatial_connectivities"] + int_clust = _int_clust(adata) + seed, n_perms = 0, 50 + + result = nhood_enrichment( + adata, + cluster_key=_CK, + library_key="library", + normalization=normalization, + n_perms=n_perms, + seed=seed, + n_jobs=n_jobs, + copy=True, + ) + expected = _reference_nhood_enrichment( + adj, int_clust, 3, n_perms=n_perms, seed=seed, normalization=normalization, libraries=adata.obs["library"] + ) + np.testing.assert_allclose(result.zscore, expected, equal_nan=True) + + +# --------------------------------------------------------------------------- # +# Statistical sanity: structure that survives RNG changes +# --------------------------------------------------------------------------- # +def test_self_enrichment_positive_for_clustered_graph(): + """Cells wired preferentially to their own cluster must show positive diagonal z-scores.""" + rng = np.random.RandomState(0) + n_per = 40 + n_cls = 3 + n = n_per * n_cls + labels = np.repeat(np.arange(n_cls), n_per) + + rows, cols = [], [] + for c in range(n_cls): + members = np.where(labels == c)[0] + # connect each cell to several same-cluster neighbors -> strong self-enrichment + for i in members: + for j in rng.choice(members[members != i], size=4, replace=False): + rows += [i, j] + cols += [j, i] + adj = csr_matrix((np.ones(len(rows)), (rows, cols)), shape=(n, n)) + + adata = AnnData( + np.zeros((n, 1)), + obs={_CK: pd.Categorical.from_codes(labels, categories=list("ABC"))}, + obsp={"spatial_connectivities": adj}, + ) + result = nhood_enrichment(adata, cluster_key=_CK, n_perms=200, seed=0, n_jobs=1, copy=True) + + diag = np.diag(result.zscore) + off_diag = result.zscore[~np.eye(n_cls, dtype=bool)] + assert np.all(diag > 0), diag + assert np.all(diag > off_diag.max()), (diag, off_diag) diff --git a/tests/graph/test_sepal.py b/tests/graph/test_sepal.py index 32310d779..81506877c 100644 --- a/tests/graph/test_sepal.py +++ b/tests/graph/test_sepal.py @@ -74,4 +74,3 @@ def test_sepal_dense(adata: AnnData): # Assert results are identical assert_frame_equal(df_sparse, df_dense) - diff --git a/tests/plotting/test_graph.py b/tests/plotting/test_graph.py index 6e1c20f7d..27e7f4380 100644 --- a/tests/plotting/test_graph.py +++ b/tests/plotting/test_graph.py @@ -66,6 +66,12 @@ def test_plot_nhood_enrichment_ax(self, adata: AnnData): fig, ax = plt.subplots(figsize=(2, 2), constrained_layout=True) pl.nhood_enrichment(adata, cluster_key=C_KEY, ax=ax) + def test_plot_nhood_enrichment_dotplot(self, adata: AnnData): + gr.spatial_neighbors(adata) + gr.nhood_enrichment(adata, cluster_key=C_KEY, normalization="conditional") + + pl.nhood_enrichment_dotplot(adata, cluster_key=C_KEY) + def test_plot_nhood_enrichment_dendro(self, adata: AnnData): gr.spatial_neighbors(adata) gr.nhood_enrichment(adata, cluster_key=C_KEY)