Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit 3498a6b

Browse files
Option to use flatmap + BC var instead of join in BC alg.
1 parent c66eb87 commit 3498a6b

1 file changed

Lines changed: 29 additions & 14 deletions

File tree

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/backwardchaining/BackwardChainingReasonerDataframe.scala

Lines changed: 29 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ import net.sansa_stack.inference.rules.RuleSets
55
import net.sansa_stack.inference.rules.plan.SimpleSQLGenerator
66
import net.sansa_stack.inference.spark.backwardchaining.tree.{AndNode, OrNode}
77
import net.sansa_stack.inference.utils.RuleUtils._
8-
import net.sansa_stack.inference.utils.{Logging, TripleUtils}
8+
import net.sansa_stack.inference.utils.{CollectionUtils, Logging, TripleUtils}
99
import org.apache.jena.graph.{Node, NodeFactory, Triple}
1010
import org.apache.jena.rdf.model.Resource
1111
import org.apache.jena.reasoner.TriplePattern
@@ -15,6 +15,7 @@ import org.apache.jena.sparql.util.FmtUtils
1515
import org.apache.jena.vocabulary.{RDF, RDFS}
1616
import org.apache.spark.sql.{Dataset, SparkSession}
1717

18+
import scala.collection.mutable
1819
import scala.concurrent.duration.FiniteDuration
1920

2021

@@ -296,11 +297,16 @@ class BackwardChainingReasonerDataframe(
296297
if(tp.getSubject.isConcrete) { // find triples where s occurs as subject or object
297298
instanceTriples = instanceTriples.filter(t => t.s == tp.getSubject.toString() || t.o == tp.getSubject.toString())
298299
}
299-
val rdfs7 = spo
300-
.join(instanceTriples.alias("DATA"), $"SPO.s" === $"DATA.p", "inner")
301-
.select($"DATA.s".alias("s"), $"SPO.o".alias("p"), $"DATA.s".alias("o"))
302-
.as[RDFTriple]
303-
instanceTriples = instanceTriples.union(rdfs7).cache()
300+
val spoBC = session.sparkContext.broadcast(
301+
CollectionUtils.toMultiMap(spo.select("s", "o").collect().map(r => (r.getString(0), r.getString(1))))
302+
)
303+
val rdfs7 = instanceTriples.flatMap(t => spoBC.value.getOrElse(t.p, Set[String]()).map(supProp => RDFTriple(t.s, supProp, t.o)))
304+
305+
// val rdfs7 = spo
306+
// .join(instanceTriples.alias("DATA"), $"SPO.s" === $"DATA.p", "inner")
307+
// .select($"DATA.s".alias("s"), $"SPO.o".alias("p"), $"DATA.s".alias("o"))
308+
// .as[RDFTriple]
309+
// instanceTriples = instanceTriples.union(rdfs7).cache()
304310

305311
// rdfs2 (domain)
306312
var dom = if (tp.getObject.isConcrete) domain.filter(_.o == tp.getObject.toString()) else domain
@@ -318,10 +324,17 @@ class BackwardChainingReasonerDataframe(
318324
} else {
319325
instanceTriples
320326
}
321-
val rdfs2 = dom
322-
.join(data.alias("DATA"), $"DOM.s" === $"DATA.p", "inner")
323-
.select($"DATA.s", lit(RDF.`type`.toString).alias("p"), dom("o").alias("o"))
324-
.as[RDFTriple]
327+
328+
val rdftype = RDF.`type`.toString
329+
330+
// val rdfs2 = dom
331+
// .join(data.alias("DATA"), $"DOM.s" === $"DATA.p", "inner")
332+
// .select($"DATA.s", lit(RDF.`type`.toString).alias("p"), dom("o").alias("o"))
333+
// .as[RDFTriple]
334+
val domBC = session.sparkContext.broadcast(
335+
CollectionUtils.toMultiMap(dom.select("s", "o").collect().map(r => (r.getString(0), r.getString(1))))
336+
)
337+
val rdfs2 = data.flatMap(t => domBC.value.getOrElse(t.p, Set[String]()).map(o => RDFTriple(t.s, rdftype, o)))
325338

326339
// rdfs3 (range)
327340
var ran = if (tp.getObject.isConcrete) range.filter(_.o == tp.getObject.toString()) else range
@@ -339,10 +352,12 @@ class BackwardChainingReasonerDataframe(
339352
} else {
340353
instanceTriples
341354
}
342-
val rdfs3 = ran
343-
.join(data.alias("DATA"), $"RAN.s" === $"DATA.p", "inner")
344-
.select($"DATA.o".alias("s"), lit(RDF.`type`.toString).alias("p"), ran("o").alias("o"))
345-
.as[RDFTriple]
355+
// val rdfs3 = ran
356+
// .join(data.alias("DATA"), $"RAN.s" === $"DATA.p", "inner")
357+
// .select($"DATA.o".alias("s"), lit(RDF.`type`.toString).alias("p"), ran("o").alias("o"))
358+
// .as[RDFTriple]
359+
val ranBC = session.sparkContext.broadcast(CollectionUtils.toMultiMap(ran.select("s", "o").collect().map(r => (r.getString(0), r.getString(1)))))
360+
val rdfs3 = data.flatMap(t => ranBC.value.getOrElse(t.p, Set[String]()).map(o => RDFTriple(t.o, rdftype, o)))
346361

347362
// all rdf:type triples
348363
val types = rdfs2

0 commit comments

Comments
 (0)