Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit effe5bf

Browse files
More analysis of backward chaining best practices.
1 parent 558003d commit effe5bf

4 files changed

Lines changed: 109 additions & 24 deletions

File tree

sansa-inference-common/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@
9797
</dependencies>
9898

9999
<build>
100-
<sourceDirectory>src/main/scala,src/main/resources</sourceDirectory>
101-
<testSourceDirectory>src/test/scala</testSourceDirectory>
100+
<!--<sourceDirectory>src/main/scala,src/main/resources</sourceDirectory>-->
101+
<!--<testSourceDirectory>src/test/scala</testSourceDirectory>-->
102102
<!--<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>-->
103103
<!--<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>-->
104104
<plugins>

sansa-inference-common/src/main/scala/net/sansa_stack/inference/rules/plan/SimpleSQLGenerator.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class SimpleSQLGenerator(val sqlSchema: SQLSchema = SQLSchemaDefault) extends SQ
2222
var idx = 0
2323

2424
def generateSQLQuery(rule: Rule): String = {
25-
info(s"Rule:\n$rule")
25+
debug(s"Rule:\n$rule")
2626

2727
reset()
2828

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/backwardchaining/BackwardChainingReasonerDataframe.scala

Lines changed: 103 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
package net.sansa_stack.inference.spark.backwardchaining
22

3-
import net.sansa_stack.inference.rules.RuleSets
4-
import net.sansa_stack.inference.rules.plan.{SimplePlanGenerator, SimpleSQLGenerator, TriplesSchema}
5-
import net.sansa_stack.inference.spark.backwardchaining.tree.{AndNode, OrNode}
6-
import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader
7-
import net.sansa_stack.inference.utils.{Logging, TripleUtils}
8-
import net.sansa_stack.inference.utils.RuleUtils._
9-
import org.apache.jena.graph.{NodeFactory, Triple}
3+
import org.apache.jena.graph.{Node, NodeFactory, Triple}
104
import org.apache.jena.reasoner.TriplePattern
115
import org.apache.jena.reasoner.rulesys.Rule
126
import org.apache.jena.reasoner.rulesys.impl.BindingVector
13-
import org.apache.jena.vocabulary.RDF
7+
import org.apache.jena.sparql.util.FmtUtils
8+
import org.apache.jena.vocabulary.{RDF, RDFS}
149
import org.apache.spark.sql.{Dataset, SparkSession}
15-
import net.sansa_stack.inference.utils.TripleUtils._
10+
11+
import net.sansa_stack.inference.rules.RuleSets
12+
import net.sansa_stack.inference.rules.plan.SimpleSQLGenerator
13+
import net.sansa_stack.inference.spark.backwardchaining.tree.{AndNode, OrNode}
14+
import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader
15+
import net.sansa_stack.inference.utils.RuleUtils._
16+
import net.sansa_stack.inference.utils.{Logging, TripleUtils}
1617

1718

1819
//case class RDFTriple(s: Node, p: Node, o: Node)
@@ -26,24 +27,31 @@ class BackwardChainingReasonerDataframe(
2627
val rules: Set[Rule],
2728
val graph: Dataset[RDFTriple]) extends Logging {
2829

30+
import org.apache.spark.sql.functions._
31+
32+
val precomputeSchema: Boolean = true
33+
34+
var schema: Map[Node, Dataset[RDFTriple]] = Map()
35+
2936
def isEntailed(triple: Triple): Boolean = {
3037
isEntailed(new TriplePattern(triple))
3138
}
3239

3340
def isEntailed(tp: TriplePattern): Boolean = {
3441

42+
if (precomputeSchema) schema = extractWithIndex(graph)
43+
3544
val tree = buildTree(new AndNode(tp), Seq())
3645
println(tree.toString)
3746

3847
val triples = processTree(tree)
3948
triples.explain(true)
40-
println(triples.count())
49+
println(triples.distinct().count())
4150

4251
false
4352
}
4453

45-
import org.apache.spark.sql.functions._
46-
val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
54+
4755

4856
private def processTree(tree: AndNode): Dataset[RDFTriple] = {
4957
// 1. look for asserted triples in the graph
@@ -58,10 +66,11 @@ class BackwardChainingReasonerDataframe(
5866
// first process the children, i.e. we get the data for each triple pattern in the body of the rule
5967
val childrenTriples: Seq[Dataset[RDFTriple]] = child.children.map(processTree(_))
6068

61-
val union = childrenTriples.reduce(_ union _)
69+
val baseTriples = if (childrenTriples.size > 1) childrenTriples.reduce(_ union _) else childrenTriples.head
70+
6271

6372
// then apply the rule on the UNION of the children data
64-
applyRule(child.element, union)
73+
applyRule(child.element, baseTriples)
6574
})
6675

6776
var triples = assertedTriples
@@ -72,14 +81,25 @@ class BackwardChainingReasonerDataframe(
7281
}
7382

7483
private def lookup(tp: TriplePattern): Dataset[RDFTriple] = {
84+
85+
val terminological = TripleUtils.isTerminological(tp.asTriple())
86+
87+
var filteredGraph =
88+
if (terminological) {
89+
schema.getOrElse(tp.getPredicate, graph)
90+
} else {
91+
graph
92+
}
93+
94+
info(s"Lookup data for $tp")
7595
val s = tp.getSubject.toString()
7696
val p = tp.getPredicate.toString()
7797
val o = tp.getObject.toString()
78-
var filteredGraph = graph
98+
7999
if(tp.getSubject.isConcrete) {
80100
filteredGraph = filteredGraph.filter(t => t.s.equals(s))
81101
}
82-
if(tp.getPredicate.isConcrete) {
102+
if(!terminological && tp.getPredicate.isConcrete) {
83103
filteredGraph = filteredGraph.filter(t => t.p.equals(p))
84104
}
85105
if(tp.getObject.isConcrete) {
@@ -147,7 +167,7 @@ class BackwardChainingReasonerDataframe(
147167
private def applyRule(rule: Rule, dataset: Dataset[RDFTriple]): Dataset[RDFTriple] = {
148168
// convert to SQL
149169
val sqlGenerator = new SimpleSQLGenerator()
150-
val sql = sqlGenerator.generateSQLQuery(rule)
170+
var sql = sqlGenerator.generateSQLQuery(rule)
151171
// val sql =
152172
// """
153173
// |SELECT rel0.s, 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AS p, 'http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AS o
@@ -156,12 +176,49 @@ class BackwardChainingReasonerDataframe(
156176
// """.stripMargin
157177

158178
// generate logical execution plan
159-
val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
160-
val plan = planGenerator.generateLogicalPlan(rule)
179+
// val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
180+
// val plan = planGenerator.generateLogicalPlan(rule)
161181

182+
val tableName = s"TRIPLES_${rule.getName}"
183+
sql = sql.replace("TRIPLES", tableName)
184+
println(s"SQL NEW: $sql")
185+
dataset.createOrReplaceTempView(tableName)
162186
dataset.sparkSession.sql(sql).as[RDFTriple]
163187
}
164188

189+
val properties = Set(RDFS.subClassOf, RDFS.subPropertyOf, RDFS.domain, RDFS.range).map(p => p.asNode())
190+
val DUMMY_VAR = NodeFactory.createVariable("VAR");
191+
192+
/**
193+
* Computes the triples for each schema property p, e.g. `rdfs:subClassOf` and returns it as mapping from p
194+
* to the [[Dataset]] containing the triples.
195+
*
196+
* @param graph the RDF graph
197+
* @return a mapping from the corresponding schema property to the Dataframe of s-o pairs
198+
*/
199+
def extractWithIndex(graph: Dataset[RDFTriple]): Map[Node, Dataset[RDFTriple]] = {
200+
log.info("Started schema extraction...")
201+
202+
// for each schema property p
203+
val index =
204+
properties.map { p =>
205+
// get triples (s, p, o)
206+
var triples = lookup(new TriplePattern(DUMMY_VAR, p, DUMMY_VAR))
207+
208+
// broadcast the triples
209+
triples = broadcast(triples)
210+
211+
// register as a table
212+
triples.createOrReplaceTempView(FmtUtils.stringForNode(p).replace(":", "_"))
213+
214+
// add to index
215+
(p -> triples)
216+
}
217+
log.info("Finished schema extraction.")
218+
219+
index.toMap
220+
}
221+
165222

166223
}
167224

@@ -182,6 +239,7 @@ object BackwardChainingReasonerDataframe {
182239
.config("spark.default.parallelism", parallelism)
183240
.config("spark.ui.showConsoleProgress", "false")
184241
.config("spark.sql.shuffle.partitions", parallelism)
242+
.config("spark.sql.autoBroadcastJoinThreshold", "10485760")
185243
.getOrCreate()
186244

187245
import session.implicits._
@@ -192,9 +250,32 @@ object BackwardChainingReasonerDataframe {
192250
.triples.map(t => RDFTriple(t.getSubject.toString(), t.getPredicate.toString(), t.getObject.toString()))
193251
val graph = session.createDataset(triples).cache()
194252
graph.createOrReplaceTempView("TRIPLES")
253+
import org.apache.spark.sql.functions._
254+
val domain = graph.filter(t => t.p == RDFS.domain.toString)
255+
broadcast(domain).createOrReplaceTempView("DOMAIN")
256+
257+
val query =
258+
"""
259+
|SELECT rel0.s AS s, 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AS p, 'http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AS o
260+
| FROM DOMAIN rel1 JOIN TRIPLES rel0 ON rel1.s=rel0.p
261+
| WHERE rel1.o='http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AND rel1.p='http://www.w3.org/2000/01/rdf-schema#domain'
262+
| UNION
263+
| SELECT *
264+
| FROM TRIPLES
265+
| WHERE p ='http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AND o ='http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person'
266+
""".stripMargin
267+
268+
val ds = session.sql(query)
269+
ds.explain()
270+
println(ds.distinct().count())
271+
195272

196273
val rules = RuleSets.RDFS_SIMPLE
197-
.filter(r => Seq("rdfs2", "rdfs3").contains(r.getName))
274+
.filter(r => Seq(
275+
"rdfs2"
276+
// , "rdfs3"
277+
// , "rdfs9"
278+
).contains(r.getName))
198279

199280
val tp = Triple.create(
200281
NodeFactory.createVariable("s"),
@@ -206,6 +287,8 @@ object BackwardChainingReasonerDataframe {
206287
println(reasoner.isEntailed(tp))
207288

208289
session.stop()
290+
291+
209292
}
210293
}
211294

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/backwardchaining/tree/AndOrTree.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@ abstract class Node[T, C <: Node[_, _]](val element: T, var children: Seq[C] = S
1818
override def toString: String = print(0)
1919

2020
def print(indent: Int): String = {
21-
renderElement() + "\n" + children.map(c => "---" * indent + c.print(indent + 1)).mkString("\n")
21+
indentS(renderElement(), indent) + "\n" + children.map(c => "---" * indent + c.print(indent + 1)).mkString("\n")
2222
}
2323

24+
def indentS(s: String, i: Int): String = "---" * i + s
25+
2426
def renderElement(): String = element.toString
2527
}
2628

0 commit comments

Comments
 (0)