Skip to content

Commit 877d010

Browse files
committed
[FEAT] Output is a blosc2.NDArray
1 parent bc6729c commit 877d010

1 file changed

Lines changed: 5 additions & 9 deletions

File tree

src/blosc2/ndarray.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3729,26 +3729,22 @@ def matmul(x1: NDArray, x2: NDArray, **kwargs: Any) -> NDArray:
37293729
p1, q1 = x1.chunks[-2:]
37303730
q2 = x2.chunks[-1]
37313731

3732-
result = np.zeros((n, m), dtype=x1.dtype)
3733-
# result = blosc2.zeros(n) # TODO: file a ticket for blosc2.zeros()
3732+
result = blosc2.zeros((n, m), dtype=x1.dtype)
37343733

37353734
for row in range(0, n, p1):
37363735
row_end = (row+p1) if (row+p1) < n else n
37373736
for col in range(0, m, q2):
37383737
col_end = (col+q2) if (col+q2) < m else m
3739-
bres = result[row:row_end, col:col_end]
37403738
for aux in range(0, l, q1):
37413739
aux_end = (aux+q1) if (aux+q1) < l else l
37423740
bx1 = x1[row:row_end, aux:aux_end]
37433741
bx2 = x2[aux:aux_end, col:col_end]
3744-
bres[:] += np.matmul(bx1, bx2)
3742+
result[row:row_end, col:col_end] += np.matmul(bx1, bx2)
37453743

37463744
if x1_is_vector and x2_is_vector:
3747-
result = result[0][0]
3748-
elif x1_is_vector:
3749-
result = result.reshape((m,))
3750-
elif x2_is_vector:
3751-
result = result.reshape((n,))
3745+
return result[0][0]
3746+
3747+
result.squeeze()
37523748

37533749
return result
37543750

0 commit comments

Comments
 (0)