Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 630a144

Browse files
authored
[API NEW][SET FUNC] Add set functions (#20693)
* [API] Add set functions * update tests * fix lint
1 parent 9e6dd92 commit 630a144

3 files changed

Lines changed: 238 additions & 0 deletions

File tree

python/mxnet/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .utils import * # pylint: disable=wildcard-import
2828
from .function_base import * # pylint: disable=wildcard-import
2929
from .stride_tricks import * # pylint: disable=wildcard-import
30+
from .set_functions import * # pylint: disable=wildcard-import
3031
from .io import * # pylint: disable=wildcard-import
3132
from .arrayprint import * # pylint: disable=wildcard-import
3233

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Standard Array API for creating and operating on sets."""
19+
20+
from collections import namedtuple
21+
22+
from ..ndarray import numpy as _mx_nd_np
23+
24+
25+
__all__ = ['unique_all', 'unique_inverse', 'unique_values']
26+
27+
28+
def unique_all(x):
29+
"""
30+
Returns the unique elements of an input array `x`
31+
32+
Notes
33+
-----
34+
`unique_all` is a standard API in
35+
https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-all-x
36+
instead of an official NumPy operator.
37+
38+
Parameters
39+
----------
40+
x : ndarray
41+
Input array. This will be flattened if it is not already 1-D.
42+
43+
Returns
44+
-------
45+
out : Tuple[ndarray, ndarray, ndarray, ndarray]
46+
a namedtuple (values, indices, inverse_indices, counts):
47+
values : ndarray
48+
The sorted unique values.
49+
indices : ndarray, optional
50+
The indices of the first occurrences of the unique values in the
51+
original array.
52+
inverse_indices : ndarray
53+
The indices to reconstruct the original array from the
54+
unique array.
55+
counts : ndarray
56+
The number of times each of the unique values comes up in the
57+
original array.
58+
"""
59+
UniqueAll = namedtuple('UniqueAll', ['values', 'indices', 'inverse_indices', 'counts'])
60+
return UniqueAll(*_mx_nd_np.unique(x, True, True, True))
61+
62+
63+
def unique_inverse(x):
64+
"""
65+
Returns the unique elements of an input array `x` and the indices
66+
from the set of unique elements that reconstruct `x`.
67+
68+
Notes
69+
-----
70+
`unique_inverse` is a standard API in
71+
https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-inverse-x
72+
instead of an official NumPy operator.
73+
74+
Parameters
75+
----------
76+
x : ndarray
77+
Input array. This will be flattened if it is not already 1-D.
78+
79+
Returns
80+
-------
81+
out : Tuple[ndarray, ndarray]
82+
a namedtuple (values, inverse_indices):
83+
values : ndarray
84+
The sorted unique values.
85+
inverse_indices : ndarray
86+
The indices to reconstruct the original array from the
87+
unique array.
88+
"""
89+
UniqueInverse = namedtuple('UniqueInverse', ['values', 'inverse_indices'])
90+
return UniqueInverse(*_mx_nd_np.unique(x, False, True, False))
91+
92+
93+
def unique_values(x):
94+
"""
95+
Returns the unique elements of an input array `x`.
96+
97+
Notes
98+
-----
99+
`unique_values` is a standard API in
100+
https://data-apis.org/array-api/latest/API_specification/set_functions.html#unique-values-x
101+
instead of an official NumPy operator.
102+
103+
Parameters
104+
----------
105+
x : ndarray
106+
Input array. This will be flattened if it is not already 1-D.
107+
108+
Returns
109+
-------
110+
out : ndarray
111+
The sorted unique values.
112+
"""
113+
return _mx_nd_np.unique(x, False, False, False)

tests/python/unittest/test_numpy_op.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8079,6 +8079,130 @@ def forward(self, a):
80798079
assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)
80808080

80818081

8082+
@use_np
8083+
@pytest.mark.parametrize('shape,index,inverse,counts', [
8084+
((), True, True, True),
8085+
((1, ), True, True, True),
8086+
((5, ), True, True, True),
8087+
((5, ), True, True, True),
8088+
((5, 4), True, True, True),
8089+
((5, 0, 4), True, True, True),
8090+
((0, 0, 0), True, True, True),
8091+
((5, 3, 4), True, True, True),
8092+
])
8093+
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64'])
8094+
@pytest.mark.parametrize('hybridize', [False, True])
8095+
def test_np_unique_all(shape, index, inverse, counts, dtype, hybridize):
8096+
class TestUniqueAll(HybridBlock):
8097+
def __init__(self):
8098+
super(TestUniqueAll, self).__init__()
8099+
8100+
def forward(self, a):
8101+
return np.unique_all(a)
8102+
8103+
test_unique = TestUniqueAll()
8104+
if hybridize:
8105+
test_unique.hybridize()
8106+
x = onp.random.uniform(-8.0, 8.0, size=shape)
8107+
x = np.array(x, dtype=dtype)
8108+
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
8109+
mx_out = test_unique(x)
8110+
for i in range(len(mx_out)):
8111+
assert mx_out[i].shape == np_out[i].shape
8112+
assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)
8113+
8114+
# Test imperative once again
8115+
mx_out = np.unique_all(x)
8116+
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
8117+
assert mx_out.values.shape == np_out[0].shape
8118+
assert_almost_equal(mx_out.values.asnumpy(), np_out[0], rtol=1e-3, atol=1e-5)
8119+
assert mx_out.indices.shape == np_out[1].shape
8120+
assert_almost_equal(mx_out.indices.asnumpy(), np_out[1], rtol=1e-3, atol=1e-5)
8121+
assert mx_out.inverse_indices.shape == np_out[2].shape
8122+
assert_almost_equal(mx_out.inverse_indices.asnumpy(), np_out[2], rtol=1e-3, atol=1e-5)
8123+
assert mx_out.counts.shape == np_out[3].shape
8124+
assert_almost_equal(mx_out.counts.asnumpy(), np_out[3], rtol=1e-3, atol=1e-5)
8125+
8126+
8127+
@use_np
8128+
@pytest.mark.parametrize('shape,index,inverse,counts', [
8129+
((), False, True, False),
8130+
((1, ), False, True, False),
8131+
((5, ), False, True, False),
8132+
((5, ), False, True, False),
8133+
((5, 4), False, True, False),
8134+
((5, 0, 4), False, True, False),
8135+
((0, 0, 0), False, True, False),
8136+
((5, 3, 4), False, True, False),
8137+
])
8138+
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64'])
8139+
@pytest.mark.parametrize('hybridize', [False, True])
8140+
def test_np_unique_inverse(shape, index, inverse, counts, dtype, hybridize):
8141+
class TestUniqueInverse(HybridBlock):
8142+
def __init__(self):
8143+
super(TestUniqueInverse, self).__init__()
8144+
8145+
def forward(self, a):
8146+
return np.unique_inverse(a)
8147+
8148+
test_unique = TestUniqueInverse()
8149+
if hybridize:
8150+
test_unique.hybridize()
8151+
x = onp.random.uniform(-8.0, 8.0, size=shape)
8152+
x = np.array(x, dtype=dtype)
8153+
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
8154+
mx_out = test_unique(x)
8155+
for i in range(len(mx_out)):
8156+
assert mx_out[i].shape == np_out[i].shape
8157+
assert_almost_equal(mx_out[i].asnumpy(), np_out[i], rtol=1e-3, atol=1e-5)
8158+
8159+
# Test imperative once again
8160+
mx_out = np.unique_inverse(x)
8161+
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
8162+
assert mx_out.values.shape == np_out[0].shape
8163+
assert_almost_equal(mx_out.values.asnumpy(), np_out[0], rtol=1e-3, atol=1e-5)
8164+
assert mx_out.inverse_indices.shape == np_out[1].shape
8165+
assert_almost_equal(mx_out.inverse_indices.asnumpy(), np_out[1], rtol=1e-3, atol=1e-5)
8166+
8167+
8168+
@use_np
8169+
@pytest.mark.parametrize('shape,index,inverse,counts', [
8170+
((), False, False, False),
8171+
((1, ), False, False, False),
8172+
((5, ), False, False, False),
8173+
((5, ), False, False, False),
8174+
((5, 4), False, False, False),
8175+
((5, 0, 4), False, False, False),
8176+
((0, 0, 0), False, False, False),
8177+
((5, 3, 4), False, False, False),
8178+
])
8179+
@pytest.mark.parametrize('dtype', ['float32', 'float64', 'int8', 'uint8', 'int32', 'int64'])
8180+
@pytest.mark.parametrize('hybridize', [False, True])
8181+
def test_np_unique_values(shape, index, inverse, counts, dtype, hybridize):
8182+
class TestUniqueValues(HybridBlock):
8183+
def __init__(self):
8184+
super(TestUniqueValues, self).__init__()
8185+
8186+
def forward(self, a):
8187+
return np.unique_values(a)
8188+
8189+
test_unique = TestUniqueValues()
8190+
if hybridize:
8191+
test_unique.hybridize()
8192+
x = onp.random.uniform(-8.0, 8.0, size=shape)
8193+
x = np.array(x, dtype=dtype)
8194+
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
8195+
mx_out = test_unique(x)
8196+
assert mx_out.shape == np_out.shape
8197+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
8198+
8199+
# Test imperative once again
8200+
mx_out = np.unique_values(x)
8201+
np_out = onp.unique(x.asnumpy(), return_index=index, return_inverse=inverse, return_counts=counts)
8202+
assert mx_out.shape == np_out.shape
8203+
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
8204+
8205+
80828206
@use_np
80838207
def test_np_take():
80848208
configs = [

0 commit comments

Comments
 (0)