|
4 | 4 | from collections.abc import Mapping |
5 | 5 | from typing import Callable |
6 | 6 |
|
| 7 | + |
7 | 8 | from sklearn.base import BaseEstimator |
8 | 9 | from sklearn.preprocessing import normalize |
9 | 10 |
|
@@ -110,8 +111,8 @@ def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int = |
110 | 111 | The mc-dropout metric of the chosen instances; |
111 | 112 | """ |
112 | 113 | predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor) |
113 | | - |
114 | 114 | #calculate BALD (Bayesian active learning divergence)) |
| 115 | + |
115 | 116 | bald_scores = _bald_divergence(predictions) |
116 | 117 |
|
117 | 118 | if not random_tie_break: |
@@ -276,7 +277,7 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde |
276 | 277 | Return: |
277 | 278 | prediction: list with all predictions |
278 | 279 | """ |
279 | | - |
| 280 | + |
280 | 281 | predictions = [] |
281 | 282 | # set dropout layers to train mode |
282 | 283 | set_dropout_mode(classifier.estimator.module_, dropout_layer_indexes, train_mode=True) |
@@ -308,6 +309,7 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde |
308 | 309 | #In comparison to: predict(), predict_proba() the infer() |
309 | 310 | # does not change train/eval mode of other layers |
310 | 311 | logits = classifier.estimator.infer(samples) |
| 312 | + |
311 | 313 | prediction = logits_adaptor(logits, samples) |
312 | 314 | mask = ~prediction.isnan() |
313 | 315 | prediction[mask] = prediction[mask].unsqueeze(0).softmax(1) |
|
0 commit comments