1+ package net .sansa_stack .inference .spark .forwardchaining
2+
3+ import net .sansa_stack .inference .data .RDFTriple
4+ import scala .language .implicitConversions
5+
6+ import org .apache .jena .riot .Lang
7+ import org .apache .jena .vocabulary .{RDF , RDFS }
8+ import org .apache .spark .SparkConf
9+ import org .apache .spark .sql .functions ._
10+ import org .apache .spark .sql .{Dataset , SQLContext , SparkSession }
11+ import org .slf4j .LoggerFactory
12+
13+ import net .sansa_stack .inference .spark .data .model .{RDFGraph , RDFGraphDataFrame }
14+ import net .sansa_stack .inference .spark .utils .RDFSSchemaExtractor
15+
16+
17+
18+ /**
19+ * A forward chaining implementation of the RDFS entailment regime.
20+ *
21+ * @constructor create a new RDFS forward chaining reasoner
22+ * @param session the Apache Spark session
23+ * @author Lorenz Buehmann
24+ */
25+ class ForwardRuleReasonerRDFSDataframe (session : SparkSession , parallelism : Int = 2 )
26+ extends TransitiveReasoner (session.sparkContext, parallelism) {
27+
28+ val sqlContext = session.sqlContext
29+ import sqlContext .implicits ._
30+
31+ private val logger = com.typesafe.scalalogging.Logger (LoggerFactory .getLogger(this .getClass.getName))
32+
33+ def apply (graph : RDFGraphDataFrame ): RDFGraphDataFrame = {
34+ logger.info(" materializing graph..." )
35+ val startTime = System .currentTimeMillis()
36+
37+ val sqlSchema = graph.schema
38+
39+ val extractor = new RDFSSchemaExtractor ()
40+
41+ var index = extractor.extractWithIndex(graph)
42+
43+ var triples = graph.toDataFrame(session).alias(" DATA" )
44+
45+ // broadcast the tables for the schema triples
46+ index = index.map{ e =>
47+ val property = e._1
48+ val dataframe = e._2
49+
50+ property -> broadcast(dataframe).as(property.getURI)
51+ }
52+
53+ // RDFS rules dependency was analyzed in \todo(add references) and the same ordering is used here
54+
55+
56+ // 1. we first compute the transitive closure of rdfs:subPropertyOf and rdfs:subClassOf
57+
58+ /**
59+ * rdfs11 xxx rdfs:subClassOf yyy .
60+ * yyy rdfs:subClassOf zzz . xxx rdfs:subClassOf zzz .
61+ */
62+ val subClassOfTriples = index(RDFS .subClassOf.asNode()) // extract rdfs:subClassOf triples
63+ val subClassOfTriplesTrans = broadcast(computeTransitiveClosureDF(subClassOfTriples.as[RDFTriple ]).toDF().alias(" SC" ))
64+ // val subClassOfMap = CollectionUtils.toMultiMap(subClassOfTriplesTrans.rdd.map(r => (r.getString(0) -> (r.getString(2)))).collect)
65+ // val subClassOfMapBC = session.sparkContext.broadcast(subClassOfMap)
66+ // val checkSubclass = udf((cls: String) => subClassOfMapBC.value.contains(cls))
67+ // val makeSuperTypeTriple = udf((ind: String, cls: String) => (ind, subClassOfMapBC.value(cls)))
68+ /*
69+ rdfs5 xxx rdfs:subPropertyOf yyy .
70+ yyy rdfs:subPropertyOf zzz . xxx rdfs:subPropertyOf zzz .
71+ */
72+ val subPropertyOfTriples = index(RDFS .subPropertyOf.asNode()) // extract rdfs:subPropertyOf triples
73+ val subPropertyOfTriplesTrans = broadcast(computeTransitiveClosureDF(subPropertyOfTriples.as[RDFTriple ]).toDF().alias(" SP" ))
74+
75+
76+ // // a map structure should be more efficient
77+ // val subClassOfMap = subClassOfTriplesTrans.collect().map(row => row(0).asInstanceOf[String] -> row(1).asInstanceOf[String]).toMap
78+ // val subPropertyOfMap = subPropertyOfTriplesTrans.collect().map(row => row(0).asInstanceOf[String] -> row(1).asInstanceOf[String]).toMap
79+ //
80+ // // distribute the schema data structures by means of shared variables
81+ // // the assumption here is that the schema is usually much smaller than the instance data
82+ // val subClassOfMapBC = session.sparkContext.broadcast(subClassOfMap)
83+ // val subPropertyOfMapBC = session.sparkContext.broadcast(subPropertyOfMap)
84+ //
85+ // def containsPredicateAsKey(map: Map[String, String]) = udf((predicate : String) => map.contains(predicate))
86+ // def fillPredicate(map: Map[String, String]) = udf((predicate : String) => if(map.contains(predicate)) map(predicate) else "")
87+
88+
89+ // Broadcast
90+ // val subClassOfTriplesTransDataBC = session.sparkContext.broadcast(subPropertyOfTriplesTrans.collectAsList())
91+ // val subClassOfTriplesTransSchemaBC = session.sparkContext.broadcast(subPropertyOfTriplesTrans.schema)
92+ // val subClassOfTriplesTransBCDF = session.sqlContext.createDataFrame(
93+ // subClassOfTriplesTransDataBC.value,
94+ // subClassOfTriplesTransSchemaBC.value).alias("SCBC")
95+
96+ // 2. SubPropertyOf inheritance according to rdfs7 is computed
97+
98+ /*
99+ rdfs7 aaa rdfs:subPropertyOf bbb .
100+ xxx aaa yyy . xxx bbb yyy .
101+ */
102+ val triplesRDFS7 =
103+ triples // all triples (s p1 o)
104+ .join(subPropertyOfTriplesTrans, $" DATA.${sqlSchema.predicateCol}" === $" SP.${sqlSchema.subjectCol}" , " inner" ) // such that p1 has a super property p2
105+ .select($" DATA.${sqlSchema.subjectCol}" , $" SP.${sqlSchema.objectCol}" , $" DATA.${sqlSchema.objectCol}" ) // create triple (s p2 o)
106+
107+ // val triplesRDFS7 =
108+ // triples // all triples (s p1 o)
109+ // .filter(containsPredicateAsKey(subPropertyOfMapBC.value)($"DATA.predicate")) // such that p1 has a super property p2
110+ // .withColumn("CC", fillPredicate(subPropertyOfMapBC.value)($"DATA.predicate"))
111+ // .select($"DATA.subject", $"CC", $"DATA.object") // create triple (s p2 o)
112+ //
113+ // triplesRDFS7.explain(true)
114+
115+ // add triples
116+ triples = triples.union(triplesRDFS7).alias(" DATA" )
117+
118+ // 3. Domain and Range inheritance according to rdfs2 and rdfs3 is computed
119+
120+ /*
121+ rdfs2 aaa rdfs:domain xxx .
122+ yyy aaa zzz . yyy rdf:type xxx .
123+ */
124+ val domainTriples = broadcast(index(RDFS .domain.asNode()).alias(" DOM" ))
125+
126+ val triplesRDFS2 =
127+ triples
128+ .join(domainTriples, $" DATA.${sqlSchema.predicateCol}" === $" DOM.${sqlSchema.subjectCol}" , " inner" )
129+ .select($" DATA.${sqlSchema.subjectCol}" , $" DOM.${sqlSchema.objectCol}" ) // (yyy, xxx)
130+ // triples.createOrReplaceTempView("DATA")
131+ // domainTriples.createOrReplaceTempView("DOM")
132+ // val triplesRDFS2 = session.sql("SELECT A.subject, B.object FROM DATA A INNER JOIN DOM B ON A.predicate=B.subject")
133+ // triplesRDFS2.explain(true)
134+
135+ /*
136+ rdfs3 aaa rdfs:range xxx .
137+ yyy aaa zzz . zzz rdf:type xxx .
138+ */
139+ val rangeTriples = broadcast(index(RDFS .range.asNode()).alias(" RAN" ))
140+
141+ val triplesRDFS3 =
142+ triples
143+ .join(rangeTriples, $" DATA.${sqlSchema.predicateCol}" === $" RAN.${sqlSchema.subjectCol}" , " inner" )
144+ .select($" DATA.${sqlSchema.objectCol}" , $" RAN.${sqlSchema.objectCol}" ) // (zzz, xxx)
145+
146+ val tuples23 = triplesRDFS2.union(triplesRDFS3)
147+
148+ // get rdf:type tuples here as intermediate result
149+ val typeTuples = triples
150+ .where(s " ${sqlSchema.predicateCol} = ' ${RDF .`type`.getURI}' " )
151+ .select(sqlSchema.subjectCol, sqlSchema.objectCol)
152+ .union(tuples23)
153+ .alias(" TYPES" )
154+
155+ // 4. SubClass inheritance according to rdfs9
156+
157+ /*
158+ rdfs9 xxx rdfs:subClassOf yyy .
159+ zzz rdf:type xxx . zzz rdf:type yyy .
160+ */
161+ val tuplesRDFS9 = typeTuples
162+ .join(subClassOfTriplesTrans, $" TYPES.${sqlSchema.objectCol}" === $" SC.${sqlSchema.subjectCol}" , " inner" )
163+ .select($" TYPES.${sqlSchema.subjectCol}" , $" SC.${sqlSchema.objectCol}" ) // (zzz, yyy)
164+
165+ // val triplesRDFS9 =
166+ // typeTuples
167+ // .where(checkSubclass($"TYPES.object"))
168+ // .map(r => (r.getString(0), subClassOfMapBC.value(r.getString(1)).toArray))
169+ // .toDF("subject", "objects")
170+ // triplesRDFS9.printSchema()
171+ //
172+ // val exploded = triplesRDFS9.flatMap(row => {
173+ // val objects = row.getAs[Array[String]]("objects")
174+ // objects.map(o => (row.getString(0), o))
175+ // }).toDF("subject", "object")
176+
177+ // explode("objects", "object") {
178+ // case Row(classes: Array[Row]) => classes.map(clsRow => clsRow(0).asInstanceOf[String])
179+ // case _ => println("ELSE")
180+ // Seq()
181+ // }
182+ // exploded.show()
183+
184+ // .explode()
185+ // .join(subClassOfTriplesTrans, $"TYPES.object" === $"SC.subject", "inner")
186+ // .withColumn("const", lit(RDF.`type`.getURI))
187+ // .select("DATA.subject", "const", "SC.object")
188+ // .select($"TYPES.subject", $"SC.object") // (zzz, yyy)
189+ // println("existing types:" + existingTypes.count())
190+ // println("SC:" + subClassOfTriplesTrans.count())
191+ // println("SP:" + subPropertyOfTriplesTrans.count())
192+ // println("TYPES:" + typeTuples.count())
193+ // println("R7:" + triplesRDFS7.count())
194+ // println("R2:" + triplesRDFS2.count())
195+ // println("R3:" + triplesRDFS3.count())
196+ // println("R9:" + tuplesRDFS9.count())
197+
198+ // 5. merge triples and remove duplicates
199+ val allTriples =
200+ typeTuples.union(tuples23).union(tuplesRDFS9)
201+ .withColumn(" const" , lit(RDF .`type`.getURI))
202+ .select(sqlSchema.subjectCol, " const" , sqlSchema.objectCol)
203+ .union(subClassOfTriplesTrans)
204+ .union(subPropertyOfTriplesTrans)
205+ .union(triplesRDFS7)
206+ .union(triples)
207+ .distinct()
208+ // .selectExpr("subject", "'" + RDF.`type`.getURI + "' as predicate", "object")
209+ // allTriples.explain()
210+
211+ logger.info(" ...finished materialization in " + (System .currentTimeMillis() - startTime) + " ms." )
212+ // val newSize = allTriples.count()
213+ // logger.info(s"|G_inf|=$newSize")
214+
215+ // return graph with inferred triples
216+ new RDFGraphDataFrame (allTriples)
217+ }
218+
219+ /**
220+ * Computes the transitive closure for a Dataframe of triples
221+ *
222+ * @param edges the Dataframe of triples
223+ * @return a Dataframe containing the transitive closure of the triples
224+ */
225+ def computeTransitiveClosureDF (edges : Dataset [RDFTriple ]): Dataset [RDFTriple ] = {
226+ log.info(" computing TC..." )
227+ // implicit val myObjEncoder = org.apache.spark.sql.Encoders.kryo[RDFTriple]
228+ val spark = edges.sparkSession.sqlContext
229+ import spark .implicits ._
230+
231+ // profile {
232+ // we keep the transitive closure cached
233+ var tc = edges
234+ tc.cache()
235+
236+ // the join is iterated until a fixed point is reached
237+ var i = 1
238+ var oldCount = 0L
239+ var nextCount = tc.count()
240+ do {
241+ log.info(s " iteration $i... " )
242+ oldCount = nextCount
243+
244+ // val df1 = tc.alias("df1")
245+ // val df2 = tc.alias("df2")
246+ // perform the join (x, y) x (y, x), obtaining an RDD of (x=y, (y, x)) pairs,
247+ // then project the result to obtain the new (x, y) paths.
248+
249+ tc.createOrReplaceTempView(" SC" )
250+ var joined = tc.as(" A" ).join(tc.as(" B" ), $" A.o" === $" B.s" ).select(" A.s" , " A.p" , " B.o" ).as[RDFTriple ]
251+ // var joined = tc
252+ // .join(edges, tc("o") === edges("s"))
253+ // .select(tc("s"), tc("p"), edges("o"))
254+ // .as[RDFTriple]
255+ // tc.sqlContext.
256+ // sql("SELECT A.subject, A.predicate, B.object FROM SC A INNER JOIN SC B ON A.object = B.subject")
257+
258+ // joined.explain()
259+ // var joined = df1.join(df2, df1("object") === df2("subject"), "inner")
260+ // println("JOINED:\n" + joined.collect().mkString("\n"))
261+ // joined = joined.select(df2(s"df1.$col1"), df1(s"df1.$col2"))
262+ // println(joined.collect().mkString("\n"))
263+
264+ tc = tc
265+ .union(joined)
266+ .distinct()
267+ .cache()
268+ nextCount = tc.count()
269+ i += 1
270+ } while (nextCount != oldCount)
271+
272+ tc.sqlContext.uncacheTable(" SC" )
273+ log.info(" TC has " + nextCount + " edges." )
274+ tc
275+ // }
276+ }
277+
278+ /**
279+ * Applies forward chaining to the given RDF graph and returns a new RDF graph that contains all additional
280+ * triples based on the underlying set of rules.
281+ *
282+ * @param graph the RDF graph
283+ * @return the materialized RDF graph
284+ */
285+ override def apply (graph : RDFGraph ): RDFGraph = graph
286+
287+ }
288+
289+ object ForwardRuleReasonerRDFSDataframe {
290+ def apply (session : SparkSession , parallelism : Int = 2 ): ForwardRuleReasonerRDFSDataframe = new ForwardRuleReasonerRDFSDataframe (session, parallelism)
291+
292+ def main (args : Array [String ]): Unit = {
293+ import net .sansa_stack .inference .spark .data .loader .sql .rdf ._
294+
295+ val parallelism = 2
296+
297+ // register the custom classes for Kryo serializer
298+ val conf = new SparkConf ()
299+ conf.registerKryoClasses(Array (classOf [org.apache.jena.graph.Triple ]))
300+ conf.set(" spark.extraListeners" , " net.sansa_stack.inference.spark.utils.CustomSparkListener" )
301+
302+ // the SPARK config
303+ val session = SparkSession .builder
304+ .appName(s " SPARK DataFrame-based RDFS Reasoning " )
305+ .master(" local[4]" )
306+ // .config("spark.eventLog.enabled", "true")
307+ .config(" spark.hadoop.validateOutputSpecs" , " false" ) // override output files
308+ .config(" spark.serializer" , " org.apache.spark.serializer.KryoSerializer" )
309+ .config(" spark.default.parallelism" , parallelism)
310+ .config(" spark.ui.showConsoleProgress" , " false" )
311+ .config(" spark.sql.shuffle.partitions" , parallelism)
312+ .config(conf)
313+ .getOrCreate()
314+
315+ val triples = session.read.rdf(Lang .NTRIPLES )(args(0 ))
316+ triples.createOrReplaceTempView(" TRIPLES" )
317+
318+ val graph = new RDFGraphDataFrame (triples)
319+
320+
321+ val infGraph = ForwardRuleReasonerRDFSDataframe (session).apply(graph)
322+ println(infGraph.size())
323+ }
324+ }
0 commit comments