11package net .sansa_stack .inference .flink .forwardchaining
22
3- import net .sansa_stack .inference .flink .data .RDFGraph
4- import org .apache .flink .api .common .typeinfo .TypeInformation
5- import org .apache .flink .api .scala .{DataSet , _ }
6- import org .apache .flink .util .Collector
73import net .sansa_stack .inference .data .RDFTriple
8- import net .sansa_stack .inference .utils .Profiler
4+ import net .sansa_stack .inference .flink .data .RDFGraph
5+ import org .apache .flink .api .scala .DataSet
96
107import scala .collection .mutable
11- import scala .reflect .ClassTag
128
139/**
1410 * A forward chaining based reasoner.
1511 *
1612 * @author Lorenz Buehmann
1713 */
18- trait ForwardRuleReasoner extends Profiler {
14+ trait ForwardRuleReasoner extends TransitiveReasoner {
1915
2016 /**
2117 * Applies forward chaining to the given RDF graph and returns a new RDF graph that contains all additional
@@ -26,136 +22,6 @@ trait ForwardRuleReasoner extends Profiler{
2622 */
2723 def apply (graph : RDFGraph ) : RDFGraph
2824
29- // def computeTransitiveClosure[A, B, C](s: mutable.Set[(A, B, C)]): mutable.Set[(A, B, C)] = {
30- // val t = addTransitive(s)
31- // // recursive call if set changed, otherwise stop and return
32- // if (t.size == s.size) s else computeTransitiveClosure(t)
33- // }
34-
35- def computeTransitiveClosure (s : mutable.Set [RDFTriple ]): mutable.Set [RDFTriple ] = {
36- val t = addTransitive(s)
37- // recursive call if set changed, otherwise stop and return
38- if (t.size == s.size) s else computeTransitiveClosure(t)
39- }
40-
41- // def addTransitive[A, B, C](s: mutable.Set[(A, B, C)]) = {
42- // s ++ (for ((s1, p1, o1) <- s; (s2, p2, o2) <- s if o1 == s2) yield (s1, p1, o2))
43- // }
44-
45- def addTransitive (s : mutable.Set [RDFTriple ]) = {
46- s ++ (for (t1 <- s; t2 <- s if t1.`object` == t2.subject) yield RDFTriple (t1.subject, t1.predicate, t2.`object`))
47- }
48-
49- def computeTransitiveClosure (triples : DataSet [RDFTriple ]): DataSet [RDFTriple ] = {
50- if (triples.count() == 0 ) return triples
51- log.info(" computing TC..." )
52-
53- profile {
54- // keep the predicate
55- val predicate = triples.first(1 ).collect().head.predicate
56-
57- // compute the TC
58- var subjectObjectPairs = triples.map(t => (t.subject, t.`object`))
59-
60- // because join() joins on keys, in addition the pairs are stored in reversed order (o, s)
61- val objectSubjectPairs = subjectObjectPairs.map(t => (t._2, t._1))
62-
63- // the join is iterated until a fixed point is reached
64- var i = 1
65- var oldCount = 0L
66- var nextCount = triples.count()
67- do {
68- log.info(s " iteration $i... " )
69- oldCount = nextCount
70- // perform the join (s1, o1) x (o2, s2), obtaining an DataSet of (s1=o2, (o1, s2)) pairs,
71- // then project the result to obtain the new (s2, o1) paths.
72- subjectObjectPairs = subjectObjectPairs
73- .union(
74- subjectObjectPairs
75- .join(objectSubjectPairs).where(0 ).equalTo(0 )
76- .map(x => (x._2._2, x._1._2))
77- .filter(tuple => tuple._1 != tuple._2)// omit (s1, s1)
78- )
79- .distinct()
80- nextCount = subjectObjectPairs.count()
81- i += 1
82- } while (nextCount != oldCount)
83-
84- println(" TC has " + nextCount + " triples." )
85- subjectObjectPairs.map(p => RDFTriple (p._1, predicate, p._2))
86- }
87- }
88-
89- def computeTransitiveClosure2 (triples : DataSet [RDFTriple ]): DataSet [RDFTriple ] = {
90- if (triples.count() == 0 ) return triples
91- log.info(" computing TC..." )
92-
93- profile {
94- // keep the predicate
95- val predicate = triples.first(1 ).collect().head.predicate
96-
97- // convert to tuples needed for the JOIN operator
98- val subjectObjectPairs = triples.map(t => (t.subject, t.`object`))
99-
100- // compute the TC
101- val res = subjectObjectPairs.iterateWithTermination(10 ) {
102- prevPaths : DataSet [(String , String )] =>
103-
104- val nextPaths = prevPaths
105- .join(subjectObjectPairs).where(1 ).equalTo(0 ) {
106- (left, right) => (left._1, right._2)
107- }
108- .union(prevPaths)
109- .groupBy(0 , 1 )
110- .reduce((l ,r) => l)
111-
112- val terminate = prevPaths
113- .coGroup(nextPaths)
114- .where(0 ).equalTo(0 ) {
115- (prev, next, out : Collector [(String , String )]) => {
116- val prevPaths = prev.toSet
117- for (n <- next)
118- if (! prevPaths.contains(n)) out.collect(n)
119- }
120- }.withForwardedFieldsSecond(" *" )
121- (nextPaths, terminate)
122- }
123-
124- // map back to RDF triples
125- res.map(p => RDFTriple (p._1, predicate, p._2))
126- }
127- }
128-
129- def computeTransitiveClosure [A : ClassTag : TypeInformation ](edges : DataSet [(A , A )]): DataSet [(A , A )] = {
130- log.info(" computing TC..." )
131- // we keep the transitive closure cached
132- var tc = edges
133-
134- // because join() joins on keys, in addition the pairs are stored in reversed order (o, s)
135- val edgesReversed = tc.map(t => (t._2, t._1))
136-
137- // the join is iterated until a fixed point is reached
138- var i = 1
139- var oldCount = 0L
140- var nextCount = tc.count()
141- do {
142- log.info(s " iteration $i... " )
143- oldCount = nextCount
144- // perform the join (x, y) x (y, x), obtaining an DataSet of (x=y, (y, x)) pairs,
145- // then project the result to obtain the new (x, y) paths.
146- val join = tc.join(edgesReversed).where(0 ).equalTo(0 )
147- join.print()
148- tc = tc
149- .union(join.map(x => (x._2._2, x._2._1)))
150- .distinct()
151- nextCount = tc.count()
152- i += 1
153- } while (nextCount != oldCount)
154-
155- println(" TC has " + nextCount + " edges." )
156- tc
157- }
158-
15925 /**
16026 * Extracts all triples for the given predicate.
16127 *
0 commit comments