Skip to content

Commit c75b4cb

Browse files
committed
Fixed automatic suppression of rules with probability 0
1 parent b5d7882 commit c75b4cb

1 file changed

Lines changed: 16 additions & 5 deletions

File tree

synth/syntax/grammars/tagged_det_grammar.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,8 +253,14 @@ def random(
253253

254254
@classmethod
255255
def pcfg_from_samples(
256-
cls, cfg: "CFG", samples: Iterable[Program]
256+
cls, cfg: "CFG", samples: Iterable[Program], remove_zero_rules: bool = False
257257
) -> "ProbDetGrammar[Tuple[CFGState, NoneType], Tuple[List[Tuple[Type, CFGState]], NoneType], List[Tuple[Type, CFGState]]]":
258+
"""
259+
Produces the PCFG whose distribution is the mean distribution over the samples.
260+
261+
Rules with probability zero are not removed unless remove_zero_rules is True.
262+
263+
"""
258264
rules_cnt: Dict[CFGNonTerminal, Dict[DerivableProgram, int]] = {}
259265
for S in cfg.rules:
260266
rules_cnt[S] = {}
@@ -287,9 +293,14 @@ def add_count(S: CFGNonTerminal, P: Program) -> bool:
287293
probabilities: Dict[CFGNonTerminal, Dict[DerivableProgram, float]] = {}
288294
for S in cfg.rules:
289295
total = sum(rules_cnt[S][P] for P in cfg.rules[S])
290-
if total > 0:
291-
probabilities[S] = {}
292-
for P in rules_cnt[S]:
293-
probabilities[S][P] = rules_cnt[S][P] / total
296+
if total <= 0:
297+
if remove_zero_rules:
298+
continue
299+
total = 1
300+
probabilities[S] = {}
301+
for P in rules_cnt[S]:
302+
val = rules_cnt[S].get(P, 0)
303+
if val > 0 or not remove_zero_rules:
304+
probabilities[S][P] = val / total
294305

295306
return ProbDetGrammar(cfg, probabilities)

0 commit comments

Comments
 (0)