Skip to content

Commit bbf4ce1

Browse files
committed
Update
1 parent 621d13e commit bbf4ce1

1 file changed

Lines changed: 199 additions & 24 deletions

File tree

examples/pbe/evaluate.py

Lines changed: 199 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import atexit
2+
from collections import defaultdict
23
import os
34
import sys
45
from typing import Callable, Iterable, List, Optional, Tuple
@@ -13,6 +14,12 @@
1314
from torch.nn.utils.rnn import PackedSequence
1415

1516
from dsl_loader import add_dsl_choice_arg, load_DSL
17+
from examples.pbe.transduction.knowledge_graph.kg_path_finder import (
18+
build_wrapper,
19+
choose_best_path,
20+
find_paths_from_level,
21+
)
22+
from examples.pbe.transduction.knowledge_graph.preprocess_tasks import sketch
1623

1724
from synth import Dataset, PBE, Task
1825
from synth.nn import (
@@ -24,7 +31,7 @@
2431
from synth.pbe import IOEncoder
2532
from synth.semantic import DSLEvaluator
2633
from synth.semantic.evaluator import DSLEvaluatorWithConstant
27-
from synth.specification import PBEWithConstants
34+
from synth.specification import Example, PBEWithConstants
2835
from synth.syntax import (
2936
CFG,
3037
ProbDetGrammar,
@@ -34,6 +41,8 @@
3441
Program,
3542
)
3643
from synth.syntax.grammars.heap_search import HSEnumerator
44+
from synth.syntax.program import Function, Primitive, Variable
45+
from synth.syntax.type_system import STRING, Arrow
3746
from synth.utils import chrono
3847

3948
import argparse
@@ -199,7 +208,10 @@ def produce_pcfgs(
199208
max_depth = max(task.solution.depth() for task in full_dataset)
200209
else:
201210
max_depth = 10 # TODO: set as parameter
202-
cfgs = [CFG.depth_constraint(dsl, t, max_depth) for t in all_type_requests]
211+
cfgs = [
212+
CFG.depth_constraint(dsl, t, max_depth, min_variable_depth=0)
213+
for t in all_type_requests
214+
]
203215

204216
class MyPredictor(nn.Module):
205217
def __init__(self, size: int) -> None:
@@ -364,10 +376,9 @@ def constants_injector(
364376
constants_out = task.specification.constants_out
365377
if len(constants_out) == 0:
366378
constants_out.append("")
367-
name = task.metadata["name"]
368-
program = task.solution
369-
if program == None:
370-
return (False, time, programs, None, None)
379+
# program = task.solution
380+
# if program == None:
381+
# return (False, time, programs, None, None)
371382
with chrono.clock("search.constant_injector") as c:
372383

373384
# print("\n-----------------------")
@@ -398,19 +409,6 @@ def constants_injector(
398409
if not found:
399410
break
400411
if found:
401-
# print("Solution found.\n")
402-
# print("\t", program)
403-
# print(
404-
# "\nWorking for all ",
405-
# counter,
406-
# "/",
407-
# len(task.specification.examples),
408-
# " examples in ",
409-
# time,
410-
# "/",
411-
# task_timeout,
412-
# "s.",
413-
# )
414412
return (
415413
True,
416414
c.elapsed_time(),
@@ -421,15 +419,192 @@ def constants_injector(
421419
return (False, time, programs, None, None)
422420

423421

422+
def sketched_base(
423+
evaluator: DSLEvaluator,
424+
task: Task[PBE],
425+
pcfg: ProbDetGrammar,
426+
custom_enumerate: Callable[[ProbDetGrammar], HSEnumerator],
427+
) -> Tuple[bool, float, int, Optional[Program]]:
428+
programs = 0
429+
global task_timeout
430+
if task.metadata.get("constants", None) is not None:
431+
original_timeout = task_timeout
432+
verbose = False
433+
# (
434+
# task.metadata["constant_post_processing"] == 0
435+
# and task.metadata["constant_detection"] == 0
436+
# and task.metadata["knowledge_graph_relationship"] > 0
437+
# )
438+
if verbose:
439+
print("should solve:", task.metadata.get("name", "???"))
440+
with chrono.clock("additional") as c:
441+
wrapper = build_wrapper(
442+
"http://192.168.1.20:9999/blazegraph/namespace/kb/sparql"
443+
)
444+
constants = task.metadata.get("constants", None)
445+
constants_in = task.metadata.get("constants_in", [])
446+
pbe = task.specification
447+
new_pseudo_tasks = defaultdict(lambda: defaultdict(list))
448+
# print("working on:", task.metadata["name"])
449+
# print("constants out.:", constants)
450+
# print("constants inp.:", constants_in)
451+
true_inputs = (
452+
[
453+
sketch(pbe.examples[i].inputs[0], constants_in)
454+
for i in range(len(pbe.examples))
455+
]
456+
if constants_in
457+
else [pbe.examples[i].inputs for i in range(len(pbe.examples))]
458+
)
459+
# print("true_inputs:", true_inputs)
460+
n = len(true_inputs[0])
461+
for i in range(len(pbe.examples)):
462+
subtasks = sketch(pbe.examples[i].output, constants)
463+
for j in range(len(subtasks)):
464+
for k in range(n):
465+
new_pseudo_tasks[j][k].append((true_inputs[i][k], subtasks[j]))
466+
solution_part = []
467+
prob = 1
468+
for j, possibles in new_pseudo_tasks.items():
469+
any_solved = False
470+
relevant_alternatives = {
471+
k: pairs
472+
for k, pairs in possibles.items()
473+
if not all(len(out) == 0 for _, out in pairs)
474+
and not all(len(inp) == 0 for inp, _ in pairs)
475+
}
476+
subn = len(relevant_alternatives)
477+
if subn == 0:
478+
continue
479+
# print(
480+
# f"\t\tpart[{j}] before:{possibles}")
481+
# print(
482+
# f"\t\tpart[{j}] before:{len(possibles)} after:{len(relevant_alternatives)}")
483+
for k, pairs in relevant_alternatives.items():
484+
# print("\tsub task:", pairs)
485+
d = task.metadata["knowledge_graph_relationship"] - 1
486+
paths = find_paths_from_level(pairs, wrapper, d)
487+
# print("\t\tfound paths:", paths)
488+
if paths:
489+
any_solved = True
490+
if len(paths) > 1:
491+
paths = [choose_best_path(paths, pairs, wrapper)]
492+
custom_input = Variable(0, STRING)
493+
if not (k == 0 and k + 1 >= len(constants_in)):
494+
custom_input = Function(
495+
Primitive(
496+
f"between {constants_in[k] if k > 0 else 'start'} and {constants_in[k + 1] if k + 1 < len(constants_in) else 'end'}",
497+
Arrow(STRING, STRING),
498+
),
499+
[custom_input],
500+
)
501+
solution_part.append(
502+
Function(
503+
Primitive(
504+
"start->" + "->".join(paths[0]) + "->end",
505+
Arrow(STRING, STRING),
506+
),
507+
[custom_input],
508+
)
509+
)
510+
if verbose:
511+
print(
512+
"\tresult:", "start->" + "->".join(paths[0]) + "->end"
513+
)
514+
else:
515+
sub_task = Task(
516+
task.type_request,
517+
PBE(
518+
[
519+
Example([pairs[i][0]], pairs[i][1])
520+
for i in range(len(pbe.examples))
521+
],
522+
),
523+
)
524+
task_timeout = original_timeout - c.elapsed_time()
525+
task_timeout /= subn
526+
if verbose:
527+
print(
528+
"\tsolving with timeout",
529+
task_timeout,
530+
"s :",
531+
sub_task.specification.examples,
532+
)
533+
534+
(
535+
solved,
536+
_,
537+
enumerated,
538+
partial_sol,
539+
part_prob,
540+
) = base(evaluator, sub_task, pcfg, custom_enumerate)
541+
task_timeout = original_timeout
542+
if verbose:
543+
print("\tresult:", solved, partial_sol)
544+
if c.elapsed_time() >= task_timeout:
545+
return (False, c.elapsed_time(), programs, None, None)
546+
if solved:
547+
any_solved = True
548+
prob *= part_prob
549+
solution_part.append(partial_sol)
550+
programs += enumerated
551+
if any_solved:
552+
break
553+
if not any_solved:
554+
return False, c.elapsed_time(), programs, None, None
555+
# Convert back to a program
556+
some_output: str = pbe.examples[0].output
557+
start_cste = len(constants) > 0 and some_output.startswith(constants[0])
558+
i = 0
559+
concat_type = STRING
560+
if start_cste:
561+
arguments = [Primitive('"' + constants[0] + '"', STRING)]
562+
for cste in constants[1:]:
563+
arguments.append(solution_part[i])
564+
concat_type = Arrow(concat_type, STRING)
565+
arguments.append(Primitive('"' + cste + '"', STRING))
566+
concat_type = Arrow(concat_type, STRING)
567+
i += 1
568+
if i < len(solution_part):
569+
arguments.append(solution_part[i])
570+
concat_type = Arrow(concat_type, STRING)
571+
572+
else:
573+
arguments = [solution_part.pop(0)]
574+
for cste in constants:
575+
arguments.append(Primitive('"' + cste + '"', STRING))
576+
concat_type = Arrow(concat_type, STRING)
577+
if i < len(solution_part):
578+
arguments.append(solution_part[i])
579+
concat_type = Arrow(concat_type, STRING)
580+
i += 1
581+
if i < len(solution_part):
582+
arguments.append(solution_part[i])
583+
concat_type = Arrow(concat_type, STRING)
584+
end_solution = (
585+
Function(Primitive("concat", concat_type), arguments)
586+
if len(arguments) > 1
587+
else arguments[0]
588+
)
589+
return True, c.elapsed_time(), programs, end_solution, prob
590+
591+
else:
592+
# print("timeout:", task_timeout)
593+
if task.specification.get_specification(PBEWithConstants) is not None:
594+
return constants_injector(evaluator, task, pcfg, custom_enumerate)
595+
else:
596+
return base(evaluator, task, pcfg, custom_enumerate)
597+
598+
424599
# Main ====================================================================
425600

426601
if __name__ == "__main__":
427602
full_dataset, dsl, evaluator, lexicon, model_name = load_dataset()
428-
method = base
429-
name = "base"
430-
if isinstance(evaluator, DSLEvaluatorWithConstant):
431-
method = constants_injector
432-
name = "constants_injector"
603+
method = sketched_base
604+
name = "sketched_base"
605+
# if isinstance(evaluator, DSLEvaluatorWithConstant):
606+
# method = constants_injector
607+
# name = "constants_injector"
433608

434609
pcfgs = produce_pcfgs(full_dataset, dsl, lexicon)
435610
file = os.path.join(

0 commit comments

Comments
 (0)