Skip to content

Commit cef1731

Browse files
committed
Add weighted quantile and percentile support with tests
1 parent 4f522e6 commit cef1731

2 files changed

Lines changed: 39 additions & 30 deletions

File tree

jax/_src/numpy/reductions.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import jax
2727
from jax import lax
28-
from jax import numpy as jnp
2928
from jax._src import api
3029
from jax._src import core
3130
from jax._src import deprecations
@@ -2453,6 +2452,7 @@ def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None =
24532452

24542453
def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24552454
method: str, keepdims: bool, squash_nans: bool, weights: Array | None = None) -> Array:
2455+
from jax import numpy as jnp
24562456
if method not in ["linear", "lower", "higher", "midpoint", "nearest", "inverted_cdf"]:
24572457
raise ValueError("method can only be 'linear', 'lower', 'higher', 'midpoint', 'nearest', or 'inverted_cdf'")
24582458
keepdim = []
@@ -2485,7 +2485,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
24852485
axis = _canonicalize_axis(axis, a.ndim)
24862486

24872487
q, = promote_dtypes_inexact(q)
2488-
2488+
q = jnp.atleast_1d(q)
24892489
q_shape = q.shape
24902490
q_ndim = q.ndim
24912491
if q_ndim > 1:
@@ -2500,16 +2500,22 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25002500
a_shape = a.shape
25012501
w_shape = np.shape(weights)
25022502
if w_shape != a_shape:
2503-
if len(w_shape) != 1:
2504-
raise ValueError("1D weights expected when shapes of a and weights differ.")
25052503
if axis is None:
25062504
raise TypeError("Axis must be specified when shapes of a and weights differ.")
2507-
if w_shape[0] != a_shape[axis]:
2508-
raise ValueError("Length of weights not compatible with specified axis.")
2509-
resh = [1] * a.ndim
2510-
resh[axis] = w_shape[0]
2511-
weights = lax.expand_dims(weights, axis)
2512-
weights = _broadcast_to(weights, a.shape)
2505+
if isinstance(axis, tuple):
2506+
if w_shape != tuple(a_shape[i] for i in axis):
2507+
raise ValueError("Shape of weights must match the shape of the axes being reduced.")
2508+
weights = lax.broadcast_in_dim(
2509+
weights,
2510+
shape=a_shape,
2511+
broadcast_dimensions=axis
2512+
)
2513+
else:
2514+
if len(w_shape) != 1 or w_shape[0] != a_shape[axis]:
2515+
raise ValueError("Length of weights not compatible with specified axis.")
2516+
weights = lax.expand_dims(weights, axis)
2517+
weights = _broadcast_to(weights, a.shape)
2518+
25132519

25142520
if squash_nans:
25152521
nan_mask = ~lax_internal._isnan(a)
@@ -2525,14 +2531,14 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None,
25252531

25262532
def _weighted_quantile(qi):
25272533
index_dtype = dtypes.default_int_dtype()
2528-
idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype)
2534+
idx = sum(lax.lt(cum_weights_norm, qi), axis=axis, dtype=index_dtype, keepdims=keepdims)
25292535
idx = lax.clamp(0, idx, a_sorted.shape[axis] - 1)
2530-
val = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx, axis), axis)
2536+
val = jnp.take_along_axis(a_sorted, idx, axis)
25312537

25322538
idx_prev = lax.clamp(idx - 1, 0, a_sorted.shape[axis] - 1)
2533-
val_prev = jnp.take_along_axis(a_sorted, jnp.expand_dims(idx_prev, axis), axis)
2534-
cw_prev = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx_prev, axis), axis)
2535-
cw_next = jnp.take_along_axis(cum_weights_norm, jnp.expand_dims(idx, axis), axis)
2539+
val_prev = jnp.take_along_axis(a_sorted, idx_prev, axis)
2540+
cw_prev = jnp.take_along_axis(cum_weights_norm, idx_prev, axis)
2541+
cw_next = jnp.take_along_axis(cum_weights_norm, idx, axis)
25362542

25372543
if method == "linear":
25382544
denom = cw_next - cw_prev
@@ -2552,11 +2558,10 @@ def _weighted_quantile(qi):
25522558
else:
25532559
raise ValueError(f"{method=!r} not recognized")
25542560
return out
2555-
2556-
if q.ndim == 0:
2557-
result = _weighted_quantile(q)
2558-
else:
2559-
result = jax.vmap(_weighted_quantile)(q)
2561+
2562+
result = jax.vmap(_weighted_quantile)(q)
2563+
if q.shape == (1,):
2564+
result = result[0]
25602565
return result
25612566

25622567
if squash_nans:

tests/lax_numpy_reducers_test.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -787,18 +787,22 @@ def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdim
787787
a = rng(a_shape, a_dtype)
788788
q = rng(q_shape, q_dtype)
789789
if axis is None:
790-
weights_shape = a_shape
790+
weights_shape = a_shape
791791
elif isinstance(axis, tuple):
792-
weights_shape = tuple(a_shape[i] for i in axis)
792+
weights_shape = tuple(a_shape[i] for i in axis)
793793
else:
794-
weights_shape = (a_shape[axis],)
794+
weights_shape = (a_shape[axis],)
795795
weights = np.abs(rng(weights_shape, a_dtype)) + 1e-3
796796

797797
def np_fun(a, q, weights):
798-
return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims)
798+
return np.quantile(np.array(a), np.array(q), axis=axis, weights=np.array(weights), method=method, keepdims=keepdims)
799799
def jnp_fun(a, q, weights):
800-
return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims)
801-
args_maker = lambda: [a, q, weights]
800+
return jnp.quantile(a, q, axis=axis, weights=weights, method=method, keepdims=keepdims)
801+
args_maker = lambda: [
802+
rng(a_shape, a_dtype),
803+
rng(q_shape, q_dtype),
804+
np.abs(rng(weights_shape, a_dtype)) + 1e-3
805+
]
802806
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, tol=1e-6)
803807
self._CompileAndCheck(jnp_fun, args_maker, rtol=1e-6)
804808

@@ -807,27 +811,27 @@ def test_weighted_quantile_negative_weights(self):
807811
weights = jnp.array([1, -1, 1, 1, 1], dtype=float)
808812
q = jnp.array([0.5])
809813
with self.assertRaisesRegex(ValueError, "Weights must be non-negative"):
810-
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
814+
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights)
811815

812816
def test_weighted_quantile_all_weights_zero(self):
813817
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
814818
weights = jnp.zeros_like(a)
815819
q = jnp.array([0.5])
816820
with self.assertRaisesRegex(ValueError, "Sum of weights must not be zero"):
817-
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
821+
jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights)
818822

819823
def test_weighted_quantile_weights_with_nan(self):
820824
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
821825
weights = jnp.array([1, np.nan, 1, 1, 1], dtype=float)
822826
q = jnp.array([0.5])
823-
result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, squash_nans=False, weights=weights)
827+
result = jnp.quantile(a, q, axis=0, method="linear", keepdims=False, weights=weights)
824828
assert np.isnan(np.array(result)).all()
825829

826830
def test_weighted_quantile_scalar_q(self):
827831
a = jnp.array([1, 2, 3, 4, 5], dtype=float)
828832
weights = jnp.array([1, 2, 1, 1, 1], dtype=float)
829833
q = 0.5
830-
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, squash_nans=False, weights=weights)
834+
result = jnp.quantile(a, q, axis=0, method="inverted_cdf", keepdims=False, weights=weights)
831835
assert jnp.issubdtype(result.dtype, jnp.floating)
832836
assert result.shape == ()
833837

0 commit comments

Comments
 (0)