11package net .sansa_stack .inference .spark .utils
22
33import org .apache .jena .vocabulary .RDFS
4+ import org .apache .spark .SparkContext
45import org .apache .spark .broadcast .Broadcast
56import org .apache .spark .rdd .RDD
6- import org .apache .spark .sql .{ DataFrame , SparkSession }
7+ import org .apache .spark .sql .DataFrame
78
8- import net .sansa_stack .inference .spark .data .{RDFGraphDataFrame , RDFGraphNative }
9+ import net .sansa_stack .inference .data .RDFTriple
10+ import net .sansa_stack .inference .spark .data .{RDFGraph , RDFGraphDataFrame , RDFGraphNative }
911import net .sansa_stack .inference .utils .{CollectionUtils , Logging }
1012
1113/**
@@ -20,9 +22,41 @@ import net.sansa_stack.inference.utils.{CollectionUtils, Logging}
2022 *
2123 * @author Lorenz Buehmann
2224 */
23- class RDFSSchemaExtractor (session : SparkSession ) extends Logging {
25+ class RDFSSchemaExtractor (sc : SparkContext ) extends Logging {
2426
25- val properties = List (RDFS .subClassOf, RDFS .subPropertyOf, RDFS .domain, RDFS .range).map(p => p.getURI)
27+ val properties = Set (RDFS .subClassOf, RDFS .subPropertyOf, RDFS .domain, RDFS .range).map(p => p.getURI)
28+
29+ /**
30+ * Extracts the RDF graph containing only the schema triples from the RDF graph.
31+ *
32+ * @param graph the RDF graph
33+ * @return the RDF graph containing only the schema triples
34+ */
35+ def extract (graph : RDFGraph ): RDFGraph = {
36+ log.info(" Started schema extraction..." )
37+
38+ val filteredTriples = graph.triples.filter(t => properties.contains(t.p))
39+
40+ log.info(" Finished schema extraction." )
41+
42+ new RDFGraph (filteredTriples)
43+ }
44+
45+ /**
46+ * Extracts the schema triples from the given triples.
47+ *
48+ * @param triples the triples
49+ * @return the schema triples
50+ */
51+ def extract (triples : RDD [RDFTriple ]): RDD [RDFTriple ] = {
52+ log.info(" Started schema extraction..." )
53+
54+ val filteredTriples = triples.filter(t => properties.contains(t.p))
55+
56+ log.info(" Finished schema extraction." )
57+
58+ filteredTriples
59+ }
2660
2761
2862 /**
@@ -32,7 +66,7 @@ class RDFSSchemaExtractor(session : SparkSession) extends Logging{
3266 * @param graph the RDF graph
3367 * @return a mapping from the corresponding schema property to the RDD of s-o pairs
3468 */
35- def extract (graph : RDFGraphNative ): Map [String , RDD [(String , String )]] = {
69+ def extractWithIndex (graph : RDFGraphNative ): Map [String , RDD [(String , String )]] = {
3670 log.info(" Started schema extraction..." )
3771
3872 // for each schema property p
@@ -59,7 +93,7 @@ class RDFSSchemaExtractor(session : SparkSession) extends Logging{
5993 * @param graph the RDF graph
6094 * @return a mapping from the corresponding schema property to the Dataframe of s-o pairs
6195 */
62- def extract (graph : RDFGraphDataFrame ): Map [String , DataFrame ] = {
96+ def extractWithIndex (graph : RDFGraphDataFrame ): Map [String , DataFrame ] = {
6397 log.info(" Started schema extraction..." )
6498
6599 // for each schema property p
@@ -87,8 +121,8 @@ class RDFSSchemaExtractor(session : SparkSession) extends Logging{
87121 * @return a mapping from the corresponding schema property to the broadcast variable that wraps the multimap
88122 * with s-o pairs
89123 */
90- def extractAndDistribute (graph : RDFGraphNative ): Map [String , Broadcast [Map [String , Set [String ]]]] = {
91- val schema = extract (graph)
124+ def extractWithIndexAndDistribute (graph : RDFGraphNative ): Map [String , Broadcast [Map [String , Set [String ]]]] = {
125+ val schema = extractWithIndex (graph)
92126
93127 log.info(" Started schema distribution..." )
94128 val index =
@@ -100,7 +134,7 @@ class RDFSSchemaExtractor(session : SparkSession) extends Logging{
100134 val mmap = CollectionUtils .toMultiMap(rdd.collect())
101135
102136 // broadcast
103- val bv = session.sparkContext .broadcast(mmap)
137+ val bv = sc .broadcast(mmap)
104138
105139 // add to index
106140 (p -> bv)
0 commit comments