@@ -29,14 +29,22 @@ import net.sansa_stack.inference.utils.{GraphUtils, Logging, RuleUtils}
2929 */
3030object RuleDependencyGraphGenerator extends Logging {
3131
32+ sealed trait RuleDependencyDirection
33+ case object ConsumerProducer extends RuleDependencyDirection
34+ case object ProducerConsumer extends RuleDependencyDirection
35+
3236 /**
3337 * Generates the rule dependency graph for a given set of rules.
3438 *
3539 * @param rules the set of rules
36- * @param f a function that denotes whether a rule r1 depends on another rule r2
40+ * @param f a function that denotes whether a rule `r1` depends on another rule `r2`
3741 * @return the rule dependency graph
3842 */
39- def generate (rules : Set [Rule ], f : (Rule , Rule ) => Option [TriplePattern ] = dependsOnSmart, pruned : Boolean = false ): RuleDependencyGraph = {
43+ def generate (rules : Set [Rule ],
44+ f : (Rule , Rule ) => Option [TriplePattern ] = dependsOnSmart,
45+ pruned : Boolean = false ,
46+ dependencyDirection : RuleDependencyDirection = ConsumerProducer ): RuleDependencyGraph = {
47+
4048 // create empty graph
4149 var g = new RuleDependencyGraph ()
4250
@@ -48,11 +56,21 @@ object RuleDependencyGraphGenerator extends Logging {
4856
4957 val r1r2 = f(r1, r2) // r1 depends on r2
5058 if (r1r2.isDefined) {
51- g += (r1 ~+> r2) (r1r2.get)
59+ val edgeLabel = r1r2.get
60+ val edge = dependencyDirection match {
61+ case ConsumerProducer => (r1 ~+> r2) (edgeLabel)
62+ case ProducerConsumer => (r2 ~+> r1) (edgeLabel)
63+ }
64+ g += edge
5265 } else {
5366 val r2r1 = f(r2, r1)
5467 if (r2r1.isDefined) { // r2 depends on r1
55- g += (r2 ~+> r1) (r2r1.get)
68+ val edgeLabel = r2r1.get
69+ val edge = dependencyDirection match {
70+ case ConsumerProducer => (r2 ~+> r1) (edgeLabel)
71+ case ProducerConsumer => (r1 ~+> r2) (edgeLabel)
72+ }
73+ g += edge
5674 }
5775 }
5876
@@ -68,7 +86,7 @@ object RuleDependencyGraphGenerator extends Logging {
6886 g = removeEdgesWithPredicateAlreadyTC(g)
6987 g = removeCyclesIfPredicateIsTC(g)
7088 g = removeEdgesWithCycleOverTCNode(g)
71- g = prune(g)
89+ // g = prune(g)
7290 // g = prune1(g)
7391 }
7492
@@ -257,7 +275,9 @@ object RuleDependencyGraphGenerator extends Logging {
257275 outgoingEdges.foreach(e => {
258276 val targetNode = e.target
259277 val rule = targetNode.value
260- val predicate = e.label.asInstanceOf [TriplePattern ].getPredicate
278+ val edgeLabel = e.label
279+ val predicate = edgeLabel.asInstanceOf [TriplePattern ].getPredicate
280+ // check if the target node computes the TC for the current edge predicate
261281 val isTCNode = RuleUtils .isTransitiveClosure(rule, predicate)
262282 debug(s " Direct successor: ${rule.getName}\t\t isTC = $isTCNode" )
263283
@@ -467,6 +487,7 @@ object RuleDependencyGraphGenerator extends Logging {
467487 ).head
468488 // val edge = (node ~+> node)(head)
469489 redundantEdges +:= edge
490+ debug(s " remove edge $edge" )
470491 }
471492
472493 }
@@ -489,5 +510,5 @@ object RuleDependencyGraphGenerator extends Logging {
489510 " [" + e.source.getName + " ~> " + e.target.getName + " ] '" + e.label
490511 }
491512
492- // override def debug(msg: => String): Unit = println(msg)
513+ override def debug (msg : => String ): Unit = println(msg)
493514}
0 commit comments