|
3 | 3 | # https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 |
4 | 4 | from __future__ import annotations |
5 | 5 |
|
| 6 | +import math |
6 | 7 | import operator |
7 | 8 | import warnings |
8 | 9 | from collections.abc import Callable |
|
25 | 26 | "create_diagonal", |
26 | 27 | "expand_dims", |
27 | 28 | "kron", |
| 29 | + "nunique", |
28 | 30 | "pad", |
29 | 31 | "setdiff1d", |
30 | 32 | "sinc", |
@@ -638,6 +640,42 @@ def pad( |
638 | 640 | return at(padded, tuple(slices)).set(x) |
639 | 641 |
|
640 | 642 |
|
| 643 | +def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array: |
| 644 | + """ |
| 645 | + Count the number of unique elements in an array. |
| 646 | +
|
| 647 | + Compatible with JAX and Dask, whose laziness would be otherwise |
| 648 | + problematic. |
| 649 | +
|
| 650 | + Parameters |
| 651 | + ---------- |
| 652 | + x : Array |
| 653 | + Input array. |
| 654 | + xp : array_namespace, optional |
| 655 | + The standard-compatible namespace for `x`. Default: infer. |
| 656 | +
|
| 657 | + Returns |
| 658 | + ------- |
| 659 | + array: 0-dimensional integer array |
| 660 | + The number of unique elements in `x`. It can be lazy. |
| 661 | + """ |
| 662 | + if xp is None: |
| 663 | + xp = array_namespace(x) |
| 664 | + |
| 665 | + if is_jax_array(x): |
| 666 | + # size= is JAX-specific |
| 667 | + # https://github.com/data-apis/array-api/issues/883 |
| 668 | + _, counts = xp.unique_counts(x, size=_compat.size(x)) |
| 669 | + return xp.astype(counts, xp.bool).sum() |
| 670 | + |
| 671 | + _, counts = xp.unique_counts(x) |
| 672 | + n = _compat.size(counts) |
| 673 | + # FIXME https://github.com/data-apis/array-api-compat/pull/231 |
| 674 | + if n is None or math.isnan(n): # e.g. Dask, ndonnx |
| 675 | + return xp.astype(counts, xp.bool).sum() |
| 676 | + return xp.asarray(n, device=_compat.device(x)) |
| 677 | + |
| 678 | + |
641 | 679 | class _AtOp(Enum): |
642 | 680 | """Operations for use in `xpx.at`.""" |
643 | 681 |
|
|
0 commit comments