Skip to content

Commit e9fa72e

Browse files
committed
Improved compress
1 parent 5247d07 commit e9fa72e

2 files changed

Lines changed: 28 additions & 3 deletions

File tree

synth/semantic/evaluator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def __init__(self, semantics: Dict[Primitive, Any], use_cache: bool = True) -> N
3535
# Statistics
3636
self._total_requests = 0
3737
self._cache_hits = 0
38+
self._dsl_constants: Dict[Any, Primitive] = {}
39+
for p, val in semantics.items():
40+
if len(p.type.arguments()) == 0:
41+
self._dsl_constants[__tuplify__(val)] = p
3842

3943
def compress(self, program: Program) -> Program:
4044
"""
@@ -50,6 +54,9 @@ def compress(self, program: Program) -> Program:
5054
self.use_cache = False
5155
value = self.eval(program, [])
5256
self.use_cache = before
57+
tval = __tuplify__(value)
58+
if tval in self._dsl_constants:
59+
return self._dsl_constants[tval]
5360
return Constant(program.type.returns(), value, True)
5461
else:
5562
return Function(program.function, args)

tests/semantic/test_evaluator.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@
2828
dsl = DSL(syntax)
2929
cfg = CFG.depth_constraint(dsl, FunctionType(INT, INT), max_depth)
3030

31+
other_syntax = {
32+
"+1": FunctionType(INT, INT),
33+
"0": INT,
34+
"2": INT,
35+
}
36+
37+
other_semantics = {
38+
"+1": lambda x: x + 1,
39+
"0": 0,
40+
"2": 2,
41+
}
42+
other_dsl = DSL(other_syntax)
43+
other_cfg = CFG.depth_constraint(other_dsl, FunctionType(INT, INT), max_depth)
44+
3145

3246
def test_eval() -> None:
3347
eval = DSLEvaluator(dsl.instantiate_semantics(semantics))
@@ -82,9 +96,13 @@ def test_use_cache() -> None:
8296

8397

8498
def test_compress() -> None:
85-
eval = DSLEvaluator(dsl.instantiate_semantics(semantics))
86-
p = dsl.auto_parse_program("(+1 0)")
87-
pp = dsl.auto_parse_program("1", constants={"1": (INT, 1)})
99+
eval = DSLEvaluator(other_dsl.instantiate_semantics(other_semantics))
100+
p = other_dsl.auto_parse_program("(+1 0)")
101+
pp = other_dsl.auto_parse_program("1", constants={"1": (INT, 1)})
88102
c = eval.compress(p)
89103
assert c != p
90104
assert c == pp
105+
p = other_dsl.auto_parse_program("(+1 (+1 0))")
106+
pp = other_dsl.auto_parse_program("2")
107+
c = eval.compress(p)
108+
assert c == pp

0 commit comments

Comments
 (0)