11import atexit
2+ from collections import defaultdict
23import os
34import sys
45from typing import Callable , Iterable , List , Optional , Tuple
1314from torch .nn .utils .rnn import PackedSequence
1415
1516from dsl_loader import add_dsl_choice_arg , load_DSL
17+ from examples .pbe .transduction .knowledge_graph .kg_path_finder import (
18+ build_wrapper ,
19+ choose_best_path ,
20+ find_paths_from_level ,
21+ )
22+ from examples .pbe .transduction .knowledge_graph .preprocess_tasks import sketch
1623
1724from synth import Dataset , PBE , Task
1825from synth .nn import (
2431from synth .pbe import IOEncoder
2532from synth .semantic import DSLEvaluator
2633from synth .semantic .evaluator import DSLEvaluatorWithConstant
27- from synth .specification import PBEWithConstants
34+ from synth .specification import Example , PBEWithConstants
2835from synth .syntax import (
2936 CFG ,
3037 ProbDetGrammar ,
3441 Program ,
3542)
3643from synth .syntax .grammars .heap_search import HSEnumerator
44+ from synth .syntax .program import Function , Primitive , Variable
45+ from synth .syntax .type_system import STRING , Arrow
3746from synth .utils import chrono
3847
3948import argparse
@@ -199,7 +208,10 @@ def produce_pcfgs(
199208 max_depth = max (task .solution .depth () for task in full_dataset )
200209 else :
201210 max_depth = 10 # TODO: set as parameter
202- cfgs = [CFG .depth_constraint (dsl , t , max_depth ) for t in all_type_requests ]
211+ cfgs = [
212+ CFG .depth_constraint (dsl , t , max_depth , min_variable_depth = 0 )
213+ for t in all_type_requests
214+ ]
203215
204216 class MyPredictor (nn .Module ):
205217 def __init__ (self , size : int ) -> None :
@@ -364,10 +376,9 @@ def constants_injector(
364376 constants_out = task .specification .constants_out
365377 if len (constants_out ) == 0 :
366378 constants_out .append ("" )
367- name = task .metadata ["name" ]
368- program = task .solution
369- if program == None :
370- return (False , time , programs , None , None )
379+ # program = task.solution
380+ # if program == None:
381+ # return (False, time, programs, None, None)
371382 with chrono .clock ("search.constant_injector" ) as c :
372383
373384 # print("\n-----------------------")
@@ -398,19 +409,6 @@ def constants_injector(
398409 if not found :
399410 break
400411 if found :
401- # print("Solution found.\n")
402- # print("\t", program)
403- # print(
404- # "\nWorking for all ",
405- # counter,
406- # "/",
407- # len(task.specification.examples),
408- # " examples in ",
409- # time,
410- # "/",
411- # task_timeout,
412- # "s.",
413- # )
414412 return (
415413 True ,
416414 c .elapsed_time (),
@@ -421,15 +419,192 @@ def constants_injector(
421419 return (False , time , programs , None , None )
422420
423421
422+ def sketched_base (
423+ evaluator : DSLEvaluator ,
424+ task : Task [PBE ],
425+ pcfg : ProbDetGrammar ,
426+ custom_enumerate : Callable [[ProbDetGrammar ], HSEnumerator ],
427+ ) -> Tuple [bool , float , int , Optional [Program ]]:
428+ programs = 0
429+ global task_timeout
430+ if task .metadata .get ("constants" , None ) is not None :
431+ original_timeout = task_timeout
432+ verbose = False
433+ # (
434+ # task.metadata["constant_post_processing"] == 0
435+ # and task.metadata["constant_detection"] == 0
436+ # and task.metadata["knowledge_graph_relationship"] > 0
437+ # )
438+ if verbose :
439+ print ("should solve:" , task .metadata .get ("name" , "???" ))
440+ with chrono .clock ("additional" ) as c :
441+ wrapper = build_wrapper (
442+ "http://192.168.1.20:9999/blazegraph/namespace/kb/sparql"
443+ )
444+ constants = task .metadata .get ("constants" , None )
445+ constants_in = task .metadata .get ("constants_in" , [])
446+ pbe = task .specification
447+ new_pseudo_tasks = defaultdict (lambda : defaultdict (list ))
448+ # print("working on:", task.metadata["name"])
449+ # print("constants out.:", constants)
450+ # print("constants inp.:", constants_in)
451+ true_inputs = (
452+ [
453+ sketch (pbe .examples [i ].inputs [0 ], constants_in )
454+ for i in range (len (pbe .examples ))
455+ ]
456+ if constants_in
457+ else [pbe .examples [i ].inputs for i in range (len (pbe .examples ))]
458+ )
459+ # print("true_inputs:", true_inputs)
460+ n = len (true_inputs [0 ])
461+ for i in range (len (pbe .examples )):
462+ subtasks = sketch (pbe .examples [i ].output , constants )
463+ for j in range (len (subtasks )):
464+ for k in range (n ):
465+ new_pseudo_tasks [j ][k ].append ((true_inputs [i ][k ], subtasks [j ]))
466+ solution_part = []
467+ prob = 1
468+ for j , possibles in new_pseudo_tasks .items ():
469+ any_solved = False
470+ relevant_alternatives = {
471+ k : pairs
472+ for k , pairs in possibles .items ()
473+ if not all (len (out ) == 0 for _ , out in pairs )
474+ and not all (len (inp ) == 0 for inp , _ in pairs )
475+ }
476+ subn = len (relevant_alternatives )
477+ if subn == 0 :
478+ continue
479+ # print(
480+ # f"\t\tpart[{j}] before:{possibles}")
481+ # print(
482+ # f"\t\tpart[{j}] before:{len(possibles)} after:{len(relevant_alternatives)}")
483+ for k , pairs in relevant_alternatives .items ():
484+ # print("\tsub task:", pairs)
485+ d = task .metadata ["knowledge_graph_relationship" ] - 1
486+ paths = find_paths_from_level (pairs , wrapper , d )
487+ # print("\t\tfound paths:", paths)
488+ if paths :
489+ any_solved = True
490+ if len (paths ) > 1 :
491+ paths = [choose_best_path (paths , pairs , wrapper )]
492+ custom_input = Variable (0 , STRING )
493+ if not (k == 0 and k + 1 >= len (constants_in )):
494+ custom_input = Function (
495+ Primitive (
496+ f"between { constants_in [k ] if k > 0 else 'start' } and { constants_in [k + 1 ] if k + 1 < len (constants_in ) else 'end' } " ,
497+ Arrow (STRING , STRING ),
498+ ),
499+ [custom_input ],
500+ )
501+ solution_part .append (
502+ Function (
503+ Primitive (
504+ "start->" + "->" .join (paths [0 ]) + "->end" ,
505+ Arrow (STRING , STRING ),
506+ ),
507+ [custom_input ],
508+ )
509+ )
510+ if verbose :
511+ print (
512+ "\t result:" , "start->" + "->" .join (paths [0 ]) + "->end"
513+ )
514+ else :
515+ sub_task = Task (
516+ task .type_request ,
517+ PBE (
518+ [
519+ Example ([pairs [i ][0 ]], pairs [i ][1 ])
520+ for i in range (len (pbe .examples ))
521+ ],
522+ ),
523+ )
524+ task_timeout = original_timeout - c .elapsed_time ()
525+ task_timeout /= subn
526+ if verbose :
527+ print (
528+ "\t solving with timeout" ,
529+ task_timeout ,
530+ "s :" ,
531+ sub_task .specification .examples ,
532+ )
533+
534+ (
535+ solved ,
536+ _ ,
537+ enumerated ,
538+ partial_sol ,
539+ part_prob ,
540+ ) = base (evaluator , sub_task , pcfg , custom_enumerate )
541+ task_timeout = original_timeout
542+ if verbose :
543+ print ("\t result:" , solved , partial_sol )
544+ if c .elapsed_time () >= task_timeout :
545+ return (False , c .elapsed_time (), programs , None , None )
546+ if solved :
547+ any_solved = True
548+ prob *= part_prob
549+ solution_part .append (partial_sol )
550+ programs += enumerated
551+ if any_solved :
552+ break
553+ if not any_solved :
554+ return False , c .elapsed_time (), programs , None , None
555+ # Convert back to a program
556+ some_output : str = pbe .examples [0 ].output
557+ start_cste = len (constants ) > 0 and some_output .startswith (constants [0 ])
558+ i = 0
559+ concat_type = STRING
560+ if start_cste :
561+ arguments = [Primitive ('"' + constants [0 ] + '"' , STRING )]
562+ for cste in constants [1 :]:
563+ arguments .append (solution_part [i ])
564+ concat_type = Arrow (concat_type , STRING )
565+ arguments .append (Primitive ('"' + cste + '"' , STRING ))
566+ concat_type = Arrow (concat_type , STRING )
567+ i += 1
568+ if i < len (solution_part ):
569+ arguments .append (solution_part [i ])
570+ concat_type = Arrow (concat_type , STRING )
571+
572+ else :
573+ arguments = [solution_part .pop (0 )]
574+ for cste in constants :
575+ arguments .append (Primitive ('"' + cste + '"' , STRING ))
576+ concat_type = Arrow (concat_type , STRING )
577+ if i < len (solution_part ):
578+ arguments .append (solution_part [i ])
579+ concat_type = Arrow (concat_type , STRING )
580+ i += 1
581+ if i < len (solution_part ):
582+ arguments .append (solution_part [i ])
583+ concat_type = Arrow (concat_type , STRING )
584+ end_solution = (
585+ Function (Primitive ("concat" , concat_type ), arguments )
586+ if len (arguments ) > 1
587+ else arguments [0 ]
588+ )
589+ return True , c .elapsed_time (), programs , end_solution , prob
590+
591+ else :
592+ # print("timeout:", task_timeout)
593+ if task .specification .get_specification (PBEWithConstants ) is not None :
594+ return constants_injector (evaluator , task , pcfg , custom_enumerate )
595+ else :
596+ return base (evaluator , task , pcfg , custom_enumerate )
597+
598+
424599# Main ====================================================================
425600
426601if __name__ == "__main__" :
427602 full_dataset , dsl , evaluator , lexicon , model_name = load_dataset ()
428- method = base
429- name = "base "
430- if isinstance (evaluator , DSLEvaluatorWithConstant ):
431- method = constants_injector
432- name = "constants_injector"
603+ method = sketched_base
604+ name = "sketched_base "
605+ # if isinstance(evaluator, DSLEvaluatorWithConstant):
606+ # method = constants_injector
607+ # name = "constants_injector"
433608
434609 pcfgs = produce_pcfgs (full_dataset , dsl , lexicon )
435610 file = os .path .join (
0 commit comments