11package net .sansa_stack .inference .spark .forwardchaining
22
33import scala .reflect .ClassTag
4-
54import org .apache .spark .rdd .RDD
6-
75import net .sansa_stack .inference .utils .Logging
6+ import org .apache .spark .sql .Dataset
87
98/**
109 * Creates a new RDD by performing bulk iterations using the given step function. The first
@@ -22,10 +21,10 @@ trait FixpointIteration[T] extends Logging {
2221object FixpointIteration extends Logging {
2322
2423 /**
25- * Creates a new RDD by performing bulk iterations using the given step function `f`. The first
26- * RDD the step function returns is the input for the next iteration, the second RDD is
24+ * Creates a new [[ RDD ]] by performing bulk iterations using the given step function `f`. The first
25+ * RDD the step function returns is the input for the next iteration, the second [[ RDD ]] is
2726 * the termination criterion. The iterations terminate when either the termination criterion
28- * RDD contains no elements or when `maxIterations` iterations have been performed.
27+ * [[ RDD ]] contains no elements or when `maxIterations` iterations have been performed.
2928 *
3029 **/
3130 def apply [T : ClassTag ](maxIterations : Int = 10 )(rdd : RDD [T ], f : RDD [T ] => RDD [T ]): RDD [T ] = {
@@ -48,4 +47,33 @@ object FixpointIteration extends Logging {
4847 }
4948 newRDD
5049 }
50+
51+ /**
52+ *
53+ * Creates a new [[Dataset ]] by performing bulk iterations using the given step function `f`. The first
54+ * [[Dataset ]] the step function returns is the input for the next iteration, the second RDD is
55+ * the termination criterion. The iterations terminate when either the termination criterion
56+ * RDD contains no elements or when `maxIterations` iterations have been performed.
57+ *
58+ **/
59+ def apply2 [T : ClassTag ](maxIterations : Int = 10 )(dataset : Dataset [T ], f : Dataset [T ] => Dataset [T ]): Dataset [T ] = {
60+ var newDS = dataset
61+ newDS.cache()
62+ var i = 1
63+ var oldCount = 0L
64+ var nextCount = if (newDS.count() == 0 ) 0L else 1L
65+ while (nextCount != oldCount) {
66+ log.info(s " iteration $i... " )
67+ oldCount = nextCount
68+ info(s " i: $nextCount" )
69+ newDS = newDS
70+ .union(f(newDS))
71+ .distinct()
72+ .cache()
73+ nextCount = newDS.count()
74+ info(s " i+1: $nextCount" )
75+ i += 1
76+ }
77+ newDS
78+ }
5179}
0 commit comments