|
| 1 | +package net.sansa_stack.inference.spark.backwardchaining |
| 2 | + |
| 3 | +import net.sansa_stack.inference.rules.RuleSets |
| 4 | +import net.sansa_stack.inference.rules.plan.{SimplePlanGenerator, SimpleSQLGenerator, TriplesSchema} |
| 5 | +import net.sansa_stack.inference.spark.backwardchaining.tree.{AndNode, OrNode} |
| 6 | +import net.sansa_stack.inference.spark.data.loader.RDFGraphLoader |
| 7 | +import net.sansa_stack.inference.utils.{Logging, TripleUtils} |
| 8 | +import net.sansa_stack.inference.utils.RuleUtils._ |
| 9 | +import org.apache.jena.graph.{NodeFactory, Triple} |
| 10 | +import org.apache.jena.reasoner.TriplePattern |
| 11 | +import org.apache.jena.reasoner.rulesys.Rule |
| 12 | +import org.apache.jena.reasoner.rulesys.impl.BindingVector |
| 13 | +import org.apache.jena.vocabulary.RDF |
| 14 | +import org.apache.spark.sql.{Dataset, SparkSession} |
| 15 | +import net.sansa_stack.inference.utils.TripleUtils._ |
| 16 | + |
| 17 | + |
| 18 | +//case class RDFTriple(s: Node, p: Node, o: Node) |
| 19 | +case class RDFTriple(s: String, p: String, o: String) |
| 20 | + |
| 21 | +/** |
| 22 | + * @author Lorenz Buehmann |
| 23 | + */ |
| 24 | +class BackwardChainingReasonerDataframe( |
| 25 | + val session: SparkSession, |
| 26 | + val rules: Set[Rule], |
| 27 | + val graph: Dataset[RDFTriple]) extends Logging { |
| 28 | + |
| 29 | + def isEntailed(triple: Triple): Boolean = { |
| 30 | + isEntailed(new TriplePattern(triple)) |
| 31 | + } |
| 32 | + |
| 33 | + def isEntailed(tp: TriplePattern): Boolean = { |
| 34 | + |
| 35 | + val tree = buildTree(new AndNode(tp), Seq()) |
| 36 | + println(tree.toString) |
| 37 | + |
| 38 | + val triples = processTree(tree) |
| 39 | + triples.explain(true) |
| 40 | + println(triples.count()) |
| 41 | + |
| 42 | + false |
| 43 | + } |
| 44 | + |
| 45 | + import org.apache.spark.sql.functions._ |
| 46 | + val planGenerator = new SimplePlanGenerator(TriplesSchema.get()) |
| 47 | + |
| 48 | + private def processTree(tree: AndNode): Dataset[RDFTriple] = { |
| 49 | + // 1. look for asserted triples in the graph |
| 50 | + val assertedTriples = lookup(tree.element) |
| 51 | + if(TripleUtils.isTerminological(tree.element.asTriple())) broadcast(assertedTriples) |
| 52 | + |
| 53 | + |
| 54 | + // 2. process the inference rules that can infer the triple pattern |
| 55 | + val inferredTriples = tree.children.map(child => { |
| 56 | + println(s"processing rule ${child.element}") |
| 57 | + |
| 58 | + // first process the children, i.e. we get the data for each triple pattern in the body of the rule |
| 59 | + val childrenTriples: Seq[Dataset[RDFTriple]] = child.children.map(processTree(_)) |
| 60 | + |
| 61 | + val union = childrenTriples.reduce(_ union _) |
| 62 | + |
| 63 | + // then apply the rule on the UNION of the children data |
| 64 | + applyRule(child.element, union) |
| 65 | + }) |
| 66 | + |
| 67 | + var triples = assertedTriples |
| 68 | + |
| 69 | + if(inferredTriples.nonEmpty) triples = triples.union(inferredTriples.reduce(_ union _)) |
| 70 | + |
| 71 | + triples |
| 72 | + } |
| 73 | + |
| 74 | + private def lookup(tp: TriplePattern): Dataset[RDFTriple] = { |
| 75 | + val s = tp.getSubject.toString() |
| 76 | + val p = tp.getPredicate.toString() |
| 77 | + val o = tp.getObject.toString() |
| 78 | + var filteredGraph = graph |
| 79 | + if(tp.getSubject.isConcrete) { |
| 80 | + filteredGraph = filteredGraph.filter(t => t.s.equals(s)) |
| 81 | + } |
| 82 | + if(tp.getPredicate.isConcrete) { |
| 83 | + filteredGraph = filteredGraph.filter(t => t.p.equals(p)) |
| 84 | + } |
| 85 | + if(tp.getObject.isConcrete) { |
| 86 | + filteredGraph = filteredGraph.filter(t => t.o.equals(o)) |
| 87 | + } |
| 88 | + filteredGraph |
| 89 | + } |
| 90 | + |
| 91 | + private def buildTree(tree: AndNode, visited: Seq[Rule]): AndNode = { |
| 92 | + val tp = tree.element |
| 93 | + |
| 94 | + rules.filterNot(visited.contains(_)).foreach(r => { |
| 95 | + // check if the head is more general than the triple in question |
| 96 | + var head = r.headTriplePatterns() |
| 97 | + |
| 98 | + head.foreach(headTP => { |
| 99 | + val subsumes = headTP.subsumes(tp) |
| 100 | + |
| 101 | + if(subsumes) { |
| 102 | + // instantiate the rule |
| 103 | + val boundRule = instantiateRule(r, tp) |
| 104 | + |
| 105 | + // add new Or-node to tree |
| 106 | + val node = new OrNode(boundRule) |
| 107 | + // println(node) |
| 108 | + tree.children :+= node |
| 109 | + |
| 110 | + boundRule.bodyTriplePatterns().foreach(newTp => { |
| 111 | + node.children :+= buildTree(new AndNode(newTp), visited ++ Seq(r)) |
| 112 | + }) |
| 113 | + } |
| 114 | + }) |
| 115 | + |
| 116 | + }) |
| 117 | + |
| 118 | + tree |
| 119 | + } |
| 120 | + |
| 121 | + /* |
| 122 | + // create a binding for the rule variables |
| 123 | + */ |
| 124 | + private def instantiateRule(rule: Rule, tp: TriplePattern): Rule = { |
| 125 | + val headTP = rule.headTriplePatterns().head // TODO handle rules with multiple head TPs |
| 126 | + |
| 127 | + val binding = new BindingVector(5) |
| 128 | + |
| 129 | + // the subject |
| 130 | + if(tp.getSubject.isConcrete && headTP.getSubject.isVariable) { |
| 131 | + binding.bind(headTP.getSubject, tp.getSubject) |
| 132 | + } |
| 133 | + // the predicate |
| 134 | + if(tp.getPredicate.isConcrete && headTP.getPredicate.isVariable) { |
| 135 | + binding.bind(headTP.getPredicate, tp.getPredicate) |
| 136 | + } |
| 137 | + // the object |
| 138 | + if(tp.getObject.isConcrete && headTP.getObject.isVariable) { |
| 139 | + binding.bind(headTP.getObject, tp.getObject) |
| 140 | + } |
| 141 | + |
| 142 | + rule.instantiate(binding) |
| 143 | + } |
| 144 | + |
| 145 | + import session.implicits._ |
| 146 | + |
| 147 | + private def applyRule(rule: Rule, dataset: Dataset[RDFTriple]): Dataset[RDFTriple] = { |
| 148 | + // convert to SQL |
| 149 | + val sqlGenerator = new SimpleSQLGenerator() |
| 150 | + val sql = sqlGenerator.generateSQLQuery(rule) |
| 151 | +// val sql = |
| 152 | +// """ |
| 153 | +// |SELECT rel0.s, 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AS p, 'http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AS o |
| 154 | +// | FROM TRIPLES rel1 INNER JOIN TRIPLES rel0 ON rel1.s=rel0.p |
| 155 | +// | WHERE rel1.o='http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AND rel1.p='http://www.w3.org/2000/01/rdf-schema#domain' |
| 156 | +// """.stripMargin |
| 157 | + |
| 158 | + // generate logical execution plan |
| 159 | + val planGenerator = new SimplePlanGenerator(TriplesSchema.get()) |
| 160 | + val plan = planGenerator.generateLogicalPlan(rule) |
| 161 | + |
| 162 | + dataset.sparkSession.sql(sql).as[RDFTriple] |
| 163 | + } |
| 164 | + |
| 165 | + |
| 166 | +} |
| 167 | + |
| 168 | +object BackwardChainingReasonerDataframe { |
| 169 | + |
| 170 | + |
| 171 | + def main(args: Array[String]): Unit = { |
| 172 | + |
| 173 | + val parallelism = 200 |
| 174 | + |
| 175 | + // the SPARK config |
| 176 | + val session = SparkSession.builder |
| 177 | + .appName(s"Spark Backward Chaining") |
| 178 | + .master("local[4]") |
| 179 | + // .config("spark.eventLog.enabled", "true") |
| 180 | + .config("spark.hadoop.validateOutputSpecs", "false") // override output files |
| 181 | + .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") |
| 182 | + .config("spark.default.parallelism", parallelism) |
| 183 | + .config("spark.ui.showConsoleProgress", "false") |
| 184 | + .config("spark.sql.shuffle.partitions", parallelism) |
| 185 | + .getOrCreate() |
| 186 | + |
| 187 | + import session.implicits._ |
| 188 | +// implicit val myObjEncoder = org.apache.spark.sql.Encoders.kryo[RDFTriple] |
| 189 | + |
| 190 | + val triples = RDFGraphLoader.loadFromDisk(session, args(0)) |
| 191 | +// .triples.map(t => RDFTriple(t.getSubject, t.getPredicate, t.getObject)) |
| 192 | + .triples.map(t => RDFTriple(t.getSubject.toString(), t.getPredicate.toString(), t.getObject.toString())) |
| 193 | + val graph = session.createDataset(triples).cache() |
| 194 | + graph.createOrReplaceTempView("TRIPLES") |
| 195 | + |
| 196 | + val rules = RuleSets.RDFS_SIMPLE |
| 197 | + .filter(r => Seq("rdfs2", "rdfs3").contains(r.getName)) |
| 198 | + |
| 199 | + val tp = Triple.create( |
| 200 | + NodeFactory.createVariable("s"), |
| 201 | + RDF.`type`.asNode(), |
| 202 | + NodeFactory.createURI("http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person")) |
| 203 | + |
| 204 | + val reasoner = new BackwardChainingReasonerDataframe(session, rules, graph) |
| 205 | + |
| 206 | + println(reasoner.isEntailed(tp)) |
| 207 | + |
| 208 | + session.stop() |
| 209 | + } |
| 210 | +} |
| 211 | + |
0 commit comments