Skip to content

Commit cce7933

Browse files
committed
logits_adapter_function
1 parent 30abf74 commit cce7933

1 file changed

Lines changed: 41 additions & 20 deletions

File tree

modAL/dropout.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import sys
33
import torch
44
from collections.abc import Mapping
5+
from typing import Callable
56

67
from sklearn.base import BaseEstimator
78
from sklearn.preprocessing import normalize
@@ -13,6 +14,10 @@
1314

1415
from skorch.utils import to_numpy
1516

17+
def default_logits_adaptor(input_tensor: torch.tensor, samples: modALinput):
18+
# default Callable parameter for get_predictions
19+
return input_tensor
20+
1621
def KL_divergence(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
1722
random_tie_break: bool = False, dropout_layer_indexes: list = [],
1823
num_cycles : int = 50, **mc_dropout_kwargs) -> np.ndarray:
@@ -36,7 +41,9 @@ def KL_divergence(classifier: BaseEstimator, X: modALinput, n_instances: int = 1
3641

3742
def mc_dropout_multi(classifier: BaseEstimator, X: modALinput, query_strategies: list = ["bald", "mean_st", "max_entropy", "max_var"],
3843
n_instances: int = 1, random_tie_break: bool = False, dropout_layer_indexes: list = [],
39-
num_cycles : int = 50, sample_per_forward_pass: int = 1000, **mc_dropout_kwargs) -> np.ndarray:
44+
num_cycles : int = 50, sample_per_forward_pass: int = 1000,
45+
logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor,
46+
**mc_dropout_kwargs) -> np.ndarray:
4047
"""
4148
Multi metric dropout query strategy. Returns the specified metrics for given input data.
4249
Selection of query strategies are:
@@ -49,7 +56,7 @@ def mc_dropout_multi(classifier: BaseEstimator, X: modALinput, query_strategies:
4956
Function returns dictionary of metrics with their name as key.
5057
The indices of the n-best samples (n_instances) is not used in this function.
5158
"""
52-
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass)
59+
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
5360

5461
metrics_dict = {}
5562
if "bald" in query_strategies:
@@ -65,7 +72,9 @@ def mc_dropout_multi(classifier: BaseEstimator, X: modALinput, query_strategies:
6572

6673
def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
6774
random_tie_break: bool = False, dropout_layer_indexes: list = [],
68-
num_cycles : int = 50, sample_per_forward_pass: int = 1000, **mc_dropout_kwargs) -> np.ndarray:
75+
num_cycles : int = 50, sample_per_forward_pass: int = 1000,
76+
logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor,
77+
**mc_dropout_kwargs,) -> np.ndarray:
6978
"""
7079
Mc-Dropout bald query strategy. Returns the indexes of the instances with the largest BALD
7180
(Bayesian Active Learning by Disagreement) score calculated through the dropout cycles
@@ -91,14 +100,16 @@ def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int =
91100
sample_per_forward_pass: max. sample number for each forward pass.
92101
The allocated RAM does mainly depend on this.
93102
Small number --> small RAM allocation
103+
logits_adaptor: Callable which can be used to adapt the output of a forward pass
104+
to the required vector format for the vectorised metric functions
94105
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
95106
measure function.
96107
97108
Returns:
98109
The indices of the instances from X chosen to be labelled;
99110
The mc-dropout metric of the chosen instances;
100111
"""
101-
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass)
112+
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
102113

103114
#calculate BALD (Bayesian active learning divergence))
104115
bald_scores = _bald_divergence(predictions)
@@ -110,7 +121,9 @@ def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int =
110121

111122
def mc_dropout_mean_st(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
112123
random_tie_break: bool = False, dropout_layer_indexes: list = [],
113-
num_cycles : int = 50, sample_per_forward_pass: int = 1000, **mc_dropout_kwargs) -> np.ndarray:
124+
num_cycles : int = 50, sample_per_forward_pass: int = 1000,
125+
logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor,
126+
**mc_dropout_kwargs) -> np.ndarray:
114127
"""
115128
Mc-Dropout mean standard deviation query strategy. Returns the indexes of the instances
116129
with the largest mean of the per class calculated standard deviations over multiple dropout cycles
@@ -132,6 +145,8 @@ def mc_dropout_mean_st(classifier: BaseEstimator, X: modALinput, n_instances: in
132145
sample_per_forward_pass: max. sample number for each forward pass.
133146
The allocated RAM does mainly depend on this.
134147
Small number --> small RAM allocation
148+
logits_adaptor: Callable which can be used to adapt the output of a forward pass
149+
to the required vector format for the vectorised metric functions
135150
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
136151
measure function.
137152
@@ -141,7 +156,7 @@ def mc_dropout_mean_st(classifier: BaseEstimator, X: modALinput, n_instances: in
141156
"""
142157

143158
# set dropout layers to train mode
144-
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass)
159+
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
145160

146161
mean_standard_deviations = _mean_standard_deviation(predictions)
147162

@@ -152,7 +167,9 @@ def mc_dropout_mean_st(classifier: BaseEstimator, X: modALinput, n_instances: in
152167

153168
def mc_dropout_max_entropy(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
154169
random_tie_break: bool = False, dropout_layer_indexes: list = [],
155-
num_cycles : int = 50, sample_per_forward_pass: int = 1000, **mc_dropout_kwargs) -> np.ndarray:
170+
num_cycles : int = 50, sample_per_forward_pass: int = 1000,
171+
logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor,
172+
**mc_dropout_kwargs) -> np.ndarray:
156173
"""
157174
Mc-Dropout maximum entropy query strategy. Returns the indexes of the instances
158175
with the largest entropy of the per class calculated entropies over multiple dropout cycles
@@ -174,14 +191,16 @@ def mc_dropout_max_entropy(classifier: BaseEstimator, X: modALinput, n_instances
174191
sample_per_forward_pass: max. sample number for each forward pass.
175192
The allocated RAM does mainly depend on this.
176193
Small number --> small RAM allocation
194+
logits_adaptor: Callable which can be used to adapt the output of a forward pass
195+
to the required vector format for the vectorised metric functions
177196
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
178197
measure function.
179198
180199
Returns:
181200
The indices of the instances from X chosen to be labelled;
182201
The mc-dropout metric of the chosen instances;
183202
"""
184-
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass)
203+
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
185204

186205
#get entropy values for predictions
187206
entropy = _entropy(predictions)
@@ -193,7 +212,9 @@ def mc_dropout_max_entropy(classifier: BaseEstimator, X: modALinput, n_instances
193212

194213
def mc_dropout_max_variationRatios(classifier: BaseEstimator, X: modALinput, n_instances: int = 1,
195214
random_tie_break: bool = False, dropout_layer_indexes: list = [],
196-
num_cycles : int = 50, sample_per_forward_pass: int = 1000, **mc_dropout_kwargs) -> np.ndarray:
215+
num_cycles : int = 50, sample_per_forward_pass: int = 1000,
216+
logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor,
217+
**mc_dropout_kwargs) -> np.ndarray:
197218
"""
198219
Mc-Dropout maximum variation ratios query strategy. Returns the indexes of the instances
199220
with the largest variation ratios over multiple dropout cycles
@@ -215,14 +236,16 @@ def mc_dropout_max_variationRatios(classifier: BaseEstimator, X: modALinput, n_i
215236
sample_per_forward_pass: max. sample number for each forward pass.
216237
The allocated RAM does mainly depend on this.
217238
Small number --> small RAM allocation
239+
logits_adaptor: Callable which can be used to adapt the output of a forward pass
240+
to the required vector format for the vectorised metric functions
218241
**uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
219242
measure function.
220243
221244
Returns:
222245
The indices of the instances from X chosen to be labelled;
223246
The mc-dropout metric of the chosen instances;
224247
"""
225-
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass)
248+
predictions = get_predictions(classifier, X, dropout_layer_indexes, num_cycles, sample_per_forward_pass, logits_adaptor)
226249

227250
#get variation ratios values for predictions
228251
variationRatios = _variation_ratios(predictions)
@@ -233,7 +256,8 @@ def mc_dropout_max_variationRatios(classifier: BaseEstimator, X: modALinput, n_i
233256
return shuffled_argmax(variationRatios, n_instances=n_instances)
234257

235258
def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_indexes: list,
236-
num_predictions: int = 50, sample_per_forward_pass: int = 1000):
259+
num_predictions: int = 50, sample_per_forward_pass: int = 1000,
260+
logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor):
237261
"""
238262
Runs num_predictions times the prediction of the classifier on the input X
239263
and puts the predictions in a list.
@@ -247,6 +271,8 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
247271
sample_per_forward_pass: max. sample number for each forward pass.
248272
The allocated RAM does mainly depend on this.
249273
Small number --> small RAM allocation
274+
logits_adaptor: Callable which can be used to adapt the output of a forward pass
275+
to the required vector format for the vectorised metric functions
250276
Return:
251277
prediction: list with all predictions
252278
"""
@@ -258,14 +284,6 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
258284
# set dropout layers to train mode
259285
set_dropout_mode(classifier.estimator.module_, dropout_layer_indexes, train_mode=True)
260286

261-
if isinstance(X, Mapping): #check for dict
262-
for k, v in X.items():
263-
v.detach()
264-
elif torch.is_tensor(X): #check for tensor
265-
X.detach()
266-
else:
267-
raise RuntimeError("Error in model data type, only dict or tensors supported")
268-
269287
for i in range(num_predictions):
270288
split_args = []
271289

@@ -287,16 +305,19 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
287305

288306

289307
probas = None
308+
290309
for samples in split_args:
291310
#call Skorch infer function to perform model forward pass
292311
#In comparison to: predict(), predict_proba() the infer()
293312
# does not change train/eval mode of other layers
294-
prediction = classifier.estimator.infer(samples)
313+
logits = classifier.estimator.infer(samples)
314+
prediction = logits_adaptor(logits, samples)
295315
mask = ~prediction.isnan()
296316
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)
297317
prediction = to_numpy(prediction)
298318
probas = prediction if probas is None else np.vstack((probas, prediction))
299319

320+
300321
predictions.append(probas)
301322

302323
# set dropout layers to eval

0 commit comments

Comments
 (0)