Skip to content

Commit 88fb47c

Browse files
committed
Fixes to infer_shape for tensordot
1 parent b5c1c3e commit 88fb47c

2 files changed

Lines changed: 7 additions & 1 deletion

File tree

src/blosc2/shape_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,8 @@ def _lookup_value(self, node):
542542
return self.shapes.get(node.id, None)
543543
elif isinstance(node, ast.Constant):
544544
return node.value
545+
elif isinstance(node, ast.Tuple):
546+
return tuple(e.value for e in node.elts)
545547
else:
546548
return None
547549

tests/ndarray/test_lazyexpr.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1714,10 +1714,14 @@ def test_lazylinalg():
17141714
np.testing.assert_array_almost_equal(out[()], npres)
17151715

17161716
# --- tensordot ---
1717-
out = blosc2.lazyexpr("tensordot(A, B, axes=1)")
1717+
out = blosc2.lazyexpr("tensordot(A, B, axes=1)") # test with int axes
17181718
npres = np.tensordot(npA, npB, axes=1)
17191719
assert out.shape == npres.shape
17201720
np.testing.assert_array_almost_equal(out[()], npres)
1721+
out = blosc2.lazyexpr("tensordot(A, B, axes=((1,) , (0,)))") # test with tuple axes
1722+
npres = np.tensordot(npA, npB, axes=((1,), (0,)))
1723+
assert out.shape == npres.shape
1724+
np.testing.assert_array_almost_equal(out[()], npres)
17211725

17221726
# --- vecdot ---
17231727
out = blosc2.lazyexpr("vecdot(x, y)")

0 commit comments

Comments
 (0)