11package net .sansa_stack .inference .spark .backwardchaining
22
3- import net .sansa_stack .inference .rules .RuleSets
4- import net .sansa_stack .inference .rules .plan .{SimplePlanGenerator , SimpleSQLGenerator , TriplesSchema }
5- import net .sansa_stack .inference .spark .backwardchaining .tree .{AndNode , OrNode }
6- import net .sansa_stack .inference .spark .data .loader .RDFGraphLoader
7- import net .sansa_stack .inference .utils .{Logging , TripleUtils }
8- import net .sansa_stack .inference .utils .RuleUtils ._
9- import org .apache .jena .graph .{NodeFactory , Triple }
3+ import org .apache .jena .graph .{Node , NodeFactory , Triple }
104import org .apache .jena .reasoner .TriplePattern
115import org .apache .jena .reasoner .rulesys .Rule
126import org .apache .jena .reasoner .rulesys .impl .BindingVector
13- import org .apache .jena .vocabulary .RDF
7+ import org .apache .jena .sparql .util .FmtUtils
8+ import org .apache .jena .vocabulary .{RDF , RDFS }
149import org .apache .spark .sql .{Dataset , SparkSession }
15- import net .sansa_stack .inference .utils .TripleUtils ._
10+
11+ import net .sansa_stack .inference .rules .RuleSets
12+ import net .sansa_stack .inference .rules .plan .SimpleSQLGenerator
13+ import net .sansa_stack .inference .spark .backwardchaining .tree .{AndNode , OrNode }
14+ import net .sansa_stack .inference .spark .data .loader .RDFGraphLoader
15+ import net .sansa_stack .inference .utils .RuleUtils ._
16+ import net .sansa_stack .inference .utils .{Logging , TripleUtils }
1617
1718
1819// case class RDFTriple(s: Node, p: Node, o: Node)
@@ -26,24 +27,31 @@ class BackwardChainingReasonerDataframe(
2627 val rules : Set [Rule ],
2728 val graph : Dataset [RDFTriple ]) extends Logging {
2829
30+ import org .apache .spark .sql .functions ._
31+
32+ val precomputeSchema : Boolean = true
33+
34+ var schema : Map [Node , Dataset [RDFTriple ]] = Map ()
35+
2936 def isEntailed (triple : Triple ): Boolean = {
3037 isEntailed(new TriplePattern (triple))
3138 }
3239
3340 def isEntailed (tp : TriplePattern ): Boolean = {
3441
42+ if (precomputeSchema) schema = extractWithIndex(graph)
43+
3544 val tree = buildTree(new AndNode (tp), Seq ())
3645 println(tree.toString)
3746
3847 val triples = processTree(tree)
3948 triples.explain(true )
40- println(triples.count())
49+ println(triples.distinct(). count())
4150
4251 false
4352 }
4453
45- import org .apache .spark .sql .functions ._
46- val planGenerator = new SimplePlanGenerator (TriplesSchema .get())
54+
4755
4856 private def processTree (tree : AndNode ): Dataset [RDFTriple ] = {
4957 // 1. look for asserted triples in the graph
@@ -58,10 +66,11 @@ class BackwardChainingReasonerDataframe(
5866 // first process the children, i.e. we get the data for each triple pattern in the body of the rule
5967 val childrenTriples : Seq [Dataset [RDFTriple ]] = child.children.map(processTree(_))
6068
61- val union = childrenTriples.reduce(_ union _)
69+ val baseTriples = if (childrenTriples.size > 1 ) childrenTriples.reduce(_ union _) else childrenTriples.head
70+
6271
6372 // then apply the rule on the UNION of the children data
64- applyRule(child.element, union )
73+ applyRule(child.element, baseTriples )
6574 })
6675
6776 var triples = assertedTriples
@@ -72,14 +81,25 @@ class BackwardChainingReasonerDataframe(
7281 }
7382
7483 private def lookup (tp : TriplePattern ): Dataset [RDFTriple ] = {
84+
85+ val terminological = TripleUtils .isTerminological(tp.asTriple())
86+
87+ var filteredGraph =
88+ if (terminological) {
89+ schema.getOrElse(tp.getPredicate, graph)
90+ } else {
91+ graph
92+ }
93+
94+ info(s " Lookup data for $tp" )
7595 val s = tp.getSubject.toString()
7696 val p = tp.getPredicate.toString()
7797 val o = tp.getObject.toString()
78- var filteredGraph = graph
98+
7999 if (tp.getSubject.isConcrete) {
80100 filteredGraph = filteredGraph.filter(t => t.s.equals(s))
81101 }
82- if (tp.getPredicate.isConcrete) {
102+ if (! terminological && tp.getPredicate.isConcrete) {
83103 filteredGraph = filteredGraph.filter(t => t.p.equals(p))
84104 }
85105 if (tp.getObject.isConcrete) {
@@ -147,7 +167,7 @@ class BackwardChainingReasonerDataframe(
147167 private def applyRule (rule : Rule , dataset : Dataset [RDFTriple ]): Dataset [RDFTriple ] = {
148168 // convert to SQL
149169 val sqlGenerator = new SimpleSQLGenerator ()
150- val sql = sqlGenerator.generateSQLQuery(rule)
170+ var sql = sqlGenerator.generateSQLQuery(rule)
151171// val sql =
152172// """
153173// |SELECT rel0.s, 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AS p, 'http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AS o
@@ -156,12 +176,49 @@ class BackwardChainingReasonerDataframe(
156176// """.stripMargin
157177
158178 // generate logical execution plan
159- val planGenerator = new SimplePlanGenerator (TriplesSchema .get())
160- val plan = planGenerator.generateLogicalPlan(rule)
179+ // val planGenerator = new SimplePlanGenerator(TriplesSchema.get())
180+ // val plan = planGenerator.generateLogicalPlan(rule)
161181
182+ val tableName = s " TRIPLES_ ${rule.getName}"
183+ sql = sql.replace(" TRIPLES" , tableName)
184+ println(s " SQL NEW: $sql" )
185+ dataset.createOrReplaceTempView(tableName)
162186 dataset.sparkSession.sql(sql).as[RDFTriple ]
163187 }
164188
189+ val properties = Set (RDFS .subClassOf, RDFS .subPropertyOf, RDFS .domain, RDFS .range).map(p => p.asNode())
190+ val DUMMY_VAR = NodeFactory .createVariable(" VAR" );
191+
192+ /**
193+ * Computes the triples for each schema property p, e.g. `rdfs:subClassOf` and returns it as mapping from p
194+ * to the [[Dataset ]] containing the triples.
195+ *
196+ * @param graph the RDF graph
197+ * @return a mapping from the corresponding schema property to the Dataframe of s-o pairs
198+ */
199+ def extractWithIndex (graph : Dataset [RDFTriple ]): Map [Node , Dataset [RDFTriple ]] = {
200+ log.info(" Started schema extraction..." )
201+
202+ // for each schema property p
203+ val index =
204+ properties.map { p =>
205+ // get triples (s, p, o)
206+ var triples = lookup(new TriplePattern (DUMMY_VAR , p, DUMMY_VAR ))
207+
208+ // broadcast the triples
209+ triples = broadcast(triples)
210+
211+ // register as a table
212+ triples.createOrReplaceTempView(FmtUtils .stringForNode(p).replace(" :" , " _" ))
213+
214+ // add to index
215+ (p -> triples)
216+ }
217+ log.info(" Finished schema extraction." )
218+
219+ index.toMap
220+ }
221+
165222
166223}
167224
@@ -182,6 +239,7 @@ object BackwardChainingReasonerDataframe {
182239 .config(" spark.default.parallelism" , parallelism)
183240 .config(" spark.ui.showConsoleProgress" , " false" )
184241 .config(" spark.sql.shuffle.partitions" , parallelism)
242+ .config(" spark.sql.autoBroadcastJoinThreshold" , " 10485760" )
185243 .getOrCreate()
186244
187245 import session .implicits ._
@@ -192,9 +250,32 @@ object BackwardChainingReasonerDataframe {
192250 .triples.map(t => RDFTriple (t.getSubject.toString(), t.getPredicate.toString(), t.getObject.toString()))
193251 val graph = session.createDataset(triples).cache()
194252 graph.createOrReplaceTempView(" TRIPLES" )
253+ import org .apache .spark .sql .functions ._
254+ val domain = graph.filter(t => t.p == RDFS .domain.toString)
255+ broadcast(domain).createOrReplaceTempView(" DOMAIN" )
256+
257+ val query =
258+ """
259+ |SELECT rel0.s AS s, 'http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AS p, 'http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AS o
260+ | FROM DOMAIN rel1 JOIN TRIPLES rel0 ON rel1.s=rel0.p
261+ | WHERE rel1.o='http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person' AND rel1.p='http://www.w3.org/2000/01/rdf-schema#domain'
262+ | UNION
263+ | SELECT *
264+ | FROM TRIPLES
265+ | WHERE p ='http://www.w3.org/1999/02/22-rdf-syntax-ns#type' AND o ='http://swat.cse.lehigh.edu/onto/univ-bench.owl#Person'
266+ """ .stripMargin
267+
268+ val ds = session.sql(query)
269+ ds.explain()
270+ println(ds.distinct().count())
271+
195272
196273 val rules = RuleSets .RDFS_SIMPLE
197- .filter(r => Seq (" rdfs2" , " rdfs3" ).contains(r.getName))
274+ .filter(r => Seq (
275+ " rdfs2"
276+ // , "rdfs3"
277+ // , "rdfs9"
278+ ).contains(r.getName))
198279
199280 val tp = Triple .create(
200281 NodeFactory .createVariable(" s" ),
@@ -206,6 +287,8 @@ object BackwardChainingReasonerDataframe {
206287 println(reasoner.isEntailed(tp))
207288
208289 session.stop()
290+
291+
209292 }
210293}
211294
0 commit comments