@@ -391,14 +391,7 @@ def visit_Call(self, node): # noqa : C901
391391 # --- Parse keyword args ---
392392 kwargs = {}
393393 for kw in node .keywords :
394- if isinstance (kw .value , ast .Constant ):
395- kwargs [kw .arg ] = kw .value .value
396- elif isinstance (kw .value , ast .Tuple ):
397- kwargs [kw .arg ] = tuple (
398- e .value if isinstance (e , ast .Constant ) else self ._lookup_value (e ) for e in kw .value .elts
399- )
400- else :
401- kwargs [kw .arg ] = self ._lookup_value (kw .value )
394+ kwargs [kw .arg ] = self ._lookup_value (kw .value )
402395
403396 # ------- handle linear algebra ---------------
404397 if base_name in linalg_funcs :
@@ -539,17 +532,62 @@ def _eval_slice(self, node):
539532 else :
540533 raise ValueError (f"Unsupported slice expression: { ast .dump (node )} " )
541534
542- def _lookup_value (self , node ):
535+ def _lookup_value (self , node ): # noqa : C901
543536 """Look up a value in self.shapes if node is a variable name, else constant value."""
537+ # Name -> lookup in shapes mapping
544538 if isinstance (node , ast .Name ):
545539 return self .shapes .get (node .id , None )
546- elif isinstance (node , ast .Constant ):
540+
541+ # Constant -> return its value
542+ if isinstance (node , ast .Constant ):
547543 return node .value
548- elif isinstance (node , ast .Tuple ):
549- return tuple (e .value for e in node .elts )
550- else :
544+
545+ # Tuple of constants / expressions
546+ if isinstance (node , ast .Tuple ):
547+ vals = []
548+ for e in node .elts :
549+ v = self ._lookup_value (e )
550+ vals .append (v )
551+ return tuple (vals )
552+
553+ # Unary operations (e.g. -1)
554+ if isinstance (node , ast .UnaryOp ):
555+ # handle negative constants like -1
556+ if isinstance (node .op , ast .USub ):
557+ val = self ._lookup_value (node .operand )
558+ if isinstance (val , (int , float )):
559+ return - val
560+ # handle + (USub) if needed
561+ if isinstance (node .op , ast .UAdd ):
562+ return self ._lookup_value (node .operand )
551563 return None
552564
565+ # Simple binary ops with constant operands (e.g. 1+2)
566+ if isinstance (node , ast .BinOp ):
567+ left = self ._lookup_value (node .left )
568+ right = self ._lookup_value (node .right )
569+ if left is None or right is None :
570+ return None
571+ try :
572+ if isinstance (node .op , ast .Add ):
573+ return left + right
574+ if isinstance (node .op , ast .Sub ):
575+ return left - right
576+ if isinstance (node .op , ast .Mult ):
577+ return left * right
578+ if isinstance (node .op , ast .FloorDiv ):
579+ return left // right
580+ if isinstance (node .op , ast .Div ):
581+ return left / right
582+ if isinstance (node .op , ast .Mod ):
583+ return left % right
584+ except Exception :
585+ return None
586+ return None
587+
588+ # fallback
589+ return None
590+
553591
554592# --- Public API ---
555593def infer_shape (expr , shapes ):
0 commit comments