Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit 95a560e

Browse files
Load RDF graph into DataFrame via separate Reader implementation.
1 parent e36d452 commit 95a560e

3 files changed

Lines changed: 71 additions & 4 deletions

File tree

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/data/loader/RDFGraphLoader.scala

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@ import java.net.URI
55
import scala.language.implicitConversions
66

77
import org.apache.spark.SparkContext
8-
import org.apache.spark.sql.{Dataset, SparkSession}
8+
import org.apache.spark.rdd.RDD
9+
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider, TableScan}
10+
import org.apache.spark.sql.types.{StringType, StructField, StructType}
11+
import org.apache.spark.sql.{Dataset, Row, SQLContext, SparkSession}
912
import org.slf4j.LoggerFactory
1013

11-
import net.sansa_stack.inference.data.RDFTriple
14+
import net.sansa_stack.inference.data.{RDFTriple, SQLSchema, SQLSchemaDefault}
1215
import net.sansa_stack.inference.spark.data.model.{RDFGraph, RDFGraphDataFrame, RDFGraphDataset, RDFGraphNative}
1316
import net.sansa_stack.inference.utils.NTriplesStringToRDFTriple
1417

@@ -192,7 +195,15 @@ object RDFGraphLoader {
192195
* @param minPartitions min number of partitions for Hadoop RDDs ([[SparkContext.defaultMinPartitions]])
193196
* @return an RDF graph based on a [[org.apache.spark.sql.DataFrame]]
194197
*/
195-
def loadFromDiskAsDataFrame(session: SparkSession, path: String, minPartitions: Int): RDFGraphDataFrame = {
196-
new RDFGraphDataFrame(loadFromDiskAsRDD(session, path, minPartitions).toDataFrame(session))
198+
def loadFromDiskAsDataFrame(session: SparkSession, path: String, minPartitions: Int, sqlSchema: SQLSchema = SQLSchemaDefault): RDFGraphDataFrame = {
199+
val df = session
200+
.read
201+
.format("net.sansa_stack.inference.spark.data.loader.sql")
202+
.load(path)
203+
204+
// register the DataFrame as a table
205+
df.createOrReplaceTempView(sqlSchema.triplesTable)
206+
207+
new RDFGraphDataFrame(df)
197208
}
198209
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
package net.sansa_stack.inference.spark.data.loader.sql
2+
3+
import org.apache.spark.sql.SQLContext
4+
import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, SchemaRelationProvider}
5+
import org.apache.spark.sql.types.StructType
6+
7+
8+
class DefaultSource extends RelationProvider with SchemaRelationProvider {
9+
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String])
10+
: BaseRelation = {
11+
createRelation(sqlContext, parameters, null)
12+
}
13+
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]
14+
, schema: StructType)
15+
: BaseRelation = {
16+
parameters.getOrElse("path", sys.error("'path' must be specified for our data."))
17+
return new NTriplesRelation(parameters.get("path").get, schema)(sqlContext)
18+
}
19+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package net.sansa_stack.inference.spark.data.loader.sql
2+
3+
import org.apache.spark.rdd.RDD
4+
import org.apache.spark.sql.{Row, SQLContext}
5+
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
6+
import org.apache.spark.sql.types.{StringType, StructField, StructType}
7+
8+
import net.sansa_stack.inference.utils.NTriplesStringToRDFTriple
9+
10+
class NTriplesRelation(location: String, userSchema: StructType)
11+
(@transient val sqlContext: SQLContext)
12+
extends BaseRelation
13+
with TableScan
14+
with Serializable {
15+
override def schema: StructType = {
16+
if (this.userSchema != null) {
17+
this.userSchema
18+
}
19+
else {
20+
StructType(
21+
Seq(
22+
StructField("s", StringType, true),
23+
StructField("p", StringType, true),
24+
StructField("o", StringType, true)
25+
))
26+
}
27+
}
28+
override def buildScan(): RDD[Row] = {
29+
val rdd = sqlContext
30+
.sparkContext
31+
.textFile(location)
32+
33+
val rows = rdd.map(new NTriplesStringToRDFTriple()).map(t => Row.fromSeq(Seq(t.s, t.p, t.o)))
34+
35+
rows
36+
}
37+
}

0 commit comments

Comments
 (0)