Skip to content

Commit a4750f0

Browse files
committed
remove time
1 parent f04db99 commit a4750f0

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

modAL/dropout.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from collections.abc import Mapping
55
from typing import Callable
66

7+
78
from sklearn.base import BaseEstimator
89
from sklearn.preprocessing import normalize
910

@@ -110,8 +111,8 @@ def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int =
110111
The mc-dropout metric of the chosen instances;
111112
"""
112113
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
113-
114114
#calculate BALD (Bayesian active learning divergence))
115+
115116
bald_scores = _bald_divergence(predictions)
116117

117118
if not random_tie_break:
@@ -276,7 +277,7 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
276277
Return:
277278
prediction: list with all predictions
278279
"""
279-
280+
280281
predictions = []
281282
# set dropout layers to train mode
282283
set_dropout_mode(classifier.estimator.module_, dropout_layer_indexes, train_mode=True)
@@ -308,6 +309,7 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
308309
#In comparison to: predict(), predict_proba() the infer()
309310
# does not change train/eval mode of other layers
310311
logits = classifier.estimator.infer(samples)
312+
311313
prediction = logits_adaptor(logits, samples)
312314
mask = ~prediction.isnan()
313315
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)

0 commit comments

Comments
 (0)