Skip to content

Commit d11d73b

Browse files
committed
Fixed U-PCFGs prediction
1 parent abce696 commit d11d73b

3 files changed

Lines changed: 33 additions & 2 deletions

File tree

examples/pbe/model_prediction.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from synth.filter import add_dfta_constraints
2121
from synth.syntax import CFG, UCFG, ProbDetGrammar, ProbUGrammar, DSL, Type
2222
from synth.utils import load_object, save_object
23+
from synth.utils.data_storage import legacy_save_object
2324

2425

2526
parser = argparse.ArgumentParser(
@@ -187,7 +188,10 @@ def produce_pcfgs(
187188
# ================================
188189
def save_pcfgs() -> None:
189190
print("Saving PCFGs...", end="")
190-
save_object(file, pcfgs)
191+
if constrained:
192+
legacy_save_object(file, pcfgs)
193+
else:
194+
save_object(file, pcfgs, compress_level=9)
191195
print("done!")
192196

193197
atexit.register(save_pcfgs)

examples/pbe/solve.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from synth.syntax.program import Program
3535
from synth.task import Task
3636
from synth.utils import load_object
37+
from synth.utils.data_storage import legacy_load_object
3738
from synth.utils.import_utils import import_file_function
3839
from synth.pbe.solvers import (
3940
NaivePBESolver,
@@ -296,7 +297,10 @@ def load_pcfgs(
296297
pcfg_file: Optional[str],
297298
) -> Union[List[ProbDetGrammar], List[ProbUGrammar]]:
298299
if pcfg_file is not None:
299-
return load_object(pcfg_file)
300+
if constrained:
301+
return legacy_load_object(pcfg_file)
302+
else:
303+
return load_object(pcfg_file)
300304
pcfgs = []
301305
for task in full_dataset:
302306
constant_types = set()

synth/utils/data_storage.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,26 @@ def save_object(
2929
if optimize:
3030
content = pickletools.optimize(content)
3131
fd.write(content)
32+
33+
34+
35+
def legacy_load_object(
36+
path: str, **kwargs: Any
37+
) -> Any:
38+
"""
39+
DEPRECATED
40+
Load an arbitrary object from the specified file.
41+
"""
42+
with open(path, "rb") as fd:
43+
return pickle.load(fd)
44+
45+
46+
def legacy_save_object(
47+
path: str, obj: Any, **kwargs: Any
48+
) -> None:
49+
"""
50+
DEPRECATED
51+
Save an arbitrary object to the specified path.
52+
"""
53+
with open(path, "wb") as fd:
54+
pickle.dump(obj, fd)

0 commit comments

Comments
 (0)