Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit e02bd1a

Browse files
Extended BC structure.
1 parent 3498a6b commit e02bd1a

5 files changed

Lines changed: 1014 additions & 319 deletions

File tree

Lines changed: 2 additions & 256 deletions
Original file line numberDiff line numberDiff line change
@@ -1,266 +1,12 @@
11
package net.sansa_stack.inference.spark.backwardchaining
22

3-
import java.io.PrintWriter
4-
5-
import org.apache.calcite.interpreter.Bindables.{BindableFilter, BindableJoin, BindableProject}
6-
import org.apache.calcite.rel.{RelNode, RelVisitor}
7-
8-
import net.sansa_stack.inference.rules.RuleSets
9-
import net.sansa_stack.inference.rules.plan.{SimplePlanGenerator, SimpleSQLGenerator, TriplesSchema}
10-
import net.sansa_stack.inference.spark.backwardchaining.tree.{AndNode, OrNode}
11-
import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader
12-
import net.sansa_stack.inference.spark.data.model.RDFGraph
13-
import net.sansa_stack.inference.utils.Logging
14-
import org.apache.jena.graph.{Node, NodeFactory, Triple}
153
import org.apache.jena.reasoner.TriplePattern
16-
import org.apache.jena.reasoner.rulesys.Rule
17-
import org.apache.jena.reasoner.rulesys.impl.BindingVector
18-
import org.apache.jena.vocabulary.RDF
19-
20-
import net.sansa_stack.inference.utils.RuleUtils._
21-
import net.sansa_stack.inference.utils.TripleUtils._
22-
import org.apache.calcite.rel.externalize.RelWriterImpl
23-
import org.apache.calcite.rex.RexCall
24-
import org.apache.spark.rdd.RDD
25-
import org.apache.spark.sql.SparkSession
26-
import org.apache.spark.sql.catalyst.plans.logical.Project
274

285
/**
296
* @author Lorenz Buehmann
307
*/
31-
class BackwardChainingReasoner(val rules: Set[Rule], val graph: RDFGraph) extends Logging{
32-
33-
def isEntailed(triple: Triple): Boolean = {
34-
isEntailed(new TriplePattern(triple))
35-
}
36-
37-
def isEntailed(tp: TriplePattern): Boolean = {
38-
39-
val tree = buildTree(new AndNode(tp), Seq())
40-
println(tree.toString)
41-
42-
val rdd = processTree(tree)
43-
// println(rdd.count())
44-
45-
false
46-
}
47-
48-
val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
49-
50-
private def processTree(tree: AndNode): RDD[Triple] = {
51-
// 1. look for asserted triples in the graph
52-
var rdd = graph.triples // lookup(tree.element)
53-
54-
// 2. process the inference rules that can infer the triple pattern
55-
tree.children.foreach(child => {
56-
println(s"processing rule ${child.element}")
57-
58-
processRule(child.element)
59-
60-
val targetTp = child.element.headTriplePatterns().head
61-
62-
// recursively process each instantiated body atom of the rule
63-
var node2RDD = child.children.map(
64-
c => (c, processTree(c))).toMap
65-
66-
// and join them
67-
node2RDD.map(_._1).toList.combinations(2).foreach(pair => {
68-
val vars = joinVars(pair(0).element, pair(1).element)
69-
println(vars.mkString("\n"))
70-
})
71-
72-
applyRule(child.element)
73-
})
74-
75-
rdd
76-
}
77-
78-
class RDDRelVisitor(rdd: RDD[Triple]) extends RelVisitor {
79-
override def visit(node: RelNode, ordinal: Int, parent: RelNode): Unit = {
80-
println(node)
81-
82-
val rdd = node match {
83-
case project: BindableProject =>
84-
85-
86-
case join: BindableJoin =>
87-
88-
89-
case filter: BindableFilter =>
90-
val operands = filter.getCondition.asInstanceOf[RexCall].getOperands
91-
92-
case _ =>
93-
}
94-
95-
super.visit(node, ordinal, parent)
96-
}
97-
98-
override def go(node: RelNode): RelNode = super.go(node)
99-
}
100-
101-
private def processRule(rule: Rule) = {
102-
val plan = planGenerator.generateLogicalPlan(rule)
103-
new RDDRelVisitor(graph.triples).go(plan)
104-
}
105-
106-
private def selectedVars(body: TriplePattern, head: TriplePattern): Seq[Int] = {
107-
var selectedIndexes: Seq[Int] = Seq()
108-
109-
val headVars = head.vars()
110-
111-
if(headVars.contains(body.getSubject)) {
112-
selectedIndexes +:= 1
113-
}
114-
if(headVars.contains(body.getPredicate)) {
115-
selectedIndexes +:= 2
116-
}
117-
if(headVars.contains(body.getObject)) {
118-
selectedIndexes +:= 3
119-
}
120-
121-
selectedIndexes
122-
}
123-
124-
private def joinVars(tp1: TriplePattern, tp2: TriplePattern): Seq[(Node, Int, Int)] = {
125-
var joinVars: Seq[(Node, Int, Int)] = Seq()
126-
127-
var tmp: Seq[(Node, Int)] = Seq()
128-
129-
if(tp1.getSubject.isVariable) {
130-
tmp +:= (tp1.getSubject, 1)
131-
}
132-
if(tp1.getPredicate.isVariable) {
133-
tmp +:= (tp1.getPredicate, 2)
134-
}
135-
if(tp1.getObject.isVariable) {
136-
tmp +:= (tp1.getObject, 3)
137-
}
138-
139-
tmp.foreach(entry => {
140-
val node = entry._1
141-
val index = entry._2
142-
if (tp2.getSubject.equals(node)) {
143-
joinVars +:= (node, index, 1)
144-
}
145-
if (tp2.getPredicate.equals(node)) {
146-
joinVars +:= (node, index, 2)
147-
}
148-
if (tp2.getObject.equals(node)) {
149-
joinVars +:= (node, index, 3)
150-
}
151-
})
152-
153-
154-
joinVars
155-
}
156-
157-
private def lookup(tp: TriplePattern): RDD[Triple] = {
158-
graph.find(tp.asTriple())
159-
}
160-
161-
private def buildTree(tree: AndNode, visited: Seq[Rule]): AndNode = {
162-
val tp = tree.element
163-
164-
rules.filterNot(visited.contains(_)).foreach(r => {
165-
// check if the head is more general than the triple in question
166-
var head = r.headTriplePatterns()
167-
168-
head.foreach(headTP => {
169-
val subsumes = headTP.subsumes(tp)
170-
171-
if(subsumes) {
172-
// instantiate the rule
173-
val boundRule = instantiateRule(r, tp)
174-
175-
// add new Or-node to tree
176-
val node = new OrNode(boundRule)
177-
// println(node)
178-
tree.children :+= node
179-
180-
boundRule.bodyTriplePatterns().foreach(newTp => {
181-
node.children :+= buildTree(new AndNode(newTp), visited ++ Seq(r))
182-
})
183-
}
184-
})
185-
186-
})
187-
188-
tree
189-
}
190-
191-
/*
192-
// create a binding for the rule variables
193-
*/
194-
private def instantiateRule(rule: Rule, tp: TriplePattern): Rule = {
195-
val headTP = rule.headTriplePatterns().head // TODO handle rules with multiple head TPs
196-
197-
val binding = new BindingVector(5)
198-
199-
// the subject
200-
if(tp.getSubject.isConcrete && headTP.getSubject.isVariable) {
201-
binding.bind(headTP.getSubject, tp.getSubject)
202-
}
203-
// the predicate
204-
if(tp.getPredicate.isConcrete && headTP.getPredicate.isVariable) {
205-
binding.bind(headTP.getPredicate, tp.getPredicate)
206-
}
207-
// the object
208-
if(tp.getObject.isConcrete && headTP.getObject.isVariable) {
209-
binding.bind(headTP.getObject, tp.getObject)
210-
}
211-
212-
rule.instantiate(binding)
213-
}
214-
215-
private def applyRule(rule: Rule) = {
216-
// convert to SQL
217-
val sqlGenerator = new SimpleSQLGenerator()
218-
val sql = sqlGenerator.generateSQLQuery(rule)
219-
220-
// generate logical execution plan
221-
val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
222-
val plan = planGenerator.generateLogicalPlan(rule)
223-
224-
// apply plan
225-
plan.explain(new RelWriterImpl(new PrintWriter(System.out)))
226-
227-
}
228-
229-
230-
}
231-
232-
object BackwardChainingReasoner {
233-
234-
235-
def main(args: Array[String]): Unit = {
236-
237-
val parallelism = 20
238-
239-
// the SPARK config
240-
val session = SparkSession.builder
241-
.appName(s"Spark Backward Chaining")
242-
.master("local[4]")
243-
// .config("spark.eventLog.enabled", "true")
244-
.config("spark.hadoop.validateOutputSpecs", "false") // override output files
245-
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
246-
.config("spark.default.parallelism", parallelism)
247-
.config("spark.ui.showConsoleProgress", "false")
248-
.config("spark.sql.shuffle.partitions", parallelism)
249-
.getOrCreate()
250-
251-
val graph = RDFGraphLoader.loadFromDisk(session, args(0))
252-
253-
val rules = RuleSets.RDFS_SIMPLE.filter(_.getName == "rdfs2")
254-
255-
val tp = Triple.create(
256-
NodeFactory.createVariable("s"),
257-
RDF.`type`.asNode(),
258-
NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person"))
259-
260-
val reasoner = new BackwardChainingReasoner(rules, graph)
8+
trait BackwardChainingReasoner[T] {
2619

262-
println(reasoner.isEntailed(tp))
10+
def query(tp: TriplePattern): T
26311

264-
session.stop()
265-
}
26612
}

0 commit comments

Comments
 (0)