Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit 4dc9130

Browse files
TC for SCO and SPO.
1 parent 4d6b0f1 commit 4dc9130

1 file changed

Lines changed: 108 additions & 35 deletions

File tree

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/backwardchaining/BackwardChainingReasonerDataframe.scala

Lines changed: 108 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader
1818
import net.sansa_stack.inference.spark.utils.NTriplesToParquetConverter.{DEFAULT_NUM_THREADS, DEFAULT_PARALLELISM}
1919
import net.sansa_stack.inference.utils.RuleUtils._
2020
import net.sansa_stack.inference.utils.{Logging, TripleUtils}
21+
import org.apache.jena.rdf.model.Resource
2122

2223
import scala.concurrent.duration.FiniteDuration
2324

2425

25-
//case class RDFTriple(s: Node, p: Node, o: Node)
26+
// case class RDFTriple(s: Node, p: Node, o: Node)
2627
case class RDFTriple(s: String, p: String, o: String)
2728

2829
/**
@@ -34,19 +35,17 @@ class BackwardChainingReasonerDataframe(
3435
val graph: Dataset[RDFTriple]) extends Logging {
3536

3637
import org.apache.spark.sql.functions._
38+
private implicit def resourceToNodeConverter(resource: Resource): Node = resource.asNode()
3739

3840
val precomputeSchema: Boolean = true
3941

40-
var schema: Map[Node, Dataset[RDFTriple]] = Map()
42+
lazy val schema: Map[Node, Dataset[RDFTriple]] = if (precomputeSchema) extractWithIndex(graph) else Map()
4143

4244
def isEntailed(triple: Triple): Boolean = {
4345
isEntailed(new TriplePattern(triple))
4446
}
4547

4648
def isEntailed(tp: TriplePattern): Boolean = {
47-
48-
if (precomputeSchema) schema = extractWithIndex(graph)
49-
5049
val tree = buildTree(new AndNode(tp), Seq())
5150
println(tree.toString)
5251

@@ -90,6 +89,26 @@ class BackwardChainingReasonerDataframe(
9089
lookup(tp.asTriple())
9190
}
9291

92+
private def lookupSimple(tp: Triple): Dataset[RDFTriple] = {
93+
info(s"Lookup data for $tp")
94+
val s = tp.getSubject.toString()
95+
val p = tp.getPredicate.toString()
96+
val o = tp.getObject.toString()
97+
98+
var filteredGraph = graph
99+
100+
if(tp.getSubject.isConcrete) {
101+
filteredGraph.filter(t => t.s.equals(s))
102+
}
103+
if(tp.getPredicate.isConcrete) {
104+
filteredGraph = filteredGraph.filter(t => t.p.equals(p))
105+
}
106+
if(tp.getObject.isConcrete) {
107+
filteredGraph = filteredGraph.filter(t => t.o.equals(o))
108+
}
109+
filteredGraph
110+
}
111+
93112
private def lookup(tp: Triple): Dataset[RDFTriple] = {
94113

95114
val terminological = TripleUtils.isTerminological(tp)
@@ -196,7 +215,11 @@ class BackwardChainingReasonerDataframe(
196215
dataset.sparkSession.sql(sql).as[RDFTriple]
197216
}
198217

199-
val properties = Set(RDFS.subClassOf, RDFS.subPropertyOf, RDFS.domain, RDFS.range).map(p => p.asNode())
218+
val properties = Set(
219+
(RDFS.subClassOf, true, "SCO"),
220+
(RDFS.subPropertyOf, true, "SPO"),
221+
(RDFS.domain, false, "DOM"),
222+
(RDFS.range, false, "RAN"))
200223
val DUMMY_VAR = NodeFactory.createVariable("VAR");
201224

202225
/**
@@ -211,18 +234,25 @@ class BackwardChainingReasonerDataframe(
211234

212235
// for each schema property p
213236
val index =
214-
properties.map { p =>
237+
properties.map { entry =>
238+
val p = entry._1
239+
val tc = entry._2
240+
val alias = entry._3
241+
215242
// get triples (s, p, o)
216-
var triples = lookup(new TriplePattern(DUMMY_VAR, p, DUMMY_VAR))
243+
var triples = lookupSimple(Triple.create(DUMMY_VAR, p, DUMMY_VAR))
244+
245+
// compute TC if necessary
246+
if (tc) triples = computeTC(triples)
217247

218248
// broadcast the triples
219-
triples = broadcast(triples)
249+
triples = broadcast(triples).alias(alias)
220250

221251
// register as a table
222252
triples.createOrReplaceTempView(FmtUtils.stringForNode(p).replace(":", "_"))
223253

224254
// add to index
225-
(p -> triples)
255+
(p.asNode() -> triples)
226256
}
227257
log.info("Finished schema extraction.")
228258

@@ -232,17 +262,10 @@ class BackwardChainingReasonerDataframe(
232262
def query(tp: Triple): Dataset[RDFTriple] = {
233263
import org.apache.spark.sql.functions._
234264

235-
val domain = broadcast(graph.filter(t => t.p == RDFS.domain.toString)).alias("DOMAIN")
236-
domain.createOrReplaceTempView("DOMAIN")
237-
238-
val range = broadcast(graph.filter(t => t.p == RDFS.range.toString)).alias("RANGE")
239-
range.createOrReplaceTempView("RANGE")
240-
241-
val sco = broadcast(graph.filter(t => t.p == RDFS.subClassOf.toString)).alias("SCO")
242-
sco.createOrReplaceTempView("SCO")
243-
244-
val spo = broadcast(graph.filter(t => t.p == RDFS.subPropertyOf.toString)).alias("SPO")
245-
spo.createOrReplaceTempView("SPO")
265+
val domain = schema.getOrElse(RDFS.domain, broadcast(graph.filter(t => t.p == RDFS.domain.toString)).alias("DOMAIN"))
266+
val range = schema.getOrElse(RDFS.range, broadcast(graph.filter(t => t.p == RDFS.range.toString)).alias("RANGE"))
267+
val sco = schema.getOrElse(RDFS.subClassOf, broadcast(computeTC(graph.filter(t => t.p == RDFS.subClassOf.toString))).alias("SCO"))
268+
val spo = schema.getOrElse(RDFS.subPropertyOf, broadcast(computeTC(graph.filter(t => t.p == RDFS.subPropertyOf.toString))).alias("SPO"))
246269

247270
// asserted triples
248271
var ds = lookup(tp)
@@ -325,11 +348,11 @@ class BackwardChainingReasonerDataframe(
325348
.select(types("s").alias("s"), lit(RDF.`type`.toString).alias("p"), sco("o").alias("o"))
326349
.as[RDFTriple]
327350

328-
// println(s"|rdf:type|=${ds.count()}")
329-
// println(s"|rdfs2|=${rdfs2.count()}")
330-
// println(s"|rdfs3|=${rdfs3.count()}")
331-
// println(s"|rdf:type/rdfs2/rdfs3/|=${types.count()}")
332-
// println(s"|rdfs9|=${rdfs9.count()}")
351+
// log.info(s"|rdf:type|=${ds.count()}")
352+
// log.info(s"|rdfs2|=${rdfs2.count()}")
353+
// log.info(s"|rdfs3|=${rdfs3.count()}")
354+
// log.info(s"|rdf:type/rdfs2/rdfs3/|=${types.count()}")
355+
// log.info(s"|rdfs9|=${rdfs9.count()}")
333356

334357

335358

@@ -384,10 +407,46 @@ class BackwardChainingReasonerDataframe(
384407
ds.distinct()
385408
}
386409

387-
410+
/**
411+
* Computes the transitive closure for a Dataset of triples. The assumption is that this Dataset is already
412+
* filter by a single predicate.
413+
*
414+
* @param ds the Dataset of triples
415+
* @return a Dataset containing the transitive closure of the triples
416+
*/
417+
private def computeTC(ds: Dataset[RDFTriple]): Dataset[RDFTriple] = {
418+
var tc = ds
419+
tc.cache()
420+
421+
// the join is iterated until a fixed point is reached
422+
var i = 1
423+
var oldCount = 0L
424+
var nextCount = tc.count()
425+
do {
426+
log.info(s"iteration $i...")
427+
oldCount = nextCount
428+
429+
val joined = tc.alias("A")
430+
.join(tc.alias("B"), $"A.o" === $"B.s")
431+
.select($"A.s", $"A.p", $"B.o")
432+
.as[RDFTriple]
433+
434+
tc = tc
435+
.union(joined)
436+
.distinct()
437+
.cache()
438+
nextCount = tc.count()
439+
i += 1
440+
} while (nextCount != oldCount)
441+
442+
tc.unpersist()
443+
444+
log.info("TC has " + nextCount + " edges.")
445+
tc
446+
}
388447
}
389448

390-
object BackwardChainingReasonerDataframe {
449+
object BackwardChainingReasonerDataframe extends Logging{
391450

392451
val DEFAULT_PARALLELISM = 200
393452
val DEFAULT_NUM_THREADS = 4
@@ -430,7 +489,7 @@ object BackwardChainingReasonerDataframe {
430489

431490
// compute size here to have it cached
432491
time {
433-
println(s"|G|=${graph.count()}")
492+
log.info(s"|G|=${graph.count()}")
434493
}
435494

436495
val rules = RuleSets.RDFS_SIMPLE
@@ -458,6 +517,20 @@ object BackwardChainingReasonerDataframe {
458517
NodeFactory.createVariable("o"))
459518
compare(tp, reasoner)
460519

520+
// :s rdfs:subClassOf VAR
521+
tp = Triple.create(
522+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#ClericalStaff"),
523+
RDFS.subClassOf.asNode(),
524+
NodeFactory.createVariable("o"))
525+
compare(tp, reasoner, true)
526+
527+
// :s rdfs:subPropertyOf VAR
528+
tp = Triple.create(
529+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#headOf"),
530+
RDFS.subPropertyOf.asNode(),
531+
NodeFactory.createVariable("o"))
532+
compare(tp, reasoner)
533+
461534
// VAR :p VAR
462535
tp = Triple.create(
463536
NodeFactory.createVariable("s"),
@@ -510,15 +583,15 @@ object BackwardChainingReasonerDataframe {
510583
session.stop()
511584
}
512585

513-
def compare(tp: Triple, reasoner: BackwardChainingReasonerDataframe): Unit = {
586+
def compare(tp: Triple, reasoner: BackwardChainingReasonerDataframe, show: Boolean = false): Unit = {
514587
time {
515588
val triples = reasoner.query(tp)
516589
println(triples.count())
517-
// println(triples.show(false))
590+
if (show) triples.show(false)
518591
}
519592

520593
// time {
521-
// println(reasoner.isEntailed(tp))
594+
// log.info(reasoner.isEntailed(tp))
522595
// }
523596
}
524597

@@ -527,7 +600,7 @@ object BackwardChainingReasonerDataframe {
527600
val t0 = System.nanoTime()
528601
val result = block // call-by-name
529602
val t1 = System.nanoTime()
530-
println("Elapsed time: " + FiniteDuration(t1 - t0, "ns").pretty)
603+
log.info("Elapsed time: " + FiniteDuration(t1 - t0, "ns").pretty)
531604
result
532605
}
533606
}
@@ -570,8 +643,8 @@ object PrettyDuration {
570643
}
571644

572645
def abbreviate(unit: TimeUnit): String = unit match {
573-
case NANOSECONDS => "ns"
574-
case MICROSECONDS => "μs"
646+
case NANOSECONDS => "ns"
647+
case MICROSECONDS => "micros"
575648
case MILLISECONDS => "ms"
576649
case SECONDS => "s"
577650
case MINUTES => "min"

0 commit comments

Comments
 (0)