1515 ProgramEnumerator ,
1616 auto_type ,
1717)
18+ from synth .syntax .grammars .det_grammar import DetGrammar
1819from synth .syntax .grammars .enumeration .constant_delay import (
1920 enumerate_prob_grammar as cd ,
2021)
2122import tqdm
2223import timeout_decorator
2324
2425SEARCH_ALGOS = {
25- "a_star" : as_enumerate_prob_grammar ,
26- "bee_search" : bs_enumerate_prob_grammar ,
26+ # "a_star": as_enumerate_prob_grammar,
27+ # "bee_search": bs_enumerate_prob_grammar,
2728 "beap_search" : bps_enumerate_prob_grammar ,
28- "heap_search" : hs_enumerate_prob_grammar ,
29- "cd4" : lambda x : cd (x , k = 4 ),
30- "cd16" : lambda x : cd (x , k = 16 ),
31- "cd64" : lambda x : cd (x , k = 64 ),
29+ # "heap_search": hs_enumerate_prob_grammar,
30+ # "cd4": lambda x: cd(x, k=4),
31+ # "cd16": lambda x: cd(x, k=16),
32+ # "cd64": lambda x: cd(x, k=64),
3233}
3334
3435parser = argparse .ArgumentParser (
@@ -163,6 +164,8 @@ def enumerative_search(
163164 Tuple [str , int , int , float , int , int , int , int ],
164165 List [Tuple [str , int , int , float , int , int , int , int ]],
165166]:
167+ import numpy as np
168+
166169 n = 0
167170 non_terminals = len (pcfg .rules )
168171 derivation_rules = sum (len (pcfg .rules [S ]) for S in pcfg .rules )
@@ -171,9 +174,11 @@ def enumerative_search(
171174 pbar = tqdm .tqdm (total = programs , desc = title or name , smoothing = 0 )
172175 enumerator = custom_enumerate (pcfg )
173176 gen = enumerator .generator ()
177+ det_g = pcfg .grammar
178+ assert isinstance (det_g , DetGrammar )
174179 program = 1
175- datum_each = 100000
176- target_generation_speed = 1000000
180+ datum_each = 10000
181+ target_generation_speed = 10000
177182 start = 0
178183 detailed = []
179184 try :
@@ -184,9 +189,15 @@ def fun():
184189 get_next = timeout_decorator .timeout (timeout , timeout_exception = StopIteration )(
185190 fun
186191 )
192+ last_multiset = None
193+ max_dist = - 1
187194 start = time .perf_counter_ns ()
188195 while program is not None :
189196 program = get_next ()
197+ new_m = det_g .to_multiset (program )
198+ if last_multiset is not None :
199+ max_dist = max (np .sum (np .abs (new_m - last_multiset )), max_dist )
200+ last_multiset = new_m
190201 n += 1
191202 if n % datum_each == 0 or n >= programs :
192203 used_time = time .perf_counter_ns () - start
@@ -200,7 +211,7 @@ def fun():
200211 (
201212 name ,
202213 non_terminals ,
203- derivation_rules ,
214+ max_dist ,
204215 used_time / 1e9 ,
205216 n ,
206217 enumerator .programs_in_queues (),
0 commit comments