Skip to content

Commit 44c1b34

Browse files
committed
Updated with better experiments
1 parent 278acb1 commit 44c1b34

1 file changed

Lines changed: 51 additions & 13 deletions

File tree

examples/compare_enumeration.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
hs_enumerate_prob_grammar,
1212
bs_enumerate_prob_grammar,
1313
bps_enumerate_prob_grammar,
14+
as_enumerate_prob_grammar,
1415
ProgramEnumerator,
1516
auto_type,
1617
)
@@ -21,12 +22,13 @@
2122
import timeout_decorator
2223

2324
SEARCH_ALGOS = {
25+
"a_star": as_enumerate_prob_grammar,
2426
"bee_search": bs_enumerate_prob_grammar,
2527
"beap_search": bps_enumerate_prob_grammar,
2628
"heap_search": hs_enumerate_prob_grammar,
2729
"cd4": lambda x: cd(x, k=4),
28-
"cd12": lambda x: cd(x, k=12),
29-
"cd100": lambda x: cd(x, k=100),
30+
"cd16": lambda x: cd(x, k=16),
31+
"cd64": lambda x: cd(x, k=64),
3032
}
3133

3234
parser = argparse.ArgumentParser(
@@ -48,6 +50,11 @@
4850
parser.add_argument(
4951
dest="max_non_terminals", type=int, help="maximum number of non terminals"
5052
)
53+
parser.add_argument(
54+
dest="scaling",
55+
choices=["distance", "nonterminals", "derivations"],
56+
help="maximum number of non terminals",
57+
)
5158
parser.add_argument(
5259
"-t", "--timeout", type=int, default=300, help="timeout in seconds (default: 300)"
5360
)
@@ -61,9 +68,8 @@
6168
programs: int = parameters.n
6269
timeout: int = parameters.timeout
6370
seed: int = parameters.seed
64-
# max_rules: int = parameters.max_rules
71+
scaling: str = parameters.scaling
6572
max_non_terminals: int = parameters.max_non_terminals
66-
6773
file_name = output_file[: -len(".csv")] if output_file.endswith(".csv") else output_file
6874

6975

@@ -225,6 +231,41 @@ def fun():
225231

226232
# Main ====================================================================
227233

234+
235+
def gen_distance(non_terminals: int) -> CFG:
236+
syntax = {
237+
"1": "s1",
238+
}
239+
for i in range(2, non_terminals + 1):
240+
syntax[f"cast{i}"] = f"s{i} -> s1"
241+
syntax[f"s{i}"] = f"s{i}"
242+
syntax[f"+{i}"] = f"s1 -> s{i} -> s{i}"
243+
syntax[f"*{i}"] = f"s{i-1} -> s{i} -> s{i+1} -> s{i}"
244+
return CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1)
245+
246+
247+
def gen_nonterminals(non_terminals: int) -> CFG:
248+
syntax = {
249+
"1": "s1",
250+
}
251+
for i in range(2, non_terminals + 1):
252+
syntax[f"cast{i}"] = f"s{i} -> s1"
253+
syntax[f"s{i}"] = f"s{i}"
254+
syntax[f"+{i}"] = f"s1 -> s{i}"
255+
return CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1)
256+
257+
258+
def gen_derivations(non_terminals: int) -> CFG:
259+
syntax = {
260+
"1": "s1",
261+
}
262+
for i in range(2, non_terminals + 1):
263+
syntax[f"m{i}"] = f"s1 -> s1 -> s1"
264+
syntax[f"f{i}"] = f"s1 -> s1"
265+
syntax[f"s{i}"] = f"s1"
266+
return CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1)
267+
268+
228269
if __name__ == "__main__":
229270
# trace_rules = [
230271
# (
@@ -253,6 +294,11 @@ def fun():
253294
# )
254295
# save(trace_rules, file_name + "_rules.csv")
255296
# print("csv file was saved as:", file_name + "_rules.csv")
297+
gene = {
298+
"distance": gen_distance,
299+
"nonterminals": gen_nonterminals,
300+
"derivations": gen_derivations,
301+
}
256302
summary_trace = [
257303
(
258304
"search",
@@ -288,15 +334,7 @@ def fun():
288334
non_terminals_values.append(last)
289335
first = True
290336
for non_terminals in non_terminals_values:
291-
syntax = {
292-
"1": "s1",
293-
}
294-
for i in range(2, non_terminals + 1):
295-
syntax[f"cast{i}"] = f"s{i} -> s1"
296-
syntax[f"s{i}"] = f"s{i}"
297-
syntax[f"+{i}"] = f"s1 -> s{i} -> s{i}"
298-
syntax[f"*{i}"] = f"s{i-1} -> s{i} -> s{i+1} -> s{i}"
299-
cfg = CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1)
337+
cfg = gene[scaling](non_terminals)
300338
if seed < 0:
301339
pcfg = ProbDetGrammar.uniform(cfg)
302340
else:

0 commit comments

Comments
 (0)