|
11 | 11 | hs_enumerate_prob_grammar, |
12 | 12 | bs_enumerate_prob_grammar, |
13 | 13 | bps_enumerate_prob_grammar, |
| 14 | + as_enumerate_prob_grammar, |
14 | 15 | ProgramEnumerator, |
15 | 16 | auto_type, |
16 | 17 | ) |
|
21 | 22 | import timeout_decorator |
22 | 23 |
|
23 | 24 | SEARCH_ALGOS = { |
| 25 | + "a_star": as_enumerate_prob_grammar, |
24 | 26 | "bee_search": bs_enumerate_prob_grammar, |
25 | 27 | "beap_search": bps_enumerate_prob_grammar, |
26 | 28 | "heap_search": hs_enumerate_prob_grammar, |
27 | 29 | "cd4": lambda x: cd(x, k=4), |
28 | | - "cd12": lambda x: cd(x, k=12), |
29 | | - "cd100": lambda x: cd(x, k=100), |
| 30 | + "cd16": lambda x: cd(x, k=16), |
| 31 | + "cd64": lambda x: cd(x, k=64), |
30 | 32 | } |
31 | 33 |
|
32 | 34 | parser = argparse.ArgumentParser( |
|
48 | 50 | parser.add_argument( |
49 | 51 | dest="max_non_terminals", type=int, help="maximum number of non terminals" |
50 | 52 | ) |
| 53 | +parser.add_argument( |
| 54 | + dest="scaling", |
| 55 | + choices=["distance", "nonterminals", "derivations"], |
| 56 | + help="maximum number of non terminals", |
| 57 | +) |
51 | 58 | parser.add_argument( |
52 | 59 | "-t", "--timeout", type=int, default=300, help="timeout in seconds (default: 300)" |
53 | 60 | ) |
|
61 | 68 | programs: int = parameters.n |
62 | 69 | timeout: int = parameters.timeout |
63 | 70 | seed: int = parameters.seed |
64 | | -# max_rules: int = parameters.max_rules |
| 71 | +scaling: str = parameters.scaling |
65 | 72 | max_non_terminals: int = parameters.max_non_terminals |
66 | | - |
67 | 73 | file_name = output_file[: -len(".csv")] if output_file.endswith(".csv") else output_file |
68 | 74 |
|
69 | 75 |
|
@@ -225,6 +231,41 @@ def fun(): |
225 | 231 |
|
226 | 232 | # Main ==================================================================== |
227 | 233 |
|
| 234 | + |
| 235 | +def gen_distance(non_terminals: int) -> CFG: |
| 236 | + syntax = { |
| 237 | + "1": "s1", |
| 238 | + } |
| 239 | + for i in range(2, non_terminals + 1): |
| 240 | + syntax[f"cast{i}"] = f"s{i} -> s1" |
| 241 | + syntax[f"s{i}"] = f"s{i}" |
| 242 | + syntax[f"+{i}"] = f"s1 -> s{i} -> s{i}" |
| 243 | + syntax[f"*{i}"] = f"s{i-1} -> s{i} -> s{i+1} -> s{i}" |
| 244 | + return CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1) |
| 245 | + |
| 246 | + |
| 247 | +def gen_nonterminals(non_terminals: int) -> CFG: |
| 248 | + syntax = { |
| 249 | + "1": "s1", |
| 250 | + } |
| 251 | + for i in range(2, non_terminals + 1): |
| 252 | + syntax[f"cast{i}"] = f"s{i} -> s1" |
| 253 | + syntax[f"s{i}"] = f"s{i}" |
| 254 | + syntax[f"+{i}"] = f"s1 -> s{i}" |
| 255 | + return CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1) |
| 256 | + |
| 257 | + |
| 258 | +def gen_derivations(non_terminals: int) -> CFG: |
| 259 | + syntax = { |
| 260 | + "1": "s1", |
| 261 | + } |
| 262 | + for i in range(2, non_terminals + 1): |
| 263 | + syntax[f"m{i}"] = f"s1 -> s1 -> s1" |
| 264 | + syntax[f"f{i}"] = f"s1 -> s1" |
| 265 | + syntax[f"s{i}"] = f"s1" |
| 266 | + return CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1) |
| 267 | + |
| 268 | + |
228 | 269 | if __name__ == "__main__": |
229 | 270 | # trace_rules = [ |
230 | 271 | # ( |
@@ -253,6 +294,11 @@ def fun(): |
253 | 294 | # ) |
254 | 295 | # save(trace_rules, file_name + "_rules.csv") |
255 | 296 | # print("csv file was saved as:", file_name + "_rules.csv") |
| 297 | + gene = { |
| 298 | + "distance": gen_distance, |
| 299 | + "nonterminals": gen_nonterminals, |
| 300 | + "derivations": gen_derivations, |
| 301 | + } |
256 | 302 | summary_trace = [ |
257 | 303 | ( |
258 | 304 | "search", |
@@ -288,15 +334,7 @@ def fun(): |
288 | 334 | non_terminals_values.append(last) |
289 | 335 | first = True |
290 | 336 | for non_terminals in non_terminals_values: |
291 | | - syntax = { |
292 | | - "1": "s1", |
293 | | - } |
294 | | - for i in range(2, non_terminals + 1): |
295 | | - syntax[f"cast{i}"] = f"s{i} -> s1" |
296 | | - syntax[f"s{i}"] = f"s{i}" |
297 | | - syntax[f"+{i}"] = f"s1 -> s{i} -> s{i}" |
298 | | - syntax[f"*{i}"] = f"s{i-1} -> s{i} -> s{i+1} -> s{i}" |
299 | | - cfg = CFG.infinite(DSL(auto_type(syntax)), auto_type("s1->s1"), n_gram=1) |
| 337 | + cfg = gene[scaling](non_terminals) |
300 | 338 | if seed < 0: |
301 | 339 | pcfg = ProbDetGrammar.uniform(cfg) |
302 | 340 | else: |
|
0 commit comments