Skip to content

Commit 763ea1a

Browse files
authored
ENH: apply_where: add kwargs support (#624)
* ENH: apply_where: add kwargs support * TST: apply_where: improve tests per review suggestions
1 parent edc111b commit 763ea1a

2 files changed

Lines changed: 65 additions & 18 deletions

File tree

src/array_api_extra/_lib/_funcs.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def apply_where( # numpydoc ignore=GL08
4141
f2: Callable[..., Array],
4242
/,
4343
*,
44+
kwargs: dict[str, Array] | None = None,
4445
xp: ModuleType | None = None,
4546
) -> Array: ...
4647

@@ -53,6 +54,7 @@ def apply_where( # numpydoc ignore=GL08
5354
/,
5455
*,
5556
fill_value: Array | complex,
57+
kwargs: dict[str, Array] | None = None,
5658
xp: ModuleType | None = None,
5759
) -> Array: ...
5860

@@ -65,6 +67,7 @@ def apply_where( # numpydoc ignore=PR01,PR02
6567
/,
6668
*,
6769
fill_value: Array | complex | None = None,
70+
kwargs: dict[str, Array] | None = None,
6871
xp: ModuleType | None = None,
6972
) -> Array:
7073
"""
@@ -91,6 +94,9 @@ def apply_where( # numpydoc ignore=PR01,PR02
9194
It does not need to be scalar; it needs however to be broadcastable with
9295
`cond` and `args`.
9396
Mutually exclusive with `f2`. You must provide one or the other.
97+
kwargs : dict of str : Array pairs
98+
Keyword argument(s) to `f1` (and `f2`). Values must be broadcastable with
99+
`cond`.
94100
xp : array_namespace, optional
95101
The standard-compatible namespace for `cond` and `args`. Default: infer.
96102
@@ -129,6 +135,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
129135
args_ = list(args) if isinstance(args, tuple) else [args]
130136
del args
131137

138+
kwargs_ = {} if kwargs is None else kwargs
139+
kwkeys = list(kwargs_.keys())
140+
args_ = [*args_, *kwargs_.values()]
141+
del kwargs
142+
132143
xp = array_namespace(cond, fill_value, *args_) if xp is None else xp
133144

134145
if isinstance(fill_value, int | float | complex | NoneType):
@@ -139,8 +150,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
139150
if is_dask_namespace(xp):
140151
meta_xp = meta_namespace(cond, fill_value, *args_, xp=xp)
141152
# map_blocks doesn't descend into tuples of Arrays
142-
return xp.map_blocks(_apply_where, cond, f1, f2, fill_value, *args_, xp=meta_xp)
143-
return _apply_where(cond, f1, f2, fill_value, *args_, xp=xp)
153+
return xp.map_blocks(
154+
_apply_where, cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=meta_xp
155+
)
156+
157+
return _apply_where(cond, f1, f2, fill_value, *args_, kwkeys=kwkeys, xp=xp)
144158

145159

146160
def _apply_where( # numpydoc ignore=PR01,RT01
@@ -149,15 +163,26 @@ def _apply_where( # numpydoc ignore=PR01,RT01
149163
f2: Callable[..., Array] | None,
150164
fill_value: Array | int | float | complex | bool | None,
151165
*args: Array,
166+
kwkeys: list[str],
152167
xp: ModuleType,
153168
) -> Array:
154169
"""Helper of `apply_where`. On Dask, this runs on a single chunk."""
155170

171+
nargs = len(args) - len(kwkeys)
172+
kwargs = dict(zip(kwkeys, args[nargs:], strict=True))
173+
args = args[:nargs]
174+
156175
if not capabilities(xp, device=_compat.device(cond))["boolean indexing"]:
157176
# jax.jit does not support assignment by boolean mask
158-
return xp.where(cond, f1(*args), f2(*args) if f2 is not None else fill_value)
177+
return xp.where(
178+
cond,
179+
f1(*args, **kwargs),
180+
f2(*args, **kwargs) if f2 is not None else fill_value,
181+
)
159182

160-
temp1 = f1(*(arr[cond] for arr in args))
183+
temp1 = f1(
184+
*(arr[cond] for arr in args), **{key: val[cond] for key, val in kwargs.items()}
185+
)
161186

162187
if f2 is None:
163188
dtype = xp.result_type(temp1, fill_value)
@@ -167,7 +192,10 @@ def _apply_where( # numpydoc ignore=PR01,RT01
167192
out = xp.astype(fill_value, dtype, copy=True)
168193
else:
169194
ncond = ~cond
170-
temp2 = f2(*(arr[ncond] for arr in args))
195+
temp2 = f2(
196+
*(arr[ncond] for arr in args),
197+
**{key: val[ncond] for key, val in kwargs.items()},
198+
)
171199
dtype = xp.result_type(temp1, temp2)
172200
out = xp.empty_like(cond, dtype=dtype)
173201
out = at(out, ncond).set(temp2)

tests/test_funcs.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,8 @@ def test_device(self, xp: ModuleType, device: Device):
210210
deadline=None,
211211
)
212212
@given(
213-
n_arrays=st.integers(min_value=1, max_value=3),
213+
n_arrays=st.integers(min_value=0, max_value=3),
214+
n_kwarrays=st.integers(min_value=0, max_value=3),
214215
rng_seed=st.integers(min_value=1000000000, max_value=9999999999),
215216
dtype=npst.floating_dtypes(sizes=(32, 64)),
216217
p=st.floats(min_value=0, max_value=1),
@@ -219,6 +220,7 @@ def test_device(self, xp: ModuleType, device: Device):
219220
def test_hypothesis(
220221
self,
221222
n_arrays: int,
223+
n_kwarrays: int,
222224
rng_seed: int,
223225
dtype: np.dtype[Any],
224226
p: float,
@@ -233,9 +235,14 @@ def test_hypothesis(
233235
):
234236
pytest.xfail(reason="NumPy 1.x dtype promotion for scalars")
235237

236-
mbs = npst.mutually_broadcastable_shapes(num_shapes=n_arrays + 1, min_side=0)
238+
_ = hypothesis.assume(n_arrays + n_kwarrays > 0)
239+
mbs = npst.mutually_broadcastable_shapes(
240+
num_shapes=1 + n_arrays + n_kwarrays, min_side=0
241+
)
237242
input_shapes, _ = data.draw(mbs)
238-
cond_shape, *shapes = input_shapes
243+
cond_shape = input_shapes[0]
244+
shapes = input_shapes[1 : 1 + n_arrays]
245+
kwshapes = input_shapes[1 + n_arrays :]
239246

240247
# cupy/cupy#8382
241248
# https://github.com/jax-ml/jax/issues/26658
@@ -257,22 +264,34 @@ def test_hypothesis(
257264
for shape in shapes
258265
)
259266

260-
def f1(*args: Array) -> Array:
261-
return cast(Array, sum(args))
267+
kwargs = {
268+
f"kw{n}": xp.asarray(
269+
data.draw(npst.arrays(dtype=dtype.type, shape=shape, elements=elements))
270+
)
271+
for n, shape in enumerate(kwshapes)
272+
}
273+
kwkeys = kwargs.keys()
274+
275+
def f1(*args: Array, **kwargs: dict[str, Array]) -> Array:
276+
assert kwargs.keys() == kwkeys
277+
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
278+
return cast(Array, sum(args_kwargs))
262279

263-
def f2(*args: Array) -> Array:
264-
return cast(Array, sum(args) / 2)
280+
def f2(*args: Array, **kwargs: dict[str, Array]) -> Array:
281+
assert kwargs.keys() == kwkeys
282+
args_kwargs = cast(tuple[Array, ...], (*args, *kwargs.values()))
283+
return cast(Array, sum(args_kwargs) / 2)
265284

266285
rng = np.random.default_rng(rng_seed)
267286
cond = xp.asarray(rng.random(size=cond_shape) > p)
268287

269-
res1 = apply_where(cond, arrays, f1, fill_value=fill_value)
270-
res2 = apply_where(cond, arrays, f1, f2)
271-
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value)
288+
res1 = apply_where(cond, arrays, f1, fill_value=fill_value, kwargs=kwargs)
289+
res2 = apply_where(cond, arrays, f1, f2, kwargs=kwargs)
290+
res3 = apply_where(cond, arrays, f1, fill_value=float_fill_value, kwargs=kwargs)
272291

273-
ref1 = xp.where(cond, f1(*arrays), fill_value)
274-
ref2 = xp.where(cond, f1(*arrays), f2(*arrays))
275-
ref3 = xp.where(cond, f1(*arrays), float_fill_value)
292+
ref1 = xp.where(cond, f1(*arrays, **kwargs), fill_value)
293+
ref2 = xp.where(cond, f1(*arrays, **kwargs), f2(*arrays, **kwargs))
294+
ref3 = xp.where(cond, f1(*arrays, **kwargs), float_fill_value)
276295

277296
xp_assert_close(res1, ref1, rtol=2e-16)
278297
xp_assert_equal(res2, ref2)

0 commit comments

Comments
 (0)