Skip to content

Commit 0b7b0b5

Browse files
lshaw8317FrancescAlted
authored andcommitted
Too difficult to make general
1 parent e283649 commit 0b7b0b5

1 file changed

Lines changed: 60 additions & 48 deletions

File tree

src/blosc2/linalg.py

Lines changed: 60 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -115,61 +115,73 @@ def matmul(x1: blosc2.Array, x2: blosc2.NDArray, **kwargs: Any) -> blosc2.NDArra
115115
# multithreaded matmul
116116
# TODO: handle a) type promotion, b) padding (explicitly), c) (improved) >2D
117117
ops = (x1, x2, result)
118-
blocks = result.blocks
119118
all_ndarray = all(isinstance(value, blosc2.NDArray) and value.shape != () for value in ops)
120119
global try_miniexpr
121120

122121
# Use a local copy so we don't modify the global
123122
use_miniexpr = try_miniexpr
124-
if all_ndarray:
125-
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
123+
if 0 not in result.shape + x1.shape + x2.shape: # if any array is empty, return array of 0s
124+
if all_ndarray:
125+
if any(op.dtype != ops[0].dtype for op in ops): # TODO: Remove this condition
126+
use_miniexpr = False
127+
128+
# Just force same chunk/block shapes
129+
same_chunks = all(op.chunks == result.chunks for op in (x1, x2))
130+
same_blocks = all(op.blocks == result.blocks for op in (x1, x2))
131+
same_shape = all(op.shape == result.shape for op in (x1, x2))
132+
133+
use_miniexpr &= same_blocks & same_chunks & same_shape
134+
135+
# TODO: We can relax this to even just load according to result blockshape, but that's difficult.
136+
# Two easier cases are presented below
137+
# Case 1: Might want to restrict loading across chunk boundaries, in which case would require:
138+
# x1.chunks[-2] % result.blocks[-2] == 0
139+
# x2.chunks[-1] % result.blocks[-1] == 0
140+
# x2.chunks[-2] % x1.blocks[-1] == 0
141+
# Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
142+
# and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
143+
144+
# Case 2: Slightly easier to implement this maybe
145+
# Require that blocks are matmul compatible and broadcastable directly to result
146+
# (M, K) x (K, N) = (M, N)
147+
# so can load block-by-block for inputs and calculate block of output
148+
# Also need to avoid loading across chunk boundaries
149+
# chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
150+
# chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
151+
# chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
152+
# same_blocks = x2.blocks[-2] == x1.blocks[-1]
153+
# same_blocks &= x2.blocks[-1] == result.blocks[-1]
154+
# same_blocks &= result.blocks[-2] == x1.blocks[-2]
155+
# try:
156+
# result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
157+
# if not (same_blocks and chunks_aligned and result_blocks[:-2] == result.blocks[:-2]):
158+
# use_miniexpr = False
159+
# except ValueError:
160+
# use_miniexpr = False
161+
162+
use_miniexpr &= x1.dtype.kind in ("i", "f")
163+
use_miniexpr &= x2.dtype.kind in ("i", "f")
164+
use_miniexpr &= x1.dtype == x2.dtype
165+
166+
else:
126167
use_miniexpr = False
127168

128-
# TODO: In fact the following can be relaxed too, just need to load across block boundaries
129-
# Might want to restrict loading across chunk boundaries, in which case would require:
130-
# x1.chunks[-2] % result.blocks[-2] == 0
131-
# x2.chunks[-1] % result.blocks[-1] == 0
132-
# x2.chunks[-2] % x1.blocks[-1] == 0
133-
# Can then load in x1 as slices of size [result.blocks[-2], x1.blocks[-1]]
134-
# and x2 in slices of [x1.blocks[-1], result.blocks[-1]]
135-
136-
# Require that blocks are matmul compatible and broadcastable directly to result
137-
# (M, K) x (K, N) = (M, N)
138-
# so can load block-by-block for inputs and calculate block of output
139-
# Also need to avoid loading across chunk boundaries
140-
chunks_aligned = x1.chunks[-2] % x1.blocks[-2] == 0
141-
chunks_aligned &= x2.chunks[-1] % x2.blocks[-1] == 0
142-
chunks_aligned &= x2.chunks[-2] % x1.blocks[-1] == 0
143-
same_blocks = x2.blocks[-2] == x1.blocks[-1]
144-
same_blocks &= x2.blocks[-1] == result.blocks[-1]
145-
same_blocks &= result.blocks[-2] == x1.blocks[-2]
146-
try:
147-
result_blocks = np.broadcast_shapes(x1.blocks, x2.blocks)
148-
except ValueError:
149-
use_miniexpr = False
150-
if not (same_blocks and chunks_aligned and result_blocks[:-2] == blocks[:-2]):
151-
use_miniexpr = False
152-
153-
else:
154-
use_miniexpr = False
155-
156-
if use_miniexpr:
157-
prefilter_set = False
158-
try:
159-
result._set_pref_matmul({"x1": x1, "x2": x2}, fp_accuracy=blosc2.FPAccuracy.DEFAULT)
160-
prefilter_set = True
161-
# Data to compress is fetched from operands, so it can be uninitialized here
162-
data = np.empty(result.schunk.chunksize, dtype=np.uint8)
163-
for nchunk_out in range(result.schunk.nchunks):
164-
result.schunk.update_data(nchunk_out, data, copy=False)
165-
except Exception as e:
166-
raise Exception from e
167-
finally:
168-
if prefilter_set:
169-
result.schunk.remove_prefilter("miniexpr")
170-
else: # couldn't do multithreading
171-
print("multithreading failed :( ")
172-
if 0 not in result.shape + x1.shape + x2.shape: # if any array is empty, return array of 0s
169+
if use_miniexpr:
170+
prefilter_set = False
171+
try:
172+
result._set_pref_matmul({"x1": x1, "x2": x2}, fp_accuracy=blosc2.FPAccuracy.DEFAULT)
173+
prefilter_set = True
174+
# Data to compress is fetched from operands, so it can be uninitialized here
175+
data = np.empty(result.schunk.chunksize, dtype=np.uint8)
176+
for nchunk_out in range(result.schunk.nchunks):
177+
result.schunk.update_data(nchunk_out, data, copy=False)
178+
except Exception as e:
179+
raise Exception from e
180+
finally:
181+
if prefilter_set:
182+
result.schunk.remove_prefilter("miniexpr")
183+
else: # couldn't do multithreading
184+
print("multithreading failed :( ")
173185
p, q = result.chunks[-2:]
174186
r = x2.chunks[-1]
175187

0 commit comments

Comments
 (0)