@@ -18,11 +18,12 @@ import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader
1818import net .sansa_stack .inference .spark .utils .NTriplesToParquetConverter .{DEFAULT_NUM_THREADS , DEFAULT_PARALLELISM }
1919import net .sansa_stack .inference .utils .RuleUtils ._
2020import net .sansa_stack .inference .utils .{Logging , TripleUtils }
21+ import org .apache .jena .rdf .model .Resource
2122
2223import scala .concurrent .duration .FiniteDuration
2324
2425
25- // case class RDFTriple(s: Node, p: Node, o: Node)
26+ // case class RDFTriple(s: Node, p: Node, o: Node)
2627case class RDFTriple (s : String , p : String , o : String )
2728
2829/**
@@ -34,19 +35,17 @@ class BackwardChainingReasonerDataframe(
3435 val graph : Dataset [RDFTriple ]) extends Logging {
3536
3637 import org .apache .spark .sql .functions ._
38+ private implicit def resourceToNodeConverter (resource : Resource ): Node = resource.asNode()
3739
3840 val precomputeSchema : Boolean = true
3941
40- var schema : Map [Node , Dataset [RDFTriple ]] = Map ()
42+ lazy val schema : Map [Node , Dataset [RDFTriple ]] = if (precomputeSchema) extractWithIndex(graph) else Map ()
4143
4244 def isEntailed (triple : Triple ): Boolean = {
4345 isEntailed(new TriplePattern (triple))
4446 }
4547
4648 def isEntailed (tp : TriplePattern ): Boolean = {
47-
48- if (precomputeSchema) schema = extractWithIndex(graph)
49-
5049 val tree = buildTree(new AndNode (tp), Seq ())
5150 println(tree.toString)
5251
@@ -90,6 +89,26 @@ class BackwardChainingReasonerDataframe(
9089 lookup(tp.asTriple())
9190 }
9291
92+ private def lookupSimple (tp : Triple ): Dataset [RDFTriple ] = {
93+ info(s " Lookup data for $tp" )
94+ val s = tp.getSubject.toString()
95+ val p = tp.getPredicate.toString()
96+ val o = tp.getObject.toString()
97+
98+ var filteredGraph = graph
99+
100+ if (tp.getSubject.isConcrete) {
101+ filteredGraph.filter(t => t.s.equals(s))
102+ }
103+ if (tp.getPredicate.isConcrete) {
104+ filteredGraph = filteredGraph.filter(t => t.p.equals(p))
105+ }
106+ if (tp.getObject.isConcrete) {
107+ filteredGraph = filteredGraph.filter(t => t.o.equals(o))
108+ }
109+ filteredGraph
110+ }
111+
93112 private def lookup (tp : Triple ): Dataset [RDFTriple ] = {
94113
95114 val terminological = TripleUtils .isTerminological(tp)
@@ -196,7 +215,11 @@ class BackwardChainingReasonerDataframe(
196215 dataset.sparkSession.sql(sql).as[RDFTriple ]
197216 }
198217
199- val properties = Set (RDFS .subClassOf, RDFS .subPropertyOf, RDFS .domain, RDFS .range).map(p => p.asNode())
218+ val properties = Set (
219+ (RDFS .subClassOf, true , " SCO" ),
220+ (RDFS .subPropertyOf, true , " SPO" ),
221+ (RDFS .domain, false , " DOM" ),
222+ (RDFS .range, false , " RAN" ))
200223 val DUMMY_VAR = NodeFactory .createVariable(" VAR" );
201224
202225 /**
@@ -211,18 +234,25 @@ class BackwardChainingReasonerDataframe(
211234
212235 // for each schema property p
213236 val index =
214- properties.map { p =>
237+ properties.map { entry =>
238+ val p = entry._1
239+ val tc = entry._2
240+ val alias = entry._3
241+
215242 // get triples (s, p, o)
216- var triples = lookup(new TriplePattern (DUMMY_VAR , p, DUMMY_VAR ))
243+ var triples = lookupSimple(Triple .create(DUMMY_VAR , p, DUMMY_VAR ))
244+
245+ // compute TC if necessary
246+ if (tc) triples = computeTC(triples)
217247
218248 // broadcast the triples
219- triples = broadcast(triples)
249+ triples = broadcast(triples).alias(alias)
220250
221251 // register as a table
222252 triples.createOrReplaceTempView(FmtUtils .stringForNode(p).replace(" :" , " _" ))
223253
224254 // add to index
225- (p -> triples)
255+ (p.asNode() -> triples)
226256 }
227257 log.info(" Finished schema extraction." )
228258
@@ -232,17 +262,10 @@ class BackwardChainingReasonerDataframe(
232262 def query (tp : Triple ): Dataset [RDFTriple ] = {
233263 import org .apache .spark .sql .functions ._
234264
235- val domain = broadcast(graph.filter(t => t.p == RDFS .domain.toString)).alias(" DOMAIN" )
236- domain.createOrReplaceTempView(" DOMAIN" )
237-
238- val range = broadcast(graph.filter(t => t.p == RDFS .range.toString)).alias(" RANGE" )
239- range.createOrReplaceTempView(" RANGE" )
240-
241- val sco = broadcast(graph.filter(t => t.p == RDFS .subClassOf.toString)).alias(" SCO" )
242- sco.createOrReplaceTempView(" SCO" )
243-
244- val spo = broadcast(graph.filter(t => t.p == RDFS .subPropertyOf.toString)).alias(" SPO" )
245- spo.createOrReplaceTempView(" SPO" )
265+ val domain = schema.getOrElse(RDFS .domain, broadcast(graph.filter(t => t.p == RDFS .domain.toString)).alias(" DOMAIN" ))
266+ val range = schema.getOrElse(RDFS .range, broadcast(graph.filter(t => t.p == RDFS .range.toString)).alias(" RANGE" ))
267+ val sco = schema.getOrElse(RDFS .subClassOf, broadcast(computeTC(graph.filter(t => t.p == RDFS .subClassOf.toString))).alias(" SCO" ))
268+ val spo = schema.getOrElse(RDFS .subPropertyOf, broadcast(computeTC(graph.filter(t => t.p == RDFS .subPropertyOf.toString))).alias(" SPO" ))
246269
247270 // asserted triples
248271 var ds = lookup(tp)
@@ -325,11 +348,11 @@ class BackwardChainingReasonerDataframe(
325348 .select(types(" s" ).alias(" s" ), lit(RDF .`type`.toString).alias(" p" ), sco(" o" ).alias(" o" ))
326349 .as[RDFTriple ]
327350
328- // println (s"|rdf:type|=${ds.count()}")
329- // println (s"|rdfs2|=${rdfs2.count()}")
330- // println (s"|rdfs3|=${rdfs3.count()}")
331- // println (s"|rdf:type/rdfs2/rdfs3/|=${types.count()}")
332- // println (s"|rdfs9|=${rdfs9.count()}")
351+ // log.info (s"|rdf:type|=${ds.count()}")
352+ // log.info (s"|rdfs2|=${rdfs2.count()}")
353+ // log.info (s"|rdfs3|=${rdfs3.count()}")
354+ // log.info (s"|rdf:type/rdfs2/rdfs3/|=${types.count()}")
355+ // log.info (s"|rdfs9|=${rdfs9.count()}")
333356
334357
335358
@@ -384,10 +407,46 @@ class BackwardChainingReasonerDataframe(
384407 ds.distinct()
385408 }
386409
387-
410+ /**
411+ * Computes the transitive closure for a Dataset of triples. The assumption is that this Dataset is already
412+ * filter by a single predicate.
413+ *
414+ * @param ds the Dataset of triples
415+ * @return a Dataset containing the transitive closure of the triples
416+ */
417+ private def computeTC (ds : Dataset [RDFTriple ]): Dataset [RDFTriple ] = {
418+ var tc = ds
419+ tc.cache()
420+
421+ // the join is iterated until a fixed point is reached
422+ var i = 1
423+ var oldCount = 0L
424+ var nextCount = tc.count()
425+ do {
426+ log.info(s " iteration $i... " )
427+ oldCount = nextCount
428+
429+ val joined = tc.alias(" A" )
430+ .join(tc.alias(" B" ), $" A.o" === $" B.s" )
431+ .select($" A.s" , $" A.p" , $" B.o" )
432+ .as[RDFTriple ]
433+
434+ tc = tc
435+ .union(joined)
436+ .distinct()
437+ .cache()
438+ nextCount = tc.count()
439+ i += 1
440+ } while (nextCount != oldCount)
441+
442+ tc.unpersist()
443+
444+ log.info(" TC has " + nextCount + " edges." )
445+ tc
446+ }
388447}
389448
390- object BackwardChainingReasonerDataframe {
449+ object BackwardChainingReasonerDataframe extends Logging {
391450
392451 val DEFAULT_PARALLELISM = 200
393452 val DEFAULT_NUM_THREADS = 4
@@ -430,7 +489,7 @@ object BackwardChainingReasonerDataframe {
430489
431490 // compute size here to have it cached
432491 time {
433- println (s " |G|= ${graph.count()}" )
492+ log.info (s " |G|= ${graph.count()}" )
434493 }
435494
436495 val rules = RuleSets .RDFS_SIMPLE
@@ -458,6 +517,20 @@ object BackwardChainingReasonerDataframe {
458517 NodeFactory .createVariable(" o" ))
459518 compare(tp, reasoner)
460519
520+ // :s rdfs:subClassOf VAR
521+ tp = Triple .create(
522+ NodeFactory .createURI(" http://swat.cse.lehigh.edu/onto/univ-bench.owl#ClericalStaff" ),
523+ RDFS .subClassOf.asNode(),
524+ NodeFactory .createVariable(" o" ))
525+ compare(tp, reasoner, true )
526+
527+ // :s rdfs:subPropertyOf VAR
528+ tp = Triple .create(
529+ NodeFactory .createURI(" http://swat.cse.lehigh.edu/onto/univ-bench.owl#headOf" ),
530+ RDFS .subPropertyOf.asNode(),
531+ NodeFactory .createVariable(" o" ))
532+ compare(tp, reasoner)
533+
461534 // VAR :p VAR
462535 tp = Triple .create(
463536 NodeFactory .createVariable(" s" ),
@@ -510,15 +583,15 @@ object BackwardChainingReasonerDataframe {
510583 session.stop()
511584 }
512585
513- def compare (tp : Triple , reasoner : BackwardChainingReasonerDataframe ): Unit = {
586+ def compare (tp : Triple , reasoner : BackwardChainingReasonerDataframe , show : Boolean = false ): Unit = {
514587 time {
515588 val triples = reasoner.query(tp)
516589 println(triples.count())
517- // println( triples.show(false) )
590+ if (show) triples.show(false )
518591 }
519592
520593// time {
521- // println (reasoner.isEntailed(tp))
594+ // log.info (reasoner.isEntailed(tp))
522595// }
523596 }
524597
@@ -527,7 +600,7 @@ object BackwardChainingReasonerDataframe {
527600 val t0 = System .nanoTime()
528601 val result = block // call-by-name
529602 val t1 = System .nanoTime()
530- println (" Elapsed time: " + FiniteDuration (t1 - t0, " ns" ).pretty)
603+ log.info (" Elapsed time: " + FiniteDuration (t1 - t0, " ns" ).pretty)
531604 result
532605 }
533606}
@@ -570,8 +643,8 @@ object PrettyDuration {
570643 }
571644
572645 def abbreviate (unit : TimeUnit ): String = unit match {
573- case NANOSECONDS => " ns"
574- case MICROSECONDS => " μs "
646+ case NANOSECONDS => " ns"
647+ case MICROSECONDS => " micros "
575648 case MILLISECONDS => " ms"
576649 case SECONDS => " s"
577650 case MINUTES => " min"
0 commit comments