|
27 | 27 | "nan_to_num", |
28 | 28 | "one_hot", |
29 | 29 | "pad", |
| 30 | + "searchsorted", |
30 | 31 | "sinc", |
31 | 32 | ] |
32 | 33 |
|
@@ -632,6 +633,85 @@ def pad( |
632 | 633 | return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) |
633 | 634 |
|
634 | 635 |
|
| 636 | +def searchsorted( |
| 637 | + x1: Array, |
| 638 | + x2: Array, |
| 639 | + /, |
| 640 | + *, |
| 641 | + side: Literal["left", "right"] = "left", |
| 642 | + xp: ModuleType | None = None, |
| 643 | +) -> Array: |
| 644 | + """ |
| 645 | + Find indices where elements should be inserted to maintain order. |
| 646 | +
|
| 647 | + Find the indices into a sorted array ``x1`` such that if the elements in ``x2`` |
| 648 | + were inserted before the indices, the resulting array would remain sorted. |
| 649 | +
|
| 650 | + The behavior of this function is similar to that of `array_api.searchsorted`, |
| 651 | + but it relaxes the requirement that `x1` must be one-dimensional. |
| 652 | + This function is vectorized, treating slices along the last axis |
| 653 | + as elements and preceding axes as batch (or "loop") dimensions. |
| 654 | +
|
| 655 | + Parameters |
| 656 | + ---------- |
| 657 | + x1 : Array |
| 658 | + Input array. Should have a real-valued data type. Must be sorted in ascending |
| 659 | + order along the last axis. |
| 660 | + x2 : Array |
| 661 | + Array containing search values. Should have a real-valued data type. Must have |
| 662 | + the same shape as ``x1`` except along the last axis. |
| 663 | + side : {'left', 'right'}, optional |
| 664 | + Argument controlling which index is returned if an element of ``x2`` is equal to |
| 665 | + one or more elements of ``x1``: ``'left'`` returns the index of the first of |
| 666 | + these elements; ``'right'`` returns the next index after the last of these |
| 667 | + elements. Default: ``'left'``. |
| 668 | + xp : array_namespace, optional |
| 669 | + The standard-compatible namespace for the array arguments. Default: infer. |
| 670 | +
|
| 671 | + Returns |
| 672 | + ------- |
| 673 | + Array: integer array |
| 674 | + An array of indices with the same shape as ``x2``. |
| 675 | +
|
| 676 | + Examples |
| 677 | + -------- |
| 678 | + >>> import array_api_strict as xp |
| 679 | + >>> import array_api_extra as xpx |
| 680 | + >>> x = xp.asarray([11, 12, 13, 13, 14, 15]) |
| 681 | + >>> xpx.searchsorted(x, xp.asarray([10, 11.5, 14.5, 16]), xp=xp) |
| 682 | + Array([0, 1, 5, 6], dtype=array_api_strict.int64) |
| 683 | + >>> xpx.searchsorted(x, xp.asarray(13), xp=xp) |
| 684 | + Array(2, dtype=array_api_strict.int64) |
| 685 | + >>> xpx.searchsorted(x, xp.asarray(13), side='right', xp=xp) |
| 686 | + Array(4, dtype=array_api_strict.int64) |
| 687 | +
|
| 688 | + `searchsorted` is vectorized along the last axis. |
| 689 | +
|
| 690 | + >>> x1 = xp.asarray([[1., 2., 3., 4.], [5., 6., 7., 8.]]) |
| 691 | + >>> x2 = xp.asarray([[1.1, 3.3], [6.6, 8.8]]) |
| 692 | + >>> xpx.searchsorted(x1, x2, xp=xp) |
| 693 | + Array([[1, 3], |
| 694 | + [2, 4]], dtype=array_api_strict.int64) |
| 695 | + """ |
| 696 | + if xp is None: |
| 697 | + xp = array_namespace(x1, x2) |
| 698 | + |
| 699 | + if side not in {"left", "right"}: |
| 700 | + message = "`side` must be either 'left' or 'right'." |
| 701 | + raise ValueError(message) |
| 702 | + |
| 703 | + xp_default_int = _funcs.default_dtype(xp, kind="integral") |
| 704 | + x2_0d = x2.ndim == 0 |
| 705 | + x1_1d = x1.ndim <= 1 |
| 706 | + |
| 707 | + if x1_1d or is_torch_namespace(xp): |
| 708 | + x2 = xp.reshape(x2, ()) if (x2_0d and x1_1d) else x2 |
| 709 | + out = xp.searchsorted(x1, x2, side=side) |
| 710 | + return xp.astype(out, xp_default_int, copy=False) |
| 711 | + |
| 712 | + return _funcs.searchsorted(x1, x2, side=side, xp=xp) |
| 713 | + |
| 714 | + |
635 | 715 | def setdiff1d( |
636 | 716 | x1: Array | complex, |
637 | 717 | x2: Array | complex, |
|
0 commit comments