Skip to content

Commit 4e408b2

Browse files
committed
Remove mask parameter & change to ignore NaN's
1 parent e76e8c2 commit 4e408b2

1 file changed

Lines changed: 12 additions & 21 deletions

File tree

modAL/dropout.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,12 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
302302

303303
return predictions
304304

305-
def entropy_sum(values: np.array, mask: np.ndarray = None, axis: int =-1):
306-
if mask is None:
307-
mask = np.ones(values.shape, dtype=bool)
305+
def entropy_sum(values: np.array, axis: int =-1):
308306
#sum Scipy basic entropy function: entr()
309-
return np.sum(entr(values), where=mask, axis=axis)
307+
entropy = entr(values)
308+
return np.sum(entropy, where=~np.isnan(entropy), axis=axis)
310309

311-
def _mean_standard_deviation(proba: list, mask: np.ndarray = None) -> np.ndarray:
310+
def _mean_standard_deviation(proba: list) -> np.ndarray:
312311
"""
313312
Calculates the mean of the per class calculated standard deviations.
314313
@@ -324,15 +323,13 @@ def _mean_standard_deviation(proba: list, mask: np.ndarray = None) -> np.ndarray
324323
"""
325324

326325
proba_stacked = np.stack(proba, axis=len(proba[0].shape))
327-
if mask is None:
328-
mask = np.ones(proba[0].shape, dtype=bool)
329326

330327
standard_deviation_class_vise = np.std(proba_stacked, axis=-1)
331-
mean_standard_deviation = np.mean(standard_deviation_class_vise, where=mask, axis=-1)
328+
mean_standard_deviation = np.mean(standard_deviation_class_vise, where=~np.isnan(standard_deviation_class_vise), axis=-1)
332329

333330
return mean_standard_deviation
334331

335-
def _entropy(proba: list, mask: np.ndarray = None) -> np.ndarray:
332+
def _entropy(proba: list) -> np.ndarray:
336333
"""
337334
Calculates the entropy per class over dropout cycles
338335
@@ -348,15 +345,13 @@ def _entropy(proba: list, mask: np.ndarray = None) -> np.ndarray:
348345
"""
349346

350347
proba_stacked = np.stack(proba, axis=len(proba[0].shape))
351-
if mask is None:
352-
mask = np.ones(proba[0].shape, dtype=bool)
353348

354349
#calculate entropy per class and sum along dropout cycles
355350
entropy_classes = entropy_sum(proba_stacked, axis=-1)
356-
entropy = np.mean(entropy_classes, where=mask, axis=-1)
351+
entropy = np.mean(entropy_classes, where=~np.isnan(entropy_classes), axis=-1)
357352
return entropy
358353

359-
def _variation_ratios(proba: list, mask: np.ndarray = None) -> np.ndarray:
354+
def _variation_ratios(proba: list) -> np.ndarray:
360355
"""
361356
Calculates the variation ratios over dropout cycles
362357
@@ -371,13 +366,12 @@ def _variation_ratios(proba: list, mask: np.ndarray = None) -> np.ndarray:
371366
Returns the variation ratios of the dropout cycles.
372367
"""
373368
proba_stacked = np.stack(proba, axis=len(proba[0].shape))
374-
if mask is None:
375-
mask = np.ones(proba[0].shape, dtype=bool)
369+
376370
#Calculate the variation ratios over the mean of dropout cycles
377371
valuesDCMean = np.mean(proba_stacked, axis=-1)
378-
return 1 - np.amax(valuesDCMean, initial=0, where=mask, axis=-1)
372+
return 1 - np.amax(valuesDCMean, initial=0, where=~np.isnan(valuesDCMean), axis=-1)
379373

380-
def _bald_divergence(proba: list, mask: np.ndarray = None) -> np.ndarray:
374+
def _bald_divergence(proba: list) -> np.ndarray:
381375
"""
382376
Calculates the bald divergence for each instance
383377
@@ -412,10 +406,7 @@ def _bald_divergence(proba: list, mask: np.ndarray = None) -> np.ndarray:
412406
#sum all dimensions of diff besides first dim (instances)
413407
shaped = np.reshape(diff, (diff.shape[0], -1))
414408

415-
if mask is None:
416-
mask = np.ones(shaped.shape, dtype=bool)
417-
418-
bald = np.sum(shaped, where=mask, axis=-1)
409+
bald = np.sum(shaped, where=~np.isnan(shaped), axis=-1)
419410
return bald
420411

421412
def _KL_divergence(proba) -> np.ndarray:

0 commit comments

Comments
 (0)