Skip to content

Commit c926554

Browse files
committed
New blosc2.argsort() function
1 parent 193dc76 commit c926554

4 files changed

Lines changed: 107 additions & 0 deletions

File tree

src/blosc2/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def _raise(exc):
533533
eye,
534534
asarray,
535535
astype,
536+
argsort,
536537
indices,
537538
sort,
538539
reshape,

src/blosc2/ndarray.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5218,6 +5218,18 @@ def indices(self, order: str | list[str] | None = None, **kwargs: Any) -> NDArra
52185218
"""
52195219
return indices(self, order, **kwargs)
52205220

5221+
def argsort(self, order: str | list[str] | None = None, **kwargs: Any) -> NDArray:
5222+
"""
5223+
Return the permutation that sorts the array.
5224+
5225+
This follows :func:`numpy.argsort` semantics more closely than
5226+
:meth:`indices`: plain 1-D arrays are supported, and ``order=None``
5227+
means "use the array's natural order" rather than "leave unsorted".
5228+
5229+
See full documentation in :func:`argsort`.
5230+
"""
5231+
return argsort(self, order, **kwargs)
5232+
52215233
def itersorted(
52225234
self,
52235235
order: str | list[str] | None = None,
@@ -6748,6 +6760,60 @@ def indices(array: blosc2.Array, order: str | list[str] | None = None, **kwargs:
67486760
return larr.indices(order).compute(**kwargs)
67496761

67506762

6763+
def argsort(array: blosc2.Array, order: str | list[str] | None = None, **kwargs: Any) -> NDArray:
6764+
"""
6765+
Return the indices that would sort the array.
6766+
6767+
This mirrors :func:`numpy.argsort` for 1-D arrays. Plain arrays sort by
6768+
their values. Structured arrays sort by ``order`` when provided, or by
6769+
their dtype field order when ``order=None``. Expression orders such as
6770+
``"abs(x)"`` are also supported when a matching ``full`` expression index
6771+
exists.
6772+
6773+
Parameters
6774+
----------
6775+
array: :ref:`blosc2.Array`
6776+
The 1-D array to be ordered.
6777+
order: str, list of str, optional
6778+
Primary and optional secondary order keys for structured arrays. When
6779+
omitted, NumPy's default record order is used for structured dtypes and
6780+
the array values themselves are used for plain dtypes.
6781+
kwargs: Any, optional
6782+
Keyword arguments that are supported by the :func:`empty` constructor.
6783+
6784+
Returns
6785+
-------
6786+
out: :ref:`NDArray`
6787+
The ordered logical positions as ``int64``.
6788+
6789+
Notes
6790+
-----
6791+
When the primary order key has a matching ``full`` field or expression
6792+
index, the permutation is returned directly from that index in ascending
6793+
stable order. Secondary keys refine ties after the primary indexed order.
6794+
Without a matching ``full`` index, :func:`argsort` falls back to
6795+
materializing the input values and delegating ordering to
6796+
:func:`numpy.argsort`.
6797+
6798+
The result is always a new array materialization. For persistent inputs,
6799+
the returned permutation is in memory by default; pass storage kwargs such
6800+
as ``urlpath`` (and typically ``mode="w"``) if the permutation should also
6801+
be persisted on disk.
6802+
"""
6803+
if isinstance(array, blosc2.NDArray):
6804+
from . import indexing
6805+
6806+
ordered = indexing.ordered_indices(array, order=order)
6807+
if ordered is not None:
6808+
return blosc2.asarray(ordered, **kwargs)
6809+
if indexing.is_expression_order(array, order):
6810+
raise ValueError("expression order requires a matching full expression index")
6811+
6812+
values = array[:]
6813+
positions = np.argsort(values, order=order, kind="stable")
6814+
return blosc2.asarray(positions.astype(np.int64, copy=False), **kwargs)
6815+
6816+
67516817
def sort(array: blosc2.Array, order: str | list[str] | None = None, **kwargs: Any) -> NDArray:
67526818
"""
67536819
Return a sorted array following the specified order.

tests/ndarray/test_indexing.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,29 @@ def test_full_index_reuses_primary_order_for_indices_and_sort():
674674
np.testing.assert_array_equal(arr.sort(order=["a", "b"])[:], np.sort(data, order=["a", "b"]))
675675

676676

677+
def test_full_index_reuses_primary_order_for_argsort():
678+
dtype = np.dtype([("a", np.int64), ("b", np.int64)])
679+
data = np.array(
680+
[(2, 9), (1, 8), (2, 7), (1, 6), (2, 5), (1, 4), (2, 3), (1, 2), (2, 1), (1, 0)],
681+
dtype=dtype,
682+
)
683+
arr = blosc2.asarray(data, chunks=(4,), blocks=(2,))
684+
arr.create_csindex("a")
685+
686+
np.testing.assert_array_equal(arr.argsort(order=["a", "b"])[:], np.argsort(data, order=["a", "b"]))
687+
688+
689+
def test_persistent_scalar_argsort_uses_full_index(tmp_path):
690+
path = tmp_path / "scalar_argsort.b2nd"
691+
data = np.array([9, 1, 7, 3, 1, 5], dtype=np.int64)
692+
arr = blosc2.asarray(data, urlpath=path, mode="w", chunks=(3,), blocks=(2,))
693+
arr.create_index(kind="full")
694+
695+
result = blosc2.argsort(arr)
696+
697+
np.testing.assert_array_equal(result[:], np.argsort(data, kind="stable"))
698+
699+
677700
def test_filtered_ordered_queries_support_cross_field_exact_indexes():
678701
dtype = np.dtype([("a", np.int64), ("b", np.int64), ("payload", np.int32)])
679702
data = np.array(

tests/ndarray/test_ndarray.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,23 @@ def test_indices(order):
482482
assert np.array_equal(b[:], nb)
483483

484484

485+
@pytest.mark.parametrize("order", ["f0", "f1", "f2", None])
486+
def test_argsort_structured(order):
487+
it = ((x + 1, x - 2, -x) for x in range(10))
488+
a = blosc2.fromiter(it, dtype="i4, i4, i8", shape=(10,))
489+
b = blosc2.argsort(a, order=order)
490+
narr = a[:]
491+
nb = np.argsort(narr, order=order, kind="stable")
492+
assert np.array_equal(b[:], nb)
493+
494+
495+
def test_argsort_scalar():
496+
data = np.array([7, 2, 9, 2, 1, 8], dtype=np.int64)
497+
a = blosc2.asarray(data)
498+
b = a.argsort()
499+
np.testing.assert_array_equal(b[:], np.argsort(data, kind="stable"))
500+
501+
485502
def test_save():
486503
a = blosc2.arange(0, 10, 1, dtype="i4", shape=(10,))
487504
blosc2.save(a, "test.b2nd")

0 commit comments

Comments
 (0)