@@ -787,18 +787,22 @@ def testWeightedQuantile(self, a_shape, a_dtype, q_shape, q_dtype, axis, keepdim
787787 a = rng (a_shape , a_dtype )
788788 q = rng (q_shape , q_dtype )
789789 if axis is None :
790- weights_shape = a_shape
790+ weights_shape = a_shape
791791 elif isinstance (axis , tuple ):
792- weights_shape = tuple (a_shape [i ] for i in axis )
792+ weights_shape = tuple (a_shape [i ] for i in axis )
793793 else :
794- weights_shape = (a_shape [axis ],)
794+ weights_shape = (a_shape [axis ],)
795795 weights = np .abs (rng (weights_shape , a_dtype )) + 1e-3
796796
797797 def np_fun (a , q , weights ):
798- return np .quantile (np .array (a ), np .array (q ), axis = axis , weights = np .array (weights ), method = method , keepdims = keepdims )
798+ return np .quantile (np .array (a ), np .array (q ), axis = axis , weights = np .array (weights ), method = method , keepdims = keepdims )
799799 def jnp_fun (a , q , weights ):
800- return jnp .quantile (a , q , axis = axis , weights = weights , method = method , keepdims = keepdims )
801- args_maker = lambda : [a , q , weights ]
800+ return jnp .quantile (a , q , axis = axis , weights = weights , method = method , keepdims = keepdims )
801+ args_maker = lambda : [
802+ rng (a_shape , a_dtype ),
803+ rng (q_shape , q_dtype ),
804+ np .abs (rng (weights_shape , a_dtype )) + 1e-3
805+ ]
802806 self ._CheckAgainstNumpy (np_fun , jnp_fun , args_maker , tol = 1e-6 )
803807 self ._CompileAndCheck (jnp_fun , args_maker , rtol = 1e-6 )
804808
@@ -807,27 +811,27 @@ def test_weighted_quantile_negative_weights(self):
807811 weights = jnp .array ([1 , - 1 , 1 , 1 , 1 ], dtype = float )
808812 q = jnp .array ([0.5 ])
809813 with self .assertRaisesRegex (ValueError , "Weights must be non-negative" ):
810- jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , squash_nans = False , weights = weights )
814+ jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , weights = weights )
811815
812816 def test_weighted_quantile_all_weights_zero (self ):
813817 a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
814818 weights = jnp .zeros_like (a )
815819 q = jnp .array ([0.5 ])
816820 with self .assertRaisesRegex (ValueError , "Sum of weights must not be zero" ):
817- jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , squash_nans = False , weights = weights )
821+ jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , weights = weights )
818822
819823 def test_weighted_quantile_weights_with_nan (self ):
820824 a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
821825 weights = jnp .array ([1 , np .nan , 1 , 1 , 1 ], dtype = float )
822826 q = jnp .array ([0.5 ])
823- result = jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , squash_nans = False , weights = weights )
827+ result = jnp .quantile (a , q , axis = 0 , method = "linear" , keepdims = False , weights = weights )
824828 assert np .isnan (np .array (result )).all ()
825829
826830 def test_weighted_quantile_scalar_q (self ):
827831 a = jnp .array ([1 , 2 , 3 , 4 , 5 ], dtype = float )
828832 weights = jnp .array ([1 , 2 , 1 , 1 , 1 ], dtype = float )
829833 q = 0.5
830- result = jnp .quantile (a , q , axis = 0 , method = "inverted_cdf" , keepdims = False , squash_nans = False , weights = weights )
834+ result = jnp .quantile (a , q , axis = 0 , method = "inverted_cdf" , keepdims = False , weights = weights )
831835 assert jnp .issubdtype (result .dtype , jnp .floating )
832836 assert result .shape == ()
833837
0 commit comments