Skip to content

Commit 34b3fd8

Browse files
committed
Implemented compress which removes constant part in programs
1 parent 6bab7ee commit 34b3fd8

2 files changed

Lines changed: 25 additions & 0 deletions

File tree

examples/pbe/dataset_generator_unique.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from synth import Dataset, PBE
1111
from synth.pbe.task_generator import TaskGenerator
12+
from synth.semantic import DSLEvaluator
1213
from synth.utils import chrono
1314
from synth.syntax import CFG, Type, Program
1415

@@ -156,6 +157,8 @@ def generate_programs_and_samples_for(
156157
prog, unique = task_generator.generate_program(tr)
157158
while not unique:
158159
prog, unique = task_generator.generate_program(tr)
160+
if isinstance(task_generator.evaluator, DSLEvaluator):
161+
prog = task_generator.evaluator.compress(prog)
159162
programs.add(prog)
160163
# Phase 2
161164
samples, equiv = generate_samples_for(
@@ -173,6 +176,8 @@ def generate_programs_and_samples_for(
173176
prog, unique = task_generator.generate_program(tr)
174177
while not unique:
175178
prog, unique = task_generator.generate_program(tr)
179+
if isinstance(task_generator.evaluator, DSLEvaluator):
180+
prog = task_generator.evaluator.compress(prog)
176181
# Compute semantic hash
177182
cl = None
178183
has_none = False

synth/semantic/evaluator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,26 @@ def __init__(self, semantics: Dict[Primitive, Any], use_cache: bool = True) -> N
3636
self._total_requests = 0
3737
self._cache_hits = 0
3838

39+
def compress(self, program: Program) -> Program:
40+
"""
41+
Return a semantically equivalent version of the program by evaluating constant expressions.
42+
Note for data saving/loading purposes, partial applications are left untouched.
43+
"""
44+
if isinstance(program, Function):
45+
args = [self.compress(p) for p in program.arguments]
46+
if len(program.type.returns().arguments()) == 0 and all(
47+
a.is_constant() for a in args
48+
):
49+
before = self.use_cache
50+
self.use_cache = False
51+
value = self.eval(program, [])
52+
self.use_cache = before
53+
return Constant(program.type.returns(), value, True)
54+
else:
55+
return Function(program.function, args)
56+
else:
57+
return program
58+
3959
def eval(self, program: Program, input: List) -> Any:
4060
key = __tuplify__(input)
4161
if self.use_cache and key not in self._cache:

0 commit comments

Comments
 (0)