Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit 558003d

Browse files
backward chaining - Spark Dataset-based.
1 parent 571cc01 commit 558003d

3 files changed

Lines changed: 225 additions & 45 deletions

File tree

sansa-inference-common/src/main/scala/net/sansa_stack/inference/rules/plan/SimpleSQLGenerator.scala

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class SimpleSQLGenerator(val sqlSchema: SQLSchema = SQLSchemaDefault) extends SQ
7878
// expressions += (if(target.getPredicate.isVariable) expressionFor(target.getPredicate, target) else target.getPredicate.toString)
7979
// expressions += (if(target.getObject.isVariable) expressionFor(target.getObject, target) else target.getObject.toString)
8080

81+
var i = 0
8182
requiredVars.foreach{ v =>
8283
if (v.isVariable) {
8384
var done = false
@@ -86,20 +87,31 @@ class SimpleSQLGenerator(val sqlSchema: SQLSchema = SQLSchemaDefault) extends SQ
8687
val expr = expressionFor(v, tp)
8788

8889
if(expr != "NULL") {
89-
expressions += expr
90+
expressions += withAlias(expr, i)
9091
done = true
9192
}
9293
}
9394
} else {
94-
expressions += "'" + v.toString + "'"
95+
val expr = "'" + v.toString + "'"
96+
expressions += withAlias(expr, i)
9597
}
98+
i += 1
9699
}
97100

98101
sql += expressions.mkString(", ")
99102

100103
sql
101104
}
102105

106+
private def withAlias(expr: String, i: Int): String = {
107+
val appendix = i match {
108+
case 0 => "s"
109+
case 1 => "p"
110+
case 2 => "o"
111+
}
112+
expr + " AS " + appendix
113+
}
114+
103115
private def fromPart(body: Set[Triple]): String = {
104116
val joins = determineJoins(body)
105117

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/backwardchaining/BackwardChainingReasoner.scala

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,23 +42,6 @@ class BackwardChainingReasoner(val rules: Set[Rule], val graph: RDFGraph) extend
4242
val rdd = processTree(tree)
4343
// println(rdd.count())
4444

45-
46-
47-
// rules.foreach(r => {
48-
// // check if the head is more general than the triple in question
49-
// var head = r.headTriplePatterns()
50-
//
51-
// head.foreach(headTP => {
52-
// val subsumes = headTP.subsumes(tp)
53-
//
54-
// if(subsumes) {
55-
// val boundRule = createBinding(r, headTP, tp)
56-
// println(boundRule.bodyTriplePatterns().mkString(";"))
57-
// }
58-
// })
59-
//
60-
// })
61-
6245
false
6346
}
6447

@@ -105,7 +88,6 @@ class BackwardChainingReasoner(val rules: Set[Rule], val graph: RDFGraph) extend
10588

10689
case filter: BindableFilter =>
10790
val operands = filter.getCondition.asInstanceOf[RexCall].getOperands
108-
operands.get(0).
10991

11092
case _ =>
11193
}
@@ -244,31 +226,6 @@ class BackwardChainingReasoner(val rules: Set[Rule], val graph: RDFGraph) extend
244226

245227
}
246228

247-
def join(rdd1: RDD[Triple], rdd2: RDD[Triple],
248-
idx1: Int, idx2: Int,
249-
selectIndexes: Seq[Int],
250-
targetTP: TriplePattern): RDD[Triple] = {
251-
val tmp1 = idx1 match {
252-
case 1 => rdd1.map(t => (t.getSubject, (t.getPredicate, t.getObject)))
253-
case 2 => rdd1.map(t => (t.getPredicate, (t.getSubject, t.getObject)))
254-
case 3 => rdd1.map(t => (t.getObject, (t.getSubject, t.getPredicate)))
255-
case _ => throw new IllegalArgumentException(s"$idx1 is not supported for triples")
256-
}
257-
258-
val tmp2 = idx2 match {
259-
case 1 => rdd2.map(t => (t.getSubject, (t.getPredicate, t.getObject)))
260-
case 2 => rdd2.map(t => (t.getPredicate, (t.getSubject, t.getObject)))
261-
case 3 => rdd2.map(t => (t.getObject, (t.getSubject, t.getPredicate)))
262-
case _ => throw new IllegalArgumentException(s"$idx2 is not supported for triples")
263-
}
264-
265-
val joined = tmp1.join(tmp2)
266-
267-
268-
269-
270-
rdd1
271-
}
272229

273230
}
274231

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
package net.sansa_stack.inference.spark.backwardchaining
2+
3+
import net.sansa_stack.inference.rules.RuleSets
4+
import net.sansa_stack.inference.rules.plan.{SimplePlanGenerator, SimpleSQLGenerator, TriplesSchema}
5+
import net.sansa_stack.inference.spark.backwardchaining.tree.{AndNode, OrNode}
6+
import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader
7+
import net.sansa_stack.inference.utils.{Logging, TripleUtils}
8+
import net.sansa_stack.inference.utils.RuleUtils._
9+
import org.apache.jena.graph.{NodeFactory, Triple}
10+
import org.apache.jena.reasoner.TriplePattern
11+
import org.apache.jena.reasoner.rulesys.Rule
12+
import org.apache.jena.reasoner.rulesys.impl.BindingVector
13+
import org.apache.jena.vocabulary.RDF
14+
import org.apache.spark.sql.{Dataset, SparkSession}
15+
import net.sansa_stack.inference.utils.TripleUtils._
16+
17+
18+
//case class RDFTriple(s: Node, p: Node, o: Node)
19+
case class RDFTriple(s: String, p: String, o: String)
20+
21+
/**
22+
* @author Lorenz Buehmann
23+
*/
24+
class BackwardChainingReasonerDataframe(
25+
val session: SparkSession,
26+
val rules: Set[Rule],
27+
val graph: Dataset[RDFTriple]) extends Logging {
28+
29+
def isEntailed(triple: Triple): Boolean = {
30+
isEntailed(new TriplePattern(triple))
31+
}
32+
33+
def isEntailed(tp: TriplePattern): Boolean = {
34+
35+
val tree = buildTree(new AndNode(tp), Seq())
36+
println(tree.toString)
37+
38+
val triples = processTree(tree)
39+
triples.explain(true)
40+
println(triples.count())
41+
42+
false
43+
}
44+
45+
import org.apache.spark.sql.functions._
46+
val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
47+
48+
private def processTree(tree: AndNode): Dataset[RDFTriple] = {
49+
// 1. look for asserted triples in the graph
50+
val assertedTriples = lookup(tree.element)
51+
if(TripleUtils.isTerminological(tree.element.asTriple())) broadcast(assertedTriples)
52+
53+
54+
// 2. process the inference rules that can infer the triple pattern
55+
val inferredTriples = tree.children.map(child => {
56+
println(s"processing rule ${child.element}")
57+
58+
// first process the children, i.e. we get the data for each triple pattern in the body of the rule
59+
val childrenTriples: Seq[Dataset[RDFTriple]] = child.children.map(processTree(_))
60+
61+
val union = childrenTriples.reduce(_ union _)
62+
63+
// then apply the rule on the UNION of the children data
64+
applyRule(child.element, union)
65+
})
66+
67+
var triples = assertedTriples
68+
69+
if(inferredTriples.nonEmpty) triples = triples.union(inferredTriples.reduce(_ union _))
70+
71+
triples
72+
}
73+
74+
private def lookup(tp: TriplePattern): Dataset[RDFTriple] = {
75+
val s = tp.getSubject.toString()
76+
val p = tp.getPredicate.toString()
77+
val o = tp.getObject.toString()
78+
var filteredGraph = graph
79+
if(tp.getSubject.isConcrete) {
80+
filteredGraph = filteredGraph.filter(t => t.s.equals(s))
81+
}
82+
if(tp.getPredicate.isConcrete) {
83+
filteredGraph = filteredGraph.filter(t => t.p.equals(p))
84+
}
85+
if(tp.getObject.isConcrete) {
86+
filteredGraph = filteredGraph.filter(t => t.o.equals(o))
87+
}
88+
filteredGraph
89+
}
90+
91+
private def buildTree(tree: AndNode, visited: Seq[Rule]): AndNode = {
92+
val tp = tree.element
93+
94+
rules.filterNot(visited.contains(_)).foreach(r => {
95+
// check if the head is more general than the triple in question
96+
var head = r.headTriplePatterns()
97+
98+
head.foreach(headTP => {
99+
val subsumes = headTP.subsumes(tp)
100+
101+
if(subsumes) {
102+
// instantiate the rule
103+
val boundRule = instantiateRule(r, tp)
104+
105+
// add new Or-node to tree
106+
val node = new OrNode(boundRule)
107+
// println(node)
108+
tree.children :+= node
109+
110+
boundRule.bodyTriplePatterns().foreach(newTp => {
111+
node.children :+= buildTree(new AndNode(newTp), visited ++ Seq(r))
112+
})
113+
}
114+
})
115+
116+
})
117+
118+
tree
119+
}
120+
121+
/*
122+
// create a binding for the rule variables
123+
*/
124+
private def instantiateRule(rule: Rule, tp: TriplePattern): Rule = {
125+
val headTP = rule.headTriplePatterns().head // TODO handle rules with multiple head TPs
126+
127+
val binding = new BindingVector(5)
128+
129+
// the subject
130+
if(tp.getSubject.isConcrete && headTP.getSubject.isVariable) {
131+
binding.bind(headTP.getSubject, tp.getSubject)
132+
}
133+
// the predicate
134+
if(tp.getPredicate.isConcrete && headTP.getPredicate.isVariable) {
135+
binding.bind(headTP.getPredicate, tp.getPredicate)
136+
}
137+
// the object
138+
if(tp.getObject.isConcrete && headTP.getObject.isVariable) {
139+
binding.bind(headTP.getObject, tp.getObject)
140+
}
141+
142+
rule.instantiate(binding)
143+
}
144+
145+
import session.implicits._
146+
147+
private def applyRule(rule: Rule, dataset: Dataset[RDFTriple]): Dataset[RDFTriple] = {
148+
// convert to SQL
149+
val sqlGenerator = new SimpleSQLGenerator()
150+
val sql = sqlGenerator.generateSQLQuery(rule)
151+
// val sql =
152+
// """
153+
// |SELECT rel0.s, 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AS p, 'http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AS o
154+
// | FROM TRIPLES rel1 INNER JOIN TRIPLES rel0 ON rel1.s=rel0.p
155+
// | WHERE rel1.o='http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AND rel1.p='http://www.w3.org/2000/01/rdf-schema#domain'
156+
// """.stripMargin
157+
158+
// generate logical execution plan
159+
val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
160+
val plan = planGenerator.generateLogicalPlan(rule)
161+
162+
dataset.sparkSession.sql(sql).as[RDFTriple]
163+
}
164+
165+
166+
}
167+
168+
object BackwardChainingReasonerDataframe {
169+
170+
171+
def main(args: Array[String]): Unit = {
172+
173+
val parallelism = 200
174+
175+
// the SPARK config
176+
val session = SparkSession.builder
177+
.appName(s"Spark Backward Chaining")
178+
.master("local[4]")
179+
// .config("spark.eventLog.enabled", "true")
180+
.config("spark.hadoop.validateOutputSpecs", "false") // override output files
181+
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
182+
.config("spark.default.parallelism", parallelism)
183+
.config("spark.ui.showConsoleProgress", "false")
184+
.config("spark.sql.shuffle.partitions", parallelism)
185+
.getOrCreate()
186+
187+
import session.implicits._
188+
// implicit val myObjEncoder = org.apache.spark.sql.Encoders.kryo[RDFTriple]
189+
190+
val triples = RDFGraphLoader.loadFromDisk(session, args(0))
191+
// .triples.map(t => RDFTriple(t.getSubject, t.getPredicate, t.getObject))
192+
.triples.map(t => RDFTriple(t.getSubject.toString(), t.getPredicate.toString(), t.getObject.toString()))
193+
val graph = session.createDataset(triples).cache()
194+
graph.createOrReplaceTempView("TRIPLES")
195+
196+
val rules = RuleSets.RDFS_SIMPLE
197+
.filter(r => Seq("rdfs2", "rdfs3").contains(r.getName))
198+
199+
val tp = Triple.create(
200+
NodeFactory.createVariable("s"),
201+
RDF.`type`.asNode(),
202+
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person"))
203+
204+
val reasoner = new BackwardChainingReasonerDataframe(session, rules, graph)
205+
206+
println(reasoner.isEntailed(tp))
207+
208+
session.stop()
209+
}
210+
}
211+

0 commit comments

Comments
 (0)