Skip to content
This repository was archived by the owner on Oct 8, 2020. It is now read-only.

Commit f0f15e3

Browse files
Added method for fixpoint iteration in Spark Dataset.
1 parent a0fb542 commit f0f15e3

1 file changed

Lines changed: 33 additions & 5 deletions

File tree

sansa-inference-spark/src/main/scala/net/sansa_stack/inference/spark/forwardchaining/FixpointIteration.scala

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
package net.sansa_stack.inference.spark.forwardchaining
22

33
import scala.reflect.ClassTag
4-
54
import org.apache.spark.rdd.RDD
6-
75
import net.sansa_stack.inference.utils.Logging
6+
import org.apache.spark.sql.Dataset
87

98
/**
109
* Creates a new RDD by performing bulk iterations using the given step function. The first
@@ -22,10 +21,10 @@ trait FixpointIteration[T] extends Logging {
2221
object FixpointIteration extends Logging {
2322

2423
/**
25-
* Creates a new RDD by performing bulk iterations using the given step function `f`. The first
26-
* RDD the step function returns is the input for the next iteration, the second RDD is
24+
* Creates a new [[RDD]] by performing bulk iterations using the given step function `f`. The first
25+
* RDD the step function returns is the input for the next iteration, the second [[RDD]] is
2726
* the termination criterion. The iterations terminate when either the termination criterion
28-
* RDD contains no elements or when `maxIterations` iterations have been performed.
27+
* [[RDD]] contains no elements or when `maxIterations` iterations have been performed.
2928
*
3029
**/
3130
def apply[T: ClassTag](maxIterations: Int = 10)(rdd: RDD[T], f: RDD[T] => RDD[T]): RDD[T] = {
@@ -48,4 +47,33 @@ object FixpointIteration extends Logging {
4847
}
4948
newRDD
5049
}
50+
51+
/**
52+
*
53+
* Creates a new [[Dataset]] by performing bulk iterations using the given step function `f`. The first
54+
* [[Dataset]] the step function returns is the input for the next iteration, the second RDD is
55+
* the termination criterion. The iterations terminate when either the termination criterion
56+
* RDD contains no elements or when `maxIterations` iterations have been performed.
57+
*
58+
**/
59+
def apply2[T: ClassTag](maxIterations: Int = 10)(dataset: Dataset[T], f: Dataset[T] => Dataset[T]): Dataset[T] = {
60+
var newDS = dataset
61+
newDS.cache()
62+
var i = 1
63+
var oldCount = 0L
64+
var nextCount = if (newDS.count() == 0) 0L else 1L
65+
while (nextCount != oldCount) {
66+
log.info(s"iteration $i...")
67+
oldCount = nextCount
68+
info(s"i:$nextCount")
69+
newDS = newDS
70+
.union(f(newDS))
71+
.distinct()
72+
.cache()
73+
nextCount = newDS.count()
74+
info(s"i+1:$nextCount")
75+
i += 1
76+
}
77+
newDS
78+
}
5179
}

0 commit comments

Comments
 (0)