@@ -40,7 +40,7 @@ def __init__(self, semantics: Dict[Primitive, Any], use_cache: bool = True) -> N
4040 if len (p .type .arguments ()) == 0 :
4141 self ._dsl_constants [__tuplify__ (val )] = p
4242
43- def compress (self , program : Program ) -> Program :
43+ def compress (self , program : Program , allow_constants : bool = True ) -> Program :
4444 """
4545 Return a semantically equivalent version of the program by evaluating constant expressions.
4646 Note for data saving/loading purposes, partial applications are left untouched.
@@ -55,12 +55,15 @@ def compress(self, program: Program) -> Program:
5555 value = self .eval (program , [])
5656 self .use_cache = before
5757 # Cancel compression of callable
58- if isinstance (value , Callable ): # type: ignore
58+ if isinstance (value , Callable ): # type: ignore
5959 return Function (program .function , args )
6060 tval = __tuplify__ (value )
6161 if tval in self ._dsl_constants :
6262 return self ._dsl_constants [tval ]
63- return Constant (program .type .returns (), value , True )
63+ if allow_constants :
64+ return Constant (program .type .returns (), value , True )
65+ else :
66+ return Function (program .function , args )
6467 else :
6568 return Function (program .function , args )
6669 else :
0 commit comments