|
22 | 22 |
|
23 | 23 | from dataclass_array import array_dataclass |
24 | 24 | from dataclass_array import ops |
25 | | -from dataclass_array.typing import DcOrArray, Shape # pylint: disable=g-multiple-import |
| 25 | +from dataclass_array.typing import DcOrArray, Shape # pylint: disable=g-multiple-import,g-importing-member |
26 | 26 | from dataclass_array.utils import inspect_utils |
27 | 27 | from dataclass_array.utils import np_utils |
28 | 28 | from dataclass_array.utils import py_utils |
@@ -307,21 +307,30 @@ def _vmap_method( |
307 | 307 | xnp: enp.NpModule, |
308 | 308 | ) -> _Out: |
309 | 309 | """Vectorize self using the `xnp` backend. Assume `self` was flatten.""" |
310 | | - if xnp is enp.lazy.np: |
| 310 | + is_jax = enp.lazy.is_jax_xnp(xnp) |
| 311 | + is_torch = enp.lazy.is_torch_xnp(xnp) |
| 312 | + |
| 313 | + if enp.lazy.is_np_xnp(xnp): |
311 | 314 | return _vmap_method_np(args, map_non_static=map_non_static) |
312 | | - elif xnp is enp.lazy.jnp: |
| 315 | + elif is_jax or is_torch: |
| 316 | + if is_jax: |
| 317 | + make_vmap_fn = _jax_vmap_cached |
| 318 | + elif is_torch: |
| 319 | + make_vmap_fn = _torch_vmap_cached |
| 320 | + else: |
| 321 | + raise ValueError('Unexpected') |
313 | 322 | return _vmap_method_jax_torch( |
314 | 323 | args, |
315 | 324 | map_non_static=map_non_static, |
316 | | - make_vmap_fn=_jax_vmap_cached, |
| 325 | + make_vmap_fn=make_vmap_fn, |
317 | 326 | ) |
318 | | - elif xnp is enp.lazy.tnp: |
319 | | - return _vmap_method_tf(args, map_non_static=map_non_static) |
320 | | - elif xnp is enp.lazy.torch: |
321 | | - return _vmap_method_jax_torch( |
322 | | - args, |
323 | | - map_non_static=map_non_static, |
324 | | - make_vmap_fn=_torch_vmap_cached, |
| 327 | + elif enp.lazy.is_tf_xnp(xnp): |
| 328 | + # return _vmap_method_tf(args, map_non_static=map_non_static) |
| 329 | + |
| 330 | + # TODO(epot): Use `tf.vectorized_map()` once TF support custom nesting |
| 331 | + raise NotImplementedError( |
| 332 | + 'vectorization not supported in TF yet due to lack of `tf.nest` ' |
| 333 | + 'support. Please upvote or comment b/152678472.' |
325 | 334 | ) |
326 | 335 | raise TypeError(f'Invalid numpy module: {xnp}') |
327 | 336 |
|
@@ -400,13 +409,52 @@ def _vmap_method_tf( |
400 | 409 | map_non_static: _MapNonStatic, |
401 | 410 | ) -> _OutT: |
402 | 411 | """vectorization using `tf` backend.""" |
403 | | - # TODO(epot): Use `tf.vectorized_map()` once TF support custom nesting |
404 | | - raise NotImplementedError( |
405 | | - 'vectorization not supported in TF yet due to lack of `tf.nest` ' |
406 | | - 'support. Please upvote or comment b/152678472.' |
| 412 | + |
| 413 | + # Flatten args |
| 414 | + |
| 415 | + args_info = args.map(lambda _: None) |
| 416 | + # ... except the non-static ones |
| 417 | + args_info = map_non_static(lambda _: 0, args_info) |
| 418 | + |
| 419 | + # Split args in static/non-static |
| 420 | + static_args = {} |
| 421 | + nonstatic_args = {} |
| 422 | + for a, ai in zip(args, args_info): |
| 423 | + assert a.name == ai.name |
| 424 | + if ai.value is None: |
| 425 | + static_args[a.name] = a.value |
| 426 | + else: |
| 427 | + nonstatic_args[a.name] = a.value |
| 428 | + |
| 429 | + def new_fn(non_statics, statics): |
| 430 | + # Merge args and call the function |
| 431 | + new_args = args.replace_args_values(dict(**non_statics, **statics)) |
| 432 | + return new_args.call() |
| 433 | + |
| 434 | + # `vectorized_map(` uses autograph, which fails, so use tf.map_fn instead |
| 435 | + return _better_map_fn( # |
| 436 | + functools.partial(new_fn, statics=static_args), |
| 437 | + nonstatic_args, |
407 | 438 | ) |
408 | 439 |
|
409 | 440 |
|
| 441 | +# tf.map_fn do not support different output signature: |
| 442 | +def _better_map_fn(fn, elems, **kwargs): |
| 443 | + """Like `tf.map_fn`.""" |
| 444 | + tf = enp.lazy.tf |
| 445 | + if 'fn_output_signature' not in kwargs: |
| 446 | + elem_spec = tf.nest.map_structure( |
| 447 | + lambda t: tf.type_spec_from_value(t)._unbatch(), elems # pylint: disable=protected-access |
| 448 | + ) |
| 449 | + output_spec = tf.nest.map_structure( |
| 450 | + tf.type_spec_from_value, |
| 451 | + tf.function(fn).get_concrete_function(elem_spec).structured_outputs, |
| 452 | + ) |
| 453 | + kwargs['fn_output_signature'] = output_spec |
| 454 | + |
| 455 | + return tf.map_fn(fn, elems, **kwargs) |
| 456 | + |
| 457 | + |
410 | 458 | def _stack(*vals: _OutT) -> _OutT: |
411 | 459 | """Stack the given tree.""" |
412 | 460 | assert vals |
|
0 commit comments