Skip to content

Commit e775855

Browse files
authored
ENH: Add vectorized searchsorted (#531)
1 parent 763ea1a commit e775855

4 files changed

Lines changed: 321 additions & 3 deletions

File tree

src/array_api_extra/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
one_hot,
1313
pad,
1414
partition,
15+
searchsorted,
1516
setdiff1d,
1617
sinc,
1718
union1d,
@@ -49,6 +50,7 @@
4950
"one_hot",
5051
"pad",
5152
"partition",
53+
"searchsorted",
5254
"setdiff1d",
5355
"sinc",
5456
"union1d",

src/array_api_extra/_delegation.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
"nan_to_num",
2828
"one_hot",
2929
"pad",
30+
"searchsorted",
3031
"sinc",
3132
]
3233

@@ -632,6 +633,85 @@ def pad(
632633
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)
633634

634635

636+
def searchsorted(
637+
x1: Array,
638+
x2: Array,
639+
/,
640+
*,
641+
side: Literal["left", "right"] = "left",
642+
xp: ModuleType | None = None,
643+
) -> Array:
644+
"""
645+
Find indices where elements should be inserted to maintain order.
646+
647+
Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
648+
were inserted before the indices, the resulting array would remain sorted.
649+
650+
The behavior of this function is similar to that of `array_api.searchsorted`,
651+
but it relaxes the requirement that `x1` must be one-dimensional.
652+
This function is vectorized, treating slices along the last axis
653+
as elements and preceding axes as batch (or "loop") dimensions.
654+
655+
Parameters
656+
----------
657+
x1 : Array
658+
Input array. Should have a real-valued data type. Must be sorted in ascending
659+
order along the last axis.
660+
x2 : Array
661+
Array containing search values. Should have a real-valued data type. Must have
662+
the same shape as ``x1`` except along the last axis.
663+
side : {'left', 'right'}, optional
664+
Argument controlling which index is returned if an element of ``x2`` is equal to
665+
one or more elements of ``x1``: ``'left'`` returns the index of the first of
666+
these elements; ``'right'`` returns the next index after the last of these
667+
elements. Default: ``'left'``.
668+
xp : array_namespace, optional
669+
The standard-compatible namespace for the array arguments. Default: infer.
670+
671+
Returns
672+
-------
673+
Array: integer array
674+
An array of indices with the same shape as ``x2``.
675+
676+
Examples
677+
--------
678+
>>> import array_api_strict as xp
679+
>>> import array_api_extra as xpx
680+
>>> x = xp.asarray([11, 12, 13, 13, 14, 15])
681+
>>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp)
682+
Array([0, 1, 5, 6], dtype=array_api_strict.int64)
683+
>>> xpx.searchsorted(x, xp.asarray(13), xp=xp)
684+
Array(2, dtype=array_api_strict.int64)
685+
>>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp)
686+
Array(4, dtype=array_api_strict.int64)
687+
688+
`searchsorted` is vectorized along the last axis.
689+
690+
>>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]])
691+
>>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]])
692+
>>> xpx.searchsorted(x1, x2, xp=xp)
693+
Array([[1, 3],
694+
[2, 4]], dtype=array_api_strict.int64)
695+
"""
696+
if xp is None:
697+
xp = array_namespace(x1, x2)
698+
699+
if side not in {"left", "right"}:
700+
message = "`side` must be either 'left' or 'right'."
701+
raise ValueError(message)
702+
703+
xp_default_int = _funcs.default_dtype(xp, kind="integral")
704+
x2_0d = x2.ndim == 0
705+
x1_1d = x1.ndim <= 1
706+
707+
if x1_1d or is_torch_namespace(xp):
708+
x2 = xp.reshape(x2, ()) if (x2_0d and x1_1d) else x2
709+
out = xp.searchsorted(x1, x2, side=side)
710+
return xp.astype(out, xp_default_int, copy=False)
711+
712+
return _funcs.searchsorted(x1, x2, side=side, xp=xp)
713+
714+
635715
def setdiff1d(
636716
x1: Array | complex,
637717
x2: Array | complex,

src/array_api_extra/_lib/_funcs.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99
from ._at import at
1010
from ._utils import _compat, _helpers
11-
from ._utils._compat import array_namespace, is_dask_namespace, is_jax_array
11+
from ._utils._compat import (
12+
array_namespace,
13+
is_dask_namespace,
14+
is_jax_array,
15+
)
1216
from ._utils._helpers import (
1317
asarrays,
1418
capabilities,
@@ -28,6 +32,7 @@
2832
"kron",
2933
"nunique",
3034
"pad",
35+
"searchsorted",
3136
"setdiff1d",
3237
"sinc",
3338
]
@@ -693,6 +698,40 @@ def pad(
693698
return at(padded, tuple(slices)).set(x)
694699

695700

701+
def searchsorted(
702+
x1: Array,
703+
x2: Array,
704+
/,
705+
*,
706+
side: Literal["left", "right"] = "left",
707+
xp: ModuleType,
708+
) -> Array:
709+
# numpydoc ignore=PR01,RT01
710+
"""See docstring in `array_api_extra._delegation.py`."""
711+
a = xp.full(x2.shape, 0, device=_compat.device(x1))
712+
713+
if x1.shape[-1] == 0:
714+
return a
715+
716+
n = xp.count_nonzero(~xp.isnan(x1), axis=-1, keepdims=True)
717+
b = xp.broadcast_to(n, x2.shape)
718+
719+
compare = xp.less_equal if side == "left" else xp.less
720+
721+
# while xp.any(b - a > 1):
722+
# refactored to for loop with ~log2(n) iterations for JAX JIT
723+
for _ in range(int(math.log2(x1.shape[-1])) + 1): # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
724+
c = (a + b) // 2
725+
x0 = xp.take_along_axis(x1, c, axis=-1)
726+
j = compare(x2, x0)
727+
b = xp.where(j, c, b)
728+
a = xp.where(j, a, c)
729+
730+
out = xp.where(compare(x2, xp.min(x1, axis=-1, keepdims=True)), 0, b)
731+
out = xp.where(xp.isnan(x2), x1.shape[-1], out) if side == "right" else out
732+
return xp.astype(out, default_dtype(xp, kind="integral"), copy=False)
733+
734+
696735
def setdiff1d(
697736
x1: Array | complex,
698737
x2: Array | complex,

0 commit comments

Comments
 (0)