@@ -5,7 +5,7 @@ import net.sansa_stack.inference.rules.RuleSets
55import net .sansa_stack .inference .rules .plan .SimpleSQLGenerator
66import net .sansa_stack .inference .spark .backwardchaining .tree .{AndNode , OrNode }
77import net .sansa_stack .inference .utils .RuleUtils ._
8- import net .sansa_stack .inference .utils .{Logging , TripleUtils }
8+ import net .sansa_stack .inference .utils .{CollectionUtils , Logging , TripleUtils }
99import org .apache .jena .graph .{Node , NodeFactory , Triple }
1010import org .apache .jena .rdf .model .Resource
1111import org .apache .jena .reasoner .TriplePattern
@@ -15,6 +15,7 @@ import org.apache.jena.sparql.util.FmtUtils
1515import org .apache .jena .vocabulary .{RDF , RDFS }
1616import org .apache .spark .sql .{Dataset , SparkSession }
1717
18+ import scala .collection .mutable
1819import scala .concurrent .duration .FiniteDuration
1920
2021
@@ -296,11 +297,16 @@ class BackwardChainingReasonerDataframe(
296297 if (tp.getSubject.isConcrete) { // find triples where s occurs as subject or object
297298 instanceTriples = instanceTriples.filter(t => t.s == tp.getSubject.toString() || t.o == tp.getSubject.toString())
298299 }
299- val rdfs7 = spo
300- .join(instanceTriples.alias(" DATA" ), $" SPO.s" === $" DATA.p" , " inner" )
301- .select($" DATA.s" .alias(" s" ), $" SPO.o" .alias(" p" ), $" DATA.s" .alias(" o" ))
302- .as[RDFTriple ]
303- instanceTriples = instanceTriples.union(rdfs7).cache()
300+ val spoBC = session.sparkContext.broadcast(
301+ CollectionUtils .toMultiMap(spo.select(" s" , " o" ).collect().map(r => (r.getString(0 ), r.getString(1 ))))
302+ )
303+ val rdfs7 = instanceTriples.flatMap(t => spoBC.value.getOrElse(t.p, Set [String ]()).map(supProp => RDFTriple (t.s, supProp, t.o)))
304+
305+ // val rdfs7 = spo
306+ // .join(instanceTriples.alias("DATA"), $"SPO.s" === $"DATA.p", "inner")
307+ // .select($"DATA.s".alias("s"), $"SPO.o".alias("p"), $"DATA.s".alias("o"))
308+ // .as[RDFTriple]
309+ // instanceTriples = instanceTriples.union(rdfs7).cache()
304310
305311 // rdfs2 (domain)
306312 var dom = if (tp.getObject.isConcrete) domain.filter(_.o == tp.getObject.toString()) else domain
@@ -318,10 +324,17 @@ class BackwardChainingReasonerDataframe(
318324 } else {
319325 instanceTriples
320326 }
321- val rdfs2 = dom
322- .join(data.alias(" DATA" ), $" DOM.s" === $" DATA.p" , " inner" )
323- .select($" DATA.s" , lit(RDF .`type`.toString).alias(" p" ), dom(" o" ).alias(" o" ))
324- .as[RDFTriple ]
327+
328+ val rdftype = RDF .`type`.toString
329+
330+ // val rdfs2 = dom
331+ // .join(data.alias("DATA"), $"DOM.s" === $"DATA.p", "inner")
332+ // .select($"DATA.s", lit(RDF.`type`.toString).alias("p"), dom("o").alias("o"))
333+ // .as[RDFTriple]
334+ val domBC = session.sparkContext.broadcast(
335+ CollectionUtils .toMultiMap(dom.select(" s" , " o" ).collect().map(r => (r.getString(0 ), r.getString(1 ))))
336+ )
337+ val rdfs2 = data.flatMap(t => domBC.value.getOrElse(t.p, Set [String ]()).map(o => RDFTriple (t.s, rdftype, o)))
325338
326339 // rdfs3 (range)
327340 var ran = if (tp.getObject.isConcrete) range.filter(_.o == tp.getObject.toString()) else range
@@ -339,10 +352,12 @@ class BackwardChainingReasonerDataframe(
339352 } else {
340353 instanceTriples
341354 }
342- val rdfs3 = ran
343- .join(data.alias(" DATA" ), $" RAN.s" === $" DATA.p" , " inner" )
344- .select($" DATA.o" .alias(" s" ), lit(RDF .`type`.toString).alias(" p" ), ran(" o" ).alias(" o" ))
345- .as[RDFTriple ]
355+ // val rdfs3 = ran
356+ // .join(data.alias("DATA"), $"RAN.s" === $"DATA.p", "inner")
357+ // .select($"DATA.o".alias("s"), lit(RDF.`type`.toString).alias("p"), ran("o").alias("o"))
358+ // .as[RDFTriple]
359+ val ranBC = session.sparkContext.broadcast(CollectionUtils .toMultiMap(ran.select(" s" , " o" ).collect().map(r => (r.getString(0 ), r.getString(1 )))))
360+ val rdfs3 = data.flatMap(t => ranBC.value.getOrElse(t.p, Set [String ]()).map(o => RDFTriple (t.o, rdftype, o)))
346361
347362 // all rdf:type triples
348363 val types = rdfs2
0 commit comments