Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit e781e51

Browse files
DataSet split function.
1 parent 92a60ef commit e781e51

2 files changed

Lines changed: 20 additions & 3 deletions

File tree

sansa-inference-flink/src/main/scala/net/sansa_stack/inference/flink/data/RDFGraph.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import net.sansa_stack.inference.flink.utils.DataSetUtils
44
import org.apache.flink.api.scala.{DataSet, _}
55
import org.apache.jena.graph.Triple
66
import net.sansa_stack.inference.data.RDFTriple
7+
import net.sansa_stack.inference.flink.utils.DataSetUtils.DataSetOps
78

89
/**
910
* A data structure that comprises a set of triples.
@@ -59,7 +60,7 @@ case class RDFGraph(triples: DataSet[RDFTriple]) {
5960
* @return the difference of both graphs
6061
*/
6162
def subtract(other: RDFGraph): RDFGraph = {
62-
RDFGraph(DataSetUtils.subtract(this.triples, other.triples))
63+
RDFGraph(triples.subtract(other.triples))
6364
}
6465

6566
/**

sansa-inference-flink/src/main/scala/net/sansa_stack/inference/flink/utils/DataSetUtils.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,24 @@ import scala.reflect.ClassTag
1414
*/
1515
object DataSetUtils {
1616

17-
def subtract[T: ClassTag: TypeInformation](first: DataSet[T], second: DataSet[T]): DataSet[T] = {
18-
first.coGroup(second).where("*").equalTo("*")(new MinusCoGroupFunction[T](true)).name("subtract")
17+
implicit class DataSetOps[T: ClassTag : TypeInformation](dataset: DataSet[T]) {
18+
19+
/**
20+
* Splits an RDD into two parts based on the given filter function. Note, that filtering is done twice on the same
21+
* data twice, thus, caching beforehand is recommended!
22+
*
23+
* @param f the boolean filter function
24+
* @return two RDDs
25+
*/
26+
def partitionBy(f: T => Boolean): (DataSet[T], DataSet[T]) = {
27+
val passes = dataset.filter(f)
28+
val fails = dataset.filter(e => !f(e)) // Flink doesn't have filterNot
29+
(passes, fails)
30+
}
31+
32+
def subtract(other: DataSet[T]): DataSet[T] = {
33+
dataset.coGroup(other).where("*").equalTo("*")(new MinusCoGroupFunction[T](true)).name("subtract")
34+
}
1935
}
2036

2137
}

0 commit comments

Comments
 (0)