Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit 86d960e

Browse files
DataFrame-based reasoner impl.
1 parent c898f8f commit 86d960e

2 files changed

Lines changed: 325 additions & 2 deletions

File tree

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
package net.sansa_stack.inference.spark.forwardchaining
2+
3+
import net.sansa_stack.inference.data.RDFTriple
4+
import scala.language.implicitConversions
5+
6+
import org.apache.jena.riot.Lang
7+
import org.apache.jena.vocabulary.{RDF, RDFS}
8+
import org.apache.spark.SparkConf
9+
import org.apache.spark.sql.functions._
10+
import org.apache.spark.sql.{Dataset, SQLContext, SparkSession}
11+
import org.slf4j.LoggerFactory
12+
13+
import net.sansa_stack.inference.spark.data.model.{RDFGraph, RDFGraphDataFrame}
14+
import net.sansa_stack.inference.spark.utils.RDFSSchemaExtractor
15+
16+
17+
18+
/**
19+
* A forward chaining implementation of the RDFS entailment regime.
20+
*
21+
* @constructor create a new RDFS forward chaining reasoner
22+
* @param session the Apache Spark session
23+
* @author Lorenz Buehmann
24+
*/
25+
class ForwardRuleReasonerRDFSDataframe(session: SparkSession, parallelism: Int = 2)
26+
extends TransitiveReasoner(session.sparkContext, parallelism) {
27+
28+
val sqlContext = session.sqlContext
29+
import sqlContext.implicits._
30+
31+
private val logger = com.typesafe.scalalogging.Logger(LoggerFactory.getLogger(this.getClass.getName))
32+
33+
def apply(graph: RDFGraphDataFrame): RDFGraphDataFrame = {
34+
logger.info("materializing graph...")
35+
val startTime = System.currentTimeMillis()
36+
37+
val sqlSchema = graph.schema
38+
39+
val extractor = new RDFSSchemaExtractor()
40+
41+
var index = extractor.extractWithIndex(graph)
42+
43+
var triples = graph.toDataFrame(session).alias("DATA")
44+
45+
// broadcast the tables for the schema triples
46+
index = index.map{ e =>
47+
val property = e._1
48+
val dataframe = e._2
49+
50+
property -> broadcast(dataframe).as(property.getURI)
51+
}
52+
53+
// RDFS rules dependency was analyzed in \todo(add references) and the same ordering is used here
54+
55+
56+
// 1. we first compute the transitive closure of rdfs:subPropertyOf and rdfs:subClassOf
57+
58+
/**
59+
* rdfs11 xxx rdfs:subClassOf yyy .
60+
* yyy rdfs:subClassOf zzz . xxx rdfs:subClassOf zzz .
61+
*/
62+
val subClassOfTriples = index(RDFS.subClassOf.asNode()) // extract rdfs:subClassOf triples
63+
val subClassOfTriplesTrans = broadcast(computeTransitiveClosureDF(subClassOfTriples.as[RDFTriple]).toDF().alias("SC"))
64+
// val subClassOfMap = CollectionUtils.toMultiMap(subClassOfTriplesTrans.rdd.map(r => (r.getString(0) -> (r.getString(2)))).collect)
65+
// val subClassOfMapBC = session.sparkContext.broadcast(subClassOfMap)
66+
// val checkSubclass = udf((cls: String) => subClassOfMapBC.value.contains(cls))
67+
// val makeSuperTypeTriple = udf((ind: String, cls: String) => (ind, subClassOfMapBC.value(cls)))
68+
/*
69+
rdfs5 xxx rdfs:subPropertyOf yyy .
70+
yyy rdfs:subPropertyOf zzz . xxx rdfs:subPropertyOf zzz .
71+
*/
72+
val subPropertyOfTriples = index(RDFS.subPropertyOf.asNode()) // extract rdfs:subPropertyOf triples
73+
val subPropertyOfTriplesTrans = broadcast(computeTransitiveClosureDF(subPropertyOfTriples.as[RDFTriple]).toDF().alias("SP"))
74+
75+
76+
// // a map structure should be more efficient
77+
// val subClassOfMap = subClassOfTriplesTrans.collect().map(row => row(0).asInstanceOf[String] -> row(1).asInstanceOf[String]).toMap
78+
// val subPropertyOfMap = subPropertyOfTriplesTrans.collect().map(row => row(0).asInstanceOf[String] -> row(1).asInstanceOf[String]).toMap
79+
//
80+
// // distribute the schema data structures by means of shared variables
81+
// // the assumption here is that the schema is usually much smaller than the instance data
82+
// val subClassOfMapBC = session.sparkContext.broadcast(subClassOfMap)
83+
// val subPropertyOfMapBC = session.sparkContext.broadcast(subPropertyOfMap)
84+
//
85+
// def containsPredicateAsKey(map: Map[String, String]) = udf((predicate : String) => map.contains(predicate))
86+
// def fillPredicate(map: Map[String, String]) = udf((predicate : String) => if(map.contains(predicate)) map(predicate) else "")
87+
88+
89+
// Broadcast
90+
// val subClassOfTriplesTransDataBC = session.sparkContext.broadcast(subPropertyOfTriplesTrans.collectAsList())
91+
// val subClassOfTriplesTransSchemaBC = session.sparkContext.broadcast(subPropertyOfTriplesTrans.schema)
92+
// val subClassOfTriplesTransBCDF = session.sqlContext.createDataFrame(
93+
// subClassOfTriplesTransDataBC.value,
94+
// subClassOfTriplesTransSchemaBC.value).alias("SCBC")
95+
96+
// 2. SubPropertyOf inheritance according to rdfs7 is computed
97+
98+
/*
99+
rdfs7 aaa rdfs:subPropertyOf bbb .
100+
xxx aaa yyy . xxx bbb yyy .
101+
*/
102+
val triplesRDFS7 =
103+
triples // all triples (s p1 o)
104+
.join(subPropertyOfTriplesTrans, $"DATA.${sqlSchema.predicateCol}" === $"SP.${sqlSchema.subjectCol}", "inner") // such that p1 has a super property p2
105+
.select($"DATA.${sqlSchema.subjectCol}", $"SP.${sqlSchema.objectCol}", $"DATA.${sqlSchema.objectCol}") // create triple (s p2 o)
106+
107+
// val triplesRDFS7 =
108+
// triples // all triples (s p1 o)
109+
// .filter(containsPredicateAsKey(subPropertyOfMapBC.value)($"DATA.predicate")) // such that p1 has a super property p2
110+
// .withColumn("CC", fillPredicate(subPropertyOfMapBC.value)($"DATA.predicate"))
111+
// .select($"DATA.subject", $"CC", $"DATA.object") // create triple (s p2 o)
112+
//
113+
// triplesRDFS7.explain(true)
114+
115+
// add triples
116+
triples = triples.union(triplesRDFS7).alias("DATA")
117+
118+
// 3. Domain and Range inheritance according to rdfs2 and rdfs3 is computed
119+
120+
/*
121+
rdfs2 aaa rdfs:domain xxx .
122+
yyy aaa zzz . yyy rdf:type xxx .
123+
*/
124+
val domainTriples = broadcast(index(RDFS.domain.asNode()).alias("DOM"))
125+
126+
val triplesRDFS2 =
127+
triples
128+
.join(domainTriples, $"DATA.${sqlSchema.predicateCol}" === $"DOM.${sqlSchema.subjectCol}", "inner")
129+
.select($"DATA.${sqlSchema.subjectCol}", $"DOM.${sqlSchema.objectCol}") // (yyy, xxx)
130+
// triples.createOrReplaceTempView("DATA")
131+
// domainTriples.createOrReplaceTempView("DOM")
132+
// val triplesRDFS2 = session.sql("SELECT A.subject, B.object FROM DATA A INNER JOIN DOM B ON A.predicate=B.subject")
133+
// triplesRDFS2.explain(true)
134+
135+
/*
136+
rdfs3 aaa rdfs:range xxx .
137+
yyy aaa zzz . zzz rdf:type xxx .
138+
*/
139+
val rangeTriples = broadcast(index(RDFS.range.asNode()).alias("RAN"))
140+
141+
val triplesRDFS3 =
142+
triples
143+
.join(rangeTriples, $"DATA.${sqlSchema.predicateCol}" === $"RAN.${sqlSchema.subjectCol}", "inner")
144+
.select($"DATA.${sqlSchema.objectCol}", $"RAN.${sqlSchema.objectCol}") // (zzz, xxx)
145+
146+
val tuples23 = triplesRDFS2.union(triplesRDFS3)
147+
148+
// get rdf:type tuples here as intermediate result
149+
val typeTuples = triples
150+
.where(s"${sqlSchema.predicateCol} = '${RDF.`type`.getURI}'")
151+
.select(sqlSchema.subjectCol, sqlSchema.objectCol)
152+
.union(tuples23)
153+
.alias("TYPES")
154+
155+
// 4. SubClass inheritance according to rdfs9
156+
157+
/*
158+
rdfs9 xxx rdfs:subClassOf yyy .
159+
zzz rdf:type xxx . zzz rdf:type yyy .
160+
*/
161+
val tuplesRDFS9 = typeTuples
162+
.join(subClassOfTriplesTrans, $"TYPES.${sqlSchema.objectCol}" === $"SC.${sqlSchema.subjectCol}", "inner")
163+
.select($"TYPES.${sqlSchema.subjectCol}", $"SC.${sqlSchema.objectCol}") // (zzz, yyy)
164+
165+
// val triplesRDFS9 =
166+
// typeTuples
167+
// .where(checkSubclass($"TYPES.object"))
168+
// .map(r => (r.getString(0), subClassOfMapBC.value(r.getString(1)).toArray))
169+
// .toDF("subject", "objects")
170+
// triplesRDFS9.printSchema()
171+
//
172+
// val exploded = triplesRDFS9.flatMap(row => {
173+
// val objects = row.getAs[Array[String]]("objects")
174+
// objects.map(o => (row.getString(0), o))
175+
// }).toDF("subject", "object")
176+
177+
// explode("objects", "object") {
178+
// case Row(classes: Array[Row]) => classes.map(clsRow => clsRow(0).asInstanceOf[String])
179+
// case _ => println("ELSE")
180+
// Seq()
181+
// }
182+
// exploded.show()
183+
184+
// .explode()
185+
// .join(subClassOfTriplesTrans, $"TYPES.object" === $"SC.subject", "inner")
186+
// .withColumn("const", lit(RDF.`type`.getURI))
187+
// .select("DATA.subject", "const", "SC.object")
188+
// .select($"TYPES.subject", $"SC.object") // (zzz, yyy)
189+
// println("existing types:" + existingTypes.count())
190+
// println("SC:" + subClassOfTriplesTrans.count())
191+
// println("SP:" + subPropertyOfTriplesTrans.count())
192+
// println("TYPES:" + typeTuples.count())
193+
// println("R7:" + triplesRDFS7.count())
194+
// println("R2:" + triplesRDFS2.count())
195+
// println("R3:" + triplesRDFS3.count())
196+
// println("R9:" + tuplesRDFS9.count())
197+
198+
// 5. merge triples and remove duplicates
199+
val allTriples =
200+
typeTuples.union(tuples23).union(tuplesRDFS9)
201+
.withColumn("const", lit(RDF.`type`.getURI))
202+
.select(sqlSchema.subjectCol, "const", sqlSchema.objectCol)
203+
.union(subClassOfTriplesTrans)
204+
.union(subPropertyOfTriplesTrans)
205+
.union(triplesRDFS7)
206+
.union(triples)
207+
.distinct()
208+
// .selectExpr("subject", "'" + RDF.`type`.getURI + "' as predicate", "object")
209+
// allTriples.explain()
210+
211+
logger.info("...finished materialization in " + (System.currentTimeMillis() - startTime) + "ms.")
212+
// val newSize = allTriples.count()
213+
// logger.info(s"|G_inf|=$newSize")
214+
215+
// return graph with inferred triples
216+
new RDFGraphDataFrame(allTriples)
217+
}
218+
219+
/**
220+
* Computes the transitive closure for a Dataframe of triples
221+
*
222+
* @param edges the Dataframe of triples
223+
* @return a Dataframe containing the transitive closure of the triples
224+
*/
225+
def computeTransitiveClosureDF(edges: Dataset[RDFTriple]): Dataset[RDFTriple] = {
226+
log.info("computing TC...")
227+
// implicit val myObjEncoder = org.apache.spark.sql.Encoders.kryo[RDFTriple]
228+
val spark = edges.sparkSession.sqlContext
229+
import spark.implicits._
230+
231+
// profile {
232+
// we keep the transitive closure cached
233+
var tc = edges
234+
tc.cache()
235+
236+
// the join is iterated until a fixed point is reached
237+
var i = 1
238+
var oldCount = 0L
239+
var nextCount = tc.count()
240+
do {
241+
log.info(s"iteration $i...")
242+
oldCount = nextCount
243+
244+
// val df1 = tc.alias("df1")
245+
// val df2 = tc.alias("df2")
246+
// perform the join (x, y) x (y, x), obtaining an RDD of (x=y, (y, x)) pairs,
247+
// then project the result to obtain the new (x, y) paths.
248+
249+
tc.createOrReplaceTempView("SC")
250+
var joined = tc.as("A").join(tc.as("B"), $"A.o" === $"B.s").select("A.s", "A.p", "B.o").as[RDFTriple]
251+
// var joined = tc
252+
// .join(edges, tc("o") === edges("s"))
253+
// .select(tc("s"), tc("p"), edges("o"))
254+
// .as[RDFTriple]
255+
// tc.sqlContext.
256+
// sql("SELECT A.subject, A.predicate, B.object FROM SC A INNER JOIN SC B ON A.object = B.subject")
257+
258+
// joined.explain()
259+
// var joined = df1.join(df2, df1("object") === df2("subject"), "inner")
260+
// println("JOINED:\n" + joined.collect().mkString("\n"))
261+
// joined = joined.select(df2(s"df1.$col1"), df1(s"df1.$col2"))
262+
// println(joined.collect().mkString("\n"))
263+
264+
tc = tc
265+
.union(joined)
266+
.distinct()
267+
.cache()
268+
nextCount = tc.count()
269+
i += 1
270+
} while (nextCount != oldCount)
271+
272+
tc.sqlContext.uncacheTable("SC")
273+
log.info("TC has " + nextCount + " edges.")
274+
tc
275+
// }
276+
}
277+
278+
/**
279+
* Applies forward chaining to the given RDF graph and returns a new RDF graph that contains all additional
280+
* triples based on the underlying set of rules.
281+
*
282+
* @param graph the RDF graph
283+
* @return the materialized RDF graph
284+
*/
285+
override def apply(graph: RDFGraph): RDFGraph = graph
286+
287+
}
288+
289+
object ForwardRuleReasonerRDFSDataframe {
290+
def apply(session: SparkSession, parallelism: Int = 2): ForwardRuleReasonerRDFSDataframe = new ForwardRuleReasonerRDFSDataframe(session, parallelism)
291+
292+
def main(args: Array[String]): Unit = {
293+
import net.sansa_stack.inference.spark.data.loader.sql.rdf._
294+
295+
val parallelism = 2
296+
297+
// register the custom classes for Kryo serializer
298+
val conf = new SparkConf()
299+
conf.registerKryoClasses(Array(classOf[org.apache.jena.graph.Triple]))
300+
conf.set("spark.extraListeners", "net.sansa_stack.inference.spark.utils.CustomSparkListener")
301+
302+
// the SPARK config
303+
val session = SparkSession.builder
304+
.appName(s"SPARK DataFrame-based RDFS Reasoning")
305+
.master("local[4]")
306+
// .config("spark.eventLog.enabled", "true")
307+
.config("spark.hadoop.validateOutputSpecs", "false") // override output files
308+
.config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
309+
.config("spark.default.parallelism", parallelism)
310+
.config("spark.ui.showConsoleProgress", "false")
311+
.config("spark.sql.shuffle.partitions", parallelism)
312+
.config(conf)
313+
.getOrCreate()
314+
315+
val triples = session.read.rdf(Lang.NTRIPLES)(args(0))
316+
triples.createOrReplaceTempView("TRIPLES")
317+
318+
val graph = new RDFGraphDataFrame(triples)
319+
320+
321+
val infGraph = ForwardRuleReasonerRDFSDataframe(session).apply(graph)
322+
println(infGraph.size())
323+
}
324+
}

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/rules/plan/PlanExecutorNative.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,8 +349,7 @@ class PlanExecutorNative(sc: SparkContext) extends PlanExecutor[Jena, RDD[Triple
349349
projectList.toList
350350
case logical.Filter(condition, child) =>
351351
expressionsFor(child)
352-
case SubqueryAlias(alias: String, child: LogicalPlan,
353-
view: scala.Option[org.apache.spark.sql.catalyst.TableIdentifier]) =>
352+
case SubqueryAlias(alias: String, child: LogicalPlan) =>
354353
expressionsFor(child)
355354
case _ =>
356355
logicalPlan.expressions.toList

0 commit comments

Comments
 (0)