Skip to content

Commit bcd8473

Browse files
authored
MAINT: housekeeping after upstream progress (#674)
1 parent bdef2b8 commit bcd8473

2 files changed

Lines changed: 2 additions & 34 deletions

File tree

src/array_api_extra/_lib/_funcs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array: # numpydoc ignore=PR01,RT01
291291
m = atleast_nd(m, ndim=2, xp=xp)
292292
m = xp.astype(m, dtype)
293293

294-
avg = _helpers.mean(m, axis=-1, keepdims=True, xp=xp)
294+
avg = xp.mean(m, axis=-1, keepdims=True)
295295

296296
m_shape = eager_shape(m)
297297
fact = m_shape[-1] - 1

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
is_dask_namespace,
2929
is_jax_namespace,
3030
is_numpy_array,
31-
is_pydata_sparse_namespace,
3231
is_torch_namespace,
3332
)
3433
from ._typing import Array, Device
@@ -53,7 +52,6 @@ def override(func):
5352
"in1d",
5453
"is_python_scalar",
5554
"jax_autojit",
56-
"mean",
5755
"meta_namespace",
5856
"pickle_flatten",
5957
"pickle_unflatten",
@@ -122,29 +120,6 @@ def in1d(
122120
return xp.take(ret, rev_idx, axis=0)
123121

124122

125-
def mean(
126-
x: Array,
127-
/,
128-
*,
129-
axis: int | tuple[int, ...] | None = None,
130-
keepdims: bool = False,
131-
xp: ModuleType | None = None,
132-
) -> Array: # numpydoc ignore=PR01,RT01
133-
"""
134-
Complex mean, https://github.com/data-apis/array-api/issues/846.
135-
"""
136-
if xp is None:
137-
xp = array_namespace(x)
138-
139-
if xp.isdtype(x.dtype, "complex floating"):
140-
x_real = xp.real(x)
141-
x_imag = xp.imag(x)
142-
mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
143-
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
144-
return mean_real + (mean_imag * xp.asarray(1j))
145-
return xp.mean(x, axis=axis, keepdims=keepdims)
146-
147-
148123
def is_python_scalar(x: object) -> TypeIs[complex]: # numpydoc ignore=PR01,RT01
149124
"""Return True if `x` is a Python scalar, False otherwise."""
150125
# isinstance(x, float) returns True for np.float64
@@ -332,14 +307,7 @@ def capabilities(
332307
Capabilities of the namespace.
333308
"""
334309
out = xp.__array_namespace_info__().capabilities()
335-
if is_pydata_sparse_namespace(xp):
336-
if out["boolean indexing"]:
337-
# FIXME https://github.com/pydata/sparse/issues/876
338-
# boolean indexing is supported, but not when the index is a sparse array.
339-
# boolean indexing by list or numpy array is not part of the Array API.
340-
out = out.copy()
341-
out["boolean indexing"] = False
342-
elif is_jax_namespace(xp):
310+
if is_jax_namespace(xp):
343311
if out["boolean indexing"]: # pragma: no cover
344312
# Backwards compatibility with jax <0.6.0
345313
# https://github.com/jax-ml/jax/issues/27418

0 commit comments

Comments
 (0)