|
16 | 16 |
|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | | -from typing import Iterable # pylint: disable=g-multiple-import |
| 19 | +import functools |
| 20 | +from typing import Any, Callable, Iterable, Optional # pylint: disable=g-multiple-import |
20 | 21 |
|
21 | 22 | from dataclass_array import array_dataclass |
22 | | -from dataclass_array.typing import DcT |
| 23 | +from dataclass_array.typing import Array, DcT # pylint: disable=g-multiple-import |
23 | 24 | from dataclass_array.utils import np_utils |
| 25 | +from etils import enp |
24 | 26 | from etils import epy |
25 | 27 |
|
26 | 28 |
|
27 | | -def stack( |
28 | | - arrays: Iterable[DcT], # list[_DcT['*shape']] |
| 29 | +def _ops_base( |
| 30 | + arrays: Iterable[DcT], |
29 | 31 | *, |
30 | | - axis: int = 0, |
31 | | -) -> DcT: # _DcT['len(arrays) *shape']: |
32 | | - """Stack dataclasses together.""" |
| 32 | + axis: int, |
| 33 | + array_fn: Callable[ |
| 34 | + [ |
| 35 | + enp.NpModule, |
| 36 | + int, |
| 37 | + Any, # array_dataclass._ArrayField[Array['*din']], |
| 38 | + ], |
| 39 | + Array['*dout'], |
| 40 | + ], |
| 41 | + dc_fn: Optional[ |
| 42 | + Callable[ |
| 43 | + [ |
| 44 | + enp.NpModule, |
| 45 | + int, |
| 46 | + Any, # array_dataclass._ArrayField[DcT], |
| 47 | + ], |
| 48 | + DcT, |
| 49 | + ] |
| 50 | + ], |
| 51 | +) -> DcT: |
| 52 | + """Base function for all ops.""" |
33 | 53 | arrays = list(arrays) |
34 | 54 | first_arr = arrays[0] |
35 | 55 |
|
@@ -61,9 +81,41 @@ def stack( |
61 | 81 | # jax.tree_map(lambda x, y: x+y, (None, 10), (1, 2)) == (None, 12) |
62 | 82 | # Similarly, static values will be the ones from the first element. |
63 | 83 | merged_arr = first_arr._map_field( # pylint: disable=protected-access |
64 | | - array_fn=lambda f: xnp.stack( # pylint: disable=g-long-lambda |
| 84 | + array_fn=functools.partial(array_fn, xnp, axis), |
| 85 | + dc_fn=functools.partial(dc_fn, xnp, axis), |
| 86 | + ) |
| 87 | + return merged_arr |
| 88 | + |
| 89 | + |
| 90 | +def stack( |
| 91 | + arrays: Iterable[DcT], # list[_DcT['*shape']] |
| 92 | + *, |
| 93 | + axis: int = 0, |
| 94 | +) -> DcT: # _DcT['len(arrays) *shape']: |
| 95 | + """Stack dataclasses together.""" |
| 96 | + return _ops_base( |
| 97 | + arrays, |
| 98 | + axis=axis, |
| 99 | + array_fn=lambda xnp, axis, f: xnp.stack( # pylint: disable=g-long-lambda |
65 | 100 | [getattr(arr, f.name) for arr in arrays], axis=axis |
66 | 101 | ), |
67 | | - dc_fn=lambda f: stack([getattr(arr, f.name) for arr in arrays]), |
| 102 | + dc_fn=lambda xnp, axis, f: stack( # pylint: disable=g-long-lambda |
| 103 | + [getattr(arr, f.name) for arr in arrays], |
| 104 | + axis=axis, |
| 105 | + ), |
| 106 | + ) |
| 107 | + |
| 108 | + |
| 109 | +def concat(arrays: Iterable[DcT], *, axis: int = 0) -> DcT: |
| 110 | + """Concatenate dataclasses together.""" |
| 111 | + return _ops_base( |
| 112 | + arrays, |
| 113 | + axis=axis, |
| 114 | + array_fn=lambda xnp, axis, f: xnp.concatenate( # pylint: disable=g-long-lambda |
| 115 | + [getattr(arr, f.name) for arr in arrays], axis=axis |
| 116 | + ), |
| 117 | + dc_fn=lambda xnp, axis, f: concat( # pylint: disable=g-long-lambda |
| 118 | + [getattr(arr, f.name) for arr in arrays], |
| 119 | + axis=axis, |
| 120 | + ), |
68 | 121 | ) |
69 | | - return merged_arr |
|
0 commit comments