Skip to content

Commit 30abf74

Browse files
committed
Dropout softmax handle NaN's
1 parent 4e408b2 commit 30abf74

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

modAL/dropout.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,10 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
292292
#In comparison to: predict(), predict_proba() the infer()
293293
# does not change train/eval mode of other layers
294294
prediction = classifier.estimator.infer(samples)
295-
prediction_proba = to_numpy(prediction.softmax(1))
296-
probas = prediction_proba if probas is None else np.vstack((probas, prediction_proba))
295+
mask = ~prediction.isnan()
296+
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)
297+
prediction = to_numpy(prediction)
298+
probas = prediction if probas is None else np.vstack((probas, prediction))
297299

298300
predictions.append(probas)
299301

0 commit comments

Comments
 (0)