Skip to content

Commit 0dacf9f

Browse files
committed
Improve shape-parser for various args
1 parent a7096e5 commit 0dacf9f

3 files changed

Lines changed: 54 additions & 15 deletions

File tree

src/blosc2/shape_utils.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,7 @@ def visit_Call(self, node): # noqa : C901
391391
# --- Parse keyword args ---
392392
kwargs = {}
393393
for kw in node.keywords:
394-
if isinstance(kw.value, ast.Constant):
395-
kwargs[kw.arg] = kw.value.value
396-
elif isinstance(kw.value, ast.Tuple):
397-
kwargs[kw.arg] = tuple(
398-
e.value if isinstance(e, ast.Constant) else self._lookup_value(e) for e in kw.value.elts
399-
)
400-
else:
401-
kwargs[kw.arg] = self._lookup_value(kw.value)
394+
kwargs[kw.arg] = self._lookup_value(kw.value)
402395

403396
# ------- handle linear algebra ---------------
404397
if base_name in linalg_funcs:
@@ -539,17 +532,62 @@ def _eval_slice(self, node):
539532
else:
540533
raise ValueError(f"Unsupported slice expression: {ast.dump(node)}")
541534

542-
def _lookup_value(self, node):
535+
def _lookup_value(self, node): # noqa : C901
543536
"""Look up a value in self.shapes if node is a variable name, else constant value."""
537+
# Name -> lookup in shapes mapping
544538
if isinstance(node, ast.Name):
545539
return self.shapes.get(node.id, None)
546-
elif isinstance(node, ast.Constant):
540+
541+
# Constant -> return its value
542+
if isinstance(node, ast.Constant):
547543
return node.value
548-
elif isinstance(node, ast.Tuple):
549-
return tuple(e.value for e in node.elts)
550-
else:
544+
545+
# Tuple of constants / expressions
546+
if isinstance(node, ast.Tuple):
547+
vals = []
548+
for e in node.elts:
549+
v = self._lookup_value(e)
550+
vals.append(v)
551+
return tuple(vals)
552+
553+
# Unary operations (e.g. -1)
554+
if isinstance(node, ast.UnaryOp):
555+
# handle negative constants like -1
556+
if isinstance(node.op, ast.USub):
557+
val = self._lookup_value(node.operand)
558+
if isinstance(val, (int, float)):
559+
return -val
560+
# handle + (USub) if needed
561+
if isinstance(node.op, ast.UAdd):
562+
return self._lookup_value(node.operand)
551563
return None
552564

565+
# Simple binary ops with constant operands (e.g. 1+2)
566+
if isinstance(node, ast.BinOp):
567+
left = self._lookup_value(node.left)
568+
right = self._lookup_value(node.right)
569+
if left is None or right is None:
570+
return None
571+
try:
572+
if isinstance(node.op, ast.Add):
573+
return left + right
574+
if isinstance(node.op, ast.Sub):
575+
return left - right
576+
if isinstance(node.op, ast.Mult):
577+
return left * right
578+
if isinstance(node.op, ast.FloorDiv):
579+
return left // right
580+
if isinstance(node.op, ast.Div):
581+
return left / right
582+
if isinstance(node.op, ast.Mod):
583+
return left % right
584+
except Exception:
585+
return None
586+
return None
587+
588+
# fallback
589+
return None
590+
553591

554592
# --- Public API ---
555593
def infer_shape(expr, shapes):

tests/ndarray/test_lazyexpr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1702,7 +1702,9 @@ def test_lazylinalg():
17021702
npres = np.squeeze(npD, -1)
17031703
assert out.shape == npres.shape
17041704
np.testing.assert_array_almost_equal(out[()], npres)
1705-
1705+
# refresh D since squeeze is in-place
1706+
s = shapes["D"]
1707+
D = blosc2.linspace(0, np.prod(s), shape=s)
17061708
out = blosc2.lazyexpr("D.squeeze(axis=-1)")
17071709
npres = np.squeeze(npD, -1)
17081710
assert out.shape == npres.shape

tests/ndarray/test_squeeze.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
((23, 1, 1, 34), (20, 1, 1, 20), None, 1234, 2),
2020
((80, 1, 51, 60, 1), None, (6, 1, 6, 26, 1), 3.333, 4),
2121
((1, 1, 1), None, None, True, (1, 2)),
22-
((1, 1, 1), None, None, True, None),
2322
],
2423
)
2524
def test_squeeze(shape, chunks, blocks, fill_value, axis):

0 commit comments

Comments
 (0)