99import numpy as np
1010
1111import blosc2
12- from blosc2 .ndarray import get_intersecting_chunks , npvecdot , slice_to_chunktuple
12+ from blosc2 .ndarray import get_intersecting_chunks , nptranspose , npvecdot , slice_to_chunktuple
1313
1414if TYPE_CHECKING :
1515 from collections .abc import Sequence
@@ -79,9 +79,8 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
7979 if np .isscalar (x1 ) or np .isscalar (x2 ):
8080 raise ValueError ("Arguments can't be scalars." )
8181
82- # Added this to pass array-api tests (which use internal getitem to check results)
83- x1 = blosc2 .asarray (x1 )
84- x2 = blosc2 .asarray (x2 )
82+ # Makes a SimpleProxy if inputs are not blosc2 arrays
83+ x1 , x2 = blosc2 .asarray (x1 ), blosc2 .asarray (x2 )
8584
8685 # Validate matrix multiplication compatibility
8786 if x1 .shape [builtins .max (- 1 , - len (x2 .shape ))] != x2 .shape [builtins .max (- 2 , - len (x2 .shape ))]:
@@ -183,9 +182,6 @@ def tensordot(
183182 """
184183 fast_path = kwargs .pop ("fast_path" , None ) # for testing purposes
185184 # TODO: add fast path for when don't need to change chunkshapes
186- # Added this to pass array-api tests (which use internal getitem to check results)
187- if isinstance (x1 , np .ndarray ) and isinstance (x2 , np .ndarray ):
188- return np .tensordot (x1 , x2 , axes = axes )
189185
190186 x1 , x2 = blosc2 .asarray (x1 ), blosc2 .asarray (x2 )
191187
@@ -261,24 +257,8 @@ def tensordot(
261257 a_selection = tuple (next (rchunk_iter ) if a else slice (None , None , 1 ) for a in a_keep )
262258 b_selection = tuple (next (rchunk_iter ) if b else slice (None , None , 1 ) for b in b_keep )
263259 res_chunks = tuple (s .stop - s .start for s in res_chunk )
264-
265- if fast_path : # just load everything
266- bx1 = x1 [a_selection ]
267- bx2 = x2 [b_selection ]
268- newshape_a = (
269- math .prod ([bx1 .shape [i ] for i in a_keep_axes ]),
270- math .prod ([bx1 .shape [a ] for a in a_axes ]),
271- )
272- newshape_b = (
273- math .prod ([bx2 .shape [b ] for b in b_axes ]),
274- math .prod ([bx2 .shape [i ] for i in b_keep_axes ]),
275- )
276- at = bx1 .transpose (newaxes_a ).reshape (newshape_a )
277- bt = bx2 .transpose (newaxes_b ).reshape (newshape_b )
278- res = np .dot (at , bt )
279- result [res_chunk ] += res .reshape (res_chunks )
280- else : # operands too big, have to go chunk-by-chunk
281- for ochunk in product (* op_chunks ):
260+ for ochunk in product (* op_chunks ):
261+ if not fast_path : # operands too big, have to go chunk-by-chunk
282262 op_chunk = tuple (
283263 slice (rc * rcs , builtins .min ((rc + 1 ) * rcs , x1s ), 1 )
284264 for rc , rcs , x1s in zip (ochunk , a_chunks_red , a_shape_red , strict = True )
@@ -293,21 +273,23 @@ def tensordot(
293273 op_chunk [next (order_iter )] if not b else bs_
294274 for bs_ , b in zip (b_selection , b_keep , strict = True )
295275 )
296- bx1 = x1 [a_selection ]
297- bx2 = x2 [b_selection ]
298- # adapted from numpy tensordot
299- newshape_a = (
300- math .prod ([bx1 .shape [i ] for i in a_keep_axes ]),
301- math .prod ([bx1 .shape [a ] for a in a_axes ]),
302- )
303- newshape_b = (
304- math .prod ([bx2 .shape [b ] for b in b_axes ]),
305- math .prod ([bx2 .shape [i ] for i in b_keep_axes ]),
306- )
307- at = bx1 .transpose (newaxes_a ).reshape (newshape_a )
308- bt = bx2 .transpose (newaxes_b ).reshape (newshape_b )
309- res = np .dot (at , bt )
310- result [res_chunk ] += res .reshape (res_chunks )
276+ bx1 = x1 [a_selection ]
277+ bx2 = x2 [b_selection ]
278+ # adapted from numpy tensordot
279+ newshape_a = (
280+ math .prod ([bx1 .shape [i ] for i in a_keep_axes ]),
281+ math .prod ([bx1 .shape [a ] for a in a_axes ]),
282+ )
283+ newshape_b = (
284+ math .prod ([bx2 .shape [b ] for b in b_axes ]),
285+ math .prod ([bx2 .shape [i ] for i in b_keep_axes ]),
286+ )
287+ at = nptranspose (bx1 , newaxes_a ).reshape (newshape_a )
288+ bt = nptranspose (bx2 , newaxes_b ).reshape (newshape_b )
289+ res = np .dot (at , bt )
290+ result [res_chunk ] += res .reshape (res_chunks )
291+ if fast_path : # already done everything
292+ break
311293 return result
312294
313295
@@ -396,19 +378,17 @@ def vecdot(x1: blosc2.NDArray, x2: blosc2.NDArray, axis: int = -1, **kwargs) ->
396378 )
397379 b_selection = tuple (next (rchunk_iter ) if b else slice (None , None , 1 ) for b in b_keep )
398380
399- if fast_path : # just load everything, also handles case of 0 in shapes
400- bx1 = x1 [a_selection ]
401- bx2 = x2 [b_selection ]
402- result [res_chunk ] += npvecdot (bx1 , bx2 , axis = axis ) # handles conjugation of bx1
403- else : # operands too big, have to go chunk-by-chunk
404- for ochunk in range (0 , a_shape_red , a_chunks_red ):
381+ for ochunk in range (0 , a_shape_red , a_chunks_red ):
382+ if not fast_path : # operands too big, go chunk-by-chunk
405383 op_chunk = (slice (ochunk , builtins .min (ochunk + a_chunks_red , x1 .shape [a_axes ]), 1 ),)
406384 a_selection = a_selection [:a_axes ] + op_chunk + a_selection [a_axes + 1 :]
407385 b_selection = b_selection [:b_axes ] + op_chunk + b_selection [b_axes + 1 :]
408- bx1 = x1 [a_selection ]
409- bx2 = x2 [b_selection ]
410- res = npvecdot (bx1 , bx2 , axis = axis ) # handles conjugation of bx1
411- result [res_chunk ] += res
386+ bx1 = x1 [a_selection ]
387+ bx2 = x2 [b_selection ]
388+ res = npvecdot (bx1 , bx2 , axis = axis ) # handles conjugation of bx1
389+ result [res_chunk ] += res
390+ if fast_path : # already done everything
391+ break
412392 return result
413393
414394
@@ -517,7 +497,7 @@ def permute_dims(
517497 src_slice = tuple (slice (start , stop ) for start , stop in start_stop )
518498 dst_slice = tuple (slice (start_stop [ax ][0 ], start_stop [ax ][1 ]) for ax in axes )
519499
520- transposed = np . transpose (arr [src_slice ], axes = axes )
500+ transposed = nptranspose (arr [src_slice ], axes = axes )
521501 result [dst_slice ] = np .ascontiguousarray (transposed )
522502
523503 return result
@@ -648,6 +628,7 @@ def outer(x1: blosc2.blosc2.NDArray, x2: blosc2.blosc2.NDArray, **kwargs: Any) -
648628 out: blosc2.NDArray
649629 A two-dimensional array containing the outer product and whose shape is (N, M).
650630 """
631+ x1 , x2 = blosc2 .asarray (x1 ), blosc2 .asarray (x2 )
651632 if (x1 .ndim != 1 ) or (x2 .ndim != 1 ):
652633 raise ValueError ("outer only valid for 1D inputs." )
653634 return tensordot (x1 , x2 , ((), ()), ** kwargs ) # for testing purposes
0 commit comments