@@ -41,6 +41,7 @@ def apply_where( # numpydoc ignore=GL08
4141 f2 : Callable [..., Array ],
4242 / ,
4343 * ,
44+ kwargs : dict [str , Array ] | None = None ,
4445 xp : ModuleType | None = None ,
4546) -> Array : ...
4647
@@ -53,6 +54,7 @@ def apply_where( # numpydoc ignore=GL08
5354 / ,
5455 * ,
5556 fill_value : Array | complex ,
57+ kwargs : dict [str , Array ] | None = None ,
5658 xp : ModuleType | None = None ,
5759) -> Array : ...
5860
@@ -65,6 +67,7 @@ def apply_where( # numpydoc ignore=PR01,PR02
6567 / ,
6668 * ,
6769 fill_value : Array | complex | None = None ,
70+ kwargs : dict [str , Array ] | None = None ,
6871 xp : ModuleType | None = None ,
6972) -> Array :
7073 """
@@ -91,6 +94,9 @@ def apply_where( # numpydoc ignore=PR01,PR02
9194 It does not need to be scalar; it needs however to be broadcastable with
9295 `cond` and `args`.
9396 Mutually exclusive with `f2`. You must provide one or the other.
97+ kwargs : dict of str : Array pairs
98+ Keyword argument(s) to `f1` (and `f2`). Values must be broadcastable with
99+ `cond`.
94100 xp : array_namespace, optional
95101 The standard-compatible namespace for `cond` and `args`. Default: infer.
96102
@@ -129,6 +135,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
129135 args_ = list (args ) if isinstance (args , tuple ) else [args ]
130136 del args
131137
138+ kwargs_ = {} if kwargs is None else kwargs
139+ kwkeys = list (kwargs_ .keys ())
140+ args_ = [* args_ , * kwargs_ .values ()]
141+ del kwargs
142+
132143 xp = array_namespace (cond , fill_value , * args_ ) if xp is None else xp
133144
134145 if isinstance (fill_value , int | float | complex | NoneType ):
@@ -139,8 +150,11 @@ def apply_where( # numpydoc ignore=PR01,PR02
139150 if is_dask_namespace (xp ):
140151 meta_xp = meta_namespace (cond , fill_value , * args_ , xp = xp )
141152 # map_blocks doesn't descend into tuples of Arrays
142- return xp .map_blocks (_apply_where , cond , f1 , f2 , fill_value , * args_ , xp = meta_xp )
143- return _apply_where (cond , f1 , f2 , fill_value , * args_ , xp = xp )
153+ return xp .map_blocks (
154+ _apply_where , cond , f1 , f2 , fill_value , * args_ , kwkeys = kwkeys , xp = meta_xp
155+ )
156+
157+ return _apply_where (cond , f1 , f2 , fill_value , * args_ , kwkeys = kwkeys , xp = xp )
144158
145159
146160def _apply_where ( # numpydoc ignore=PR01,RT01
@@ -149,15 +163,26 @@ def _apply_where( # numpydoc ignore=PR01,RT01
149163 f2 : Callable [..., Array ] | None ,
150164 fill_value : Array | int | float | complex | bool | None ,
151165 * args : Array ,
166+ kwkeys : list [str ],
152167 xp : ModuleType ,
153168) -> Array :
154169 """Helper of `apply_where`. On Dask, this runs on a single chunk."""
155170
171+ nargs = len (args ) - len (kwkeys )
172+ kwargs = dict (zip (kwkeys , args [nargs :], strict = True ))
173+ args = args [:nargs ]
174+
156175 if not capabilities (xp , device = _compat .device (cond ))["boolean indexing" ]:
157176 # jax.jit does not support assignment by boolean mask
158- return xp .where (cond , f1 (* args ), f2 (* args ) if f2 is not None else fill_value )
177+ return xp .where (
178+ cond ,
179+ f1 (* args , ** kwargs ),
180+ f2 (* args , ** kwargs ) if f2 is not None else fill_value ,
181+ )
159182
160- temp1 = f1 (* (arr [cond ] for arr in args ))
183+ temp1 = f1 (
184+ * (arr [cond ] for arr in args ), ** {key : val [cond ] for key , val in kwargs .items ()}
185+ )
161186
162187 if f2 is None :
163188 dtype = xp .result_type (temp1 , fill_value )
@@ -167,7 +192,10 @@ def _apply_where( # numpydoc ignore=PR01,RT01
167192 out = xp .astype (fill_value , dtype , copy = True )
168193 else :
169194 ncond = ~ cond
170- temp2 = f2 (* (arr [ncond ] for arr in args ))
195+ temp2 = f2 (
196+ * (arr [ncond ] for arr in args ),
197+ ** {key : val [ncond ] for key , val in kwargs .items ()},
198+ )
171199 dtype = xp .result_type (temp1 , temp2 )
172200 out = xp .empty_like (cond , dtype = dtype )
173201 out = at (out , ncond ).set (temp2 )
0 commit comments