Skip to content

Commit b4538d4

Browse files
committed
fixed mypy issues
1 parent f5ac55d commit b4538d4

5 files changed

Lines changed: 54 additions & 16 deletions

File tree

examples/compare_enumeration.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,21 @@
1515
ProgramEnumerator,
1616
auto_type,
1717
)
18+
from synth.syntax.grammars.det_grammar import DetGrammar
1819
from synth.syntax.grammars.enumeration.constant_delay import (
1920
enumerate_prob_grammar as cd,
2021
)
2122
import tqdm
2223
import timeout_decorator
2324

2425
SEARCH_ALGOS = {
25-
"a_star": as_enumerate_prob_grammar,
26-
"bee_search": bs_enumerate_prob_grammar,
26+
# "a_star": as_enumerate_prob_grammar,
27+
# "bee_search": bs_enumerate_prob_grammar,
2728
"beap_search": bps_enumerate_prob_grammar,
28-
"heap_search": hs_enumerate_prob_grammar,
29-
"cd4": lambda x: cd(x, k=4),
30-
"cd16": lambda x: cd(x, k=16),
31-
"cd64": lambda x: cd(x, k=64),
29+
# "heap_search": hs_enumerate_prob_grammar,
30+
# "cd4": lambda x: cd(x, k=4),
31+
# "cd16": lambda x: cd(x, k=16),
32+
# "cd64": lambda x: cd(x, k=64),
3233
}
3334

3435
parser = argparse.ArgumentParser(
@@ -163,6 +164,8 @@ def enumerative_search(
163164
Tuple[str, int, int, float, int, int, int, int],
164165
List[Tuple[str, int, int, float, int, int, int, int]],
165166
]:
167+
import numpy as np
168+
166169
n = 0
167170
non_terminals = len(pcfg.rules)
168171
derivation_rules = sum(len(pcfg.rules[S]) for S in pcfg.rules)
@@ -171,9 +174,11 @@ def enumerative_search(
171174
pbar = tqdm.tqdm(total=programs, desc=title or name, smoothing=0)
172175
enumerator = custom_enumerate(pcfg)
173176
gen = enumerator.generator()
177+
det_g = pcfg.grammar
178+
assert isinstance(det_g, DetGrammar)
174179
program = 1
175-
datum_each = 100000
176-
target_generation_speed = 1000000
180+
datum_each = 10000
181+
target_generation_speed = 10000
177182
start = 0
178183
detailed = []
179184
try:
@@ -184,9 +189,15 @@ def fun():
184189
get_next = timeout_decorator.timeout(timeout, timeout_exception=StopIteration)(
185190
fun
186191
)
192+
last_multiset = None
193+
max_dist = -1
187194
start = time.perf_counter_ns()
188195
while program is not None:
189196
program = get_next()
197+
new_m = det_g.to_multiset(program)
198+
if last_multiset is not None:
199+
max_dist = max(np.sum(np.abs(new_m - last_multiset)), max_dist)
200+
last_multiset = new_m
190201
n += 1
191202
if n % datum_each == 0 or n >= programs:
192203
used_time = time.perf_counter_ns() - start
@@ -200,7 +211,7 @@ def fun():
200211
(
201212
name,
202213
non_terminals,
203-
derivation_rules,
214+
max_dist,
204215
used_time / 1e9,
205216
n,
206217
enumerator.programs_in_queues(),

examples/plot_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -246,16 +246,16 @@ def plot_dist(
246246
methods: Dict[str, Dict[int, List]],
247247
y_data: Tuple[int, str],
248248
x_axis_name: str,
249+
nbins: int = 5,
249250
) -> None:
250251
width = 1.0
251252
data_length = 0
252253
a_index, a_name = y_data
253254
max_a = max(
254-
max(max([y[a_index] for y in x]) for x in seed_dico.values())
255+
max(max(y[a_index] for y in x) for x in seed_dico.values())
255256
for seed_dico in methods.values()
256257
)
257258
bottom = None
258-
nbins = 5
259259
bins = [max_a]
260260
while len(bins) <= nbins:
261261
bins.insert(0, np.sqrt(bins[0] + 1))

synth/pbe/solvers/pbe_solver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _close_task_solving_(
7272

7373
def solve(
7474
self, task: Task[PBE], enumerator: ProgramEnumerator[None], timeout: float = 60
75-
) -> Generator[Program, bool, None]:
75+
) -> Generator[Program, None, bool]:
7676
"""
7777
Solve the given task by enumerating programs with the given enumerator.
7878
When the timeout is reached, this function returns.
@@ -101,6 +101,7 @@ def solve(
101101
except StopIteration as e:
102102
self._close_task_solving_(task, enumerator, time, False, program)
103103
raise e
104+
return False
104105

105106
def _test_(self, task: Task[PBE], program: Program) -> bool:
106107
"""

synth/pbe/solvers/restart_pbe_solver.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _close_task_solving_(
6161

6262
def solve(
6363
self, task: Task[PBE], enumerator: ProgramEnumerator[None], timeout: float = 60
64-
) -> Generator[Program, bool, None]:
64+
) -> Generator[Program, None, bool]:
6565
with chrono.clock(f"solve.{self.name()}.{self.subsolver.name()}") as c: # type: ignore
6666
self._enumerator = enumerator
6767
self._init_task_solving_(task, self._enumerator, timeout)
@@ -73,15 +73,15 @@ def solve(
7373
self._close_task_solving_(
7474
task, self._enumerator, time, False, program
7575
)
76-
return
76+
return False
7777
self._programs += 1
7878
if self._test_(task, program):
7979
should_stop = yield program
8080
if should_stop:
8181
self._close_task_solving_(
8282
task, self._enumerator, time, True, program
8383
)
84-
return
84+
return True
8585
self._score = self.subsolver._score
8686
# Saves data
8787
if self._score > 0:
@@ -92,6 +92,7 @@ def solve(
9292
self._enumerator = self._restart_(self._enumerator)
9393
gen = self._enumerator.generator()
9494
program = next(gen)
95+
return False
9596

9697
def _should_restart_(self) -> bool:
9798
return self.restart_criterion(self)

synth/syntax/grammars/det_grammar.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Generic,
1212
)
1313
from functools import lru_cache
14-
import copy
14+
import numpy as np
1515

1616
from synth.syntax.grammars.grammar import DerivableProgram, Grammar
1717
from synth.syntax.program import Constant, Function, Primitive, Program, Variable
@@ -23,6 +23,13 @@
2323
T = TypeVar("T")
2424

2525

26+
def __tuplify__(element: Any) -> Any:
27+
if isinstance(element, (List, Tuple)):
28+
return tuple(__tuplify__(x) for x in element)
29+
else:
30+
return element
31+
32+
2633
class DetGrammar(Grammar, ABC, Generic[U, V, W]):
2734
"""
2835
Represents a deterministic grammar.
@@ -54,6 +61,14 @@ def __init__(
5461
self.type_request = self._guess_type_request_()
5562
if clean:
5663
self.clean()
64+
self._derivation2index = {}
65+
self._index2derivation = []
66+
67+
for S in self.rules:
68+
for P, args in self.rules[S].items():
69+
elem = __tuplify__((S, P, args))
70+
self._derivation2index[elem] = len(self._index2derivation)
71+
self._index2derivation.append(elem)
5772

5873
@lru_cache()
5974
def primitives_used(self) -> Set[Primitive]:
@@ -83,6 +98,16 @@ def __str__(self) -> str:
8398
s += " {}\n".format(self.__rule_to_str__(P, out))
8499
return s
85100

101+
def to_multiset(self, program: Program) -> np.ndarray:
102+
out = np.zeros((len(self._derivation2index)))
103+
104+
def reduce(acc, S, P, args):
105+
elem = __tuplify__((S, P, args))
106+
out[self._derivation2index[elem]] += 1
107+
108+
self.reduce_derivations(reduce, None, program, None)
109+
return out
110+
86111
def __repr__(self) -> str:
87112
return self.__str__()
88113

0 commit comments

Comments
 (0)