22import sys
33import torch
44from collections .abc import Mapping
5+ from typing import Callable
56
67from sklearn .base import BaseEstimator
78from sklearn .preprocessing import normalize
1314
1415from skorch .utils import to_numpy
1516
17+ def default_logits_adaptor (input_tensor : torch .tensor , samples : modALinput ):
18+ # default Callable parameter for get_predictions
19+ return input_tensor
20+
1621def KL_divergence (classifier : BaseEstimator , X : modALinput , n_instances : int = 1 ,
1722 random_tie_break : bool = False , dropout_layer_indexes : list = [],
1823 num_cycles : int = 50 , ** mc_dropout_kwargs ) -> np .ndarray :
@@ -36,7 +41,9 @@ def KL_divergence(classifier: BaseEstimator, X: modALinput, n_instances: int = 1
3641
3742def mc_dropout_multi (classifier : BaseEstimator , X : modALinput , query_strategies : list = ["bald" , "mean_st" , "max_entropy" , "max_var" ],
3843 n_instances : int = 1 , random_tie_break : bool = False , dropout_layer_indexes : list = [],
39- num_cycles : int = 50 , sample_per_forward_pass : int = 1000 , ** mc_dropout_kwargs ) -> np .ndarray :
44+ num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
45+ logits_adaptor : Callable [[torch .tensor , modALinput ], torch .tensor ] = default_logits_adaptor ,
46+ ** mc_dropout_kwargs ) -> np .ndarray :
4047 """
4148 Multi metric dropout query strategy. Returns the specified metrics for given input data.
4249 Selection of query strategies are:
@@ -49,7 +56,7 @@ def mc_dropout_multi(classifier: BaseEstimator, X: modALinput, query_strategies:
4956 Function returns dictionary of metrics with their name as key.
5057 The indices of the n-best samples (n_instances) is not used in this function.
5158 """
52- predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass )
59+ predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass , logits_adaptor )
5360
5461 metrics_dict = {}
5562 if "bald" in query_strategies :
@@ -65,7 +72,9 @@ def mc_dropout_multi(classifier: BaseEstimator, X: modALinput, query_strategies:
6572
6673def mc_dropout_bald (classifier : BaseEstimator , X : modALinput , n_instances : int = 1 ,
6774 random_tie_break : bool = False , dropout_layer_indexes : list = [],
68- num_cycles : int = 50 , sample_per_forward_pass : int = 1000 , ** mc_dropout_kwargs ) -> np .ndarray :
75+ num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
76+ logits_adaptor : Callable [[torch .tensor , modALinput ], torch .tensor ] = default_logits_adaptor ,
77+ ** mc_dropout_kwargs ,) -> np .ndarray :
6978 """
7079 Mc-Dropout bald query strategy. Returns the indexes of the instances with the largest BALD
7180 (Bayesian Active Learning by Disagreement) score calculated through the dropout cycles
@@ -91,14 +100,16 @@ def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int =
91100 sample_per_forward_pass: max. sample number for each forward pass.
92101 The allocated RAM does mainly depend on this.
93102 Small number --> small RAM allocation
103+ logits_adaptor: Callable which can be used to adapt the output of a forward pass
104+ to the required vector format for the vectorised metric functions
94105 **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
95106 measure function.
96107
97108 Returns:
98109 The indices of the instances from X chosen to be labelled;
99110 The mc-dropout metric of the chosen instances;
100111 """
101- predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass )
112+ predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass , logits_adaptor )
102113
103114 #calculate BALD (Bayesian active learning divergence))
104115 bald_scores = _bald_divergence (predictions )
@@ -110,7 +121,9 @@ def mc_dropout_bald(classifier: BaseEstimator, X: modALinput, n_instances: int =
110121
111122def mc_dropout_mean_st (classifier : BaseEstimator , X : modALinput , n_instances : int = 1 ,
112123 random_tie_break : bool = False , dropout_layer_indexes : list = [],
113- num_cycles : int = 50 , sample_per_forward_pass : int = 1000 , ** mc_dropout_kwargs ) -> np .ndarray :
124+ num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
125+ logits_adaptor : Callable [[torch .tensor , modALinput ], torch .tensor ] = default_logits_adaptor ,
126+ ** mc_dropout_kwargs ) -> np .ndarray :
114127 """
115128 Mc-Dropout mean standard deviation query strategy. Returns the indexes of the instances
116129 with the largest mean of the per class calculated standard deviations over multiple dropout cycles
@@ -132,6 +145,8 @@ def mc_dropout_mean_st(classifier: BaseEstimator, X: modALinput, n_instances: in
132145 sample_per_forward_pass: max. sample number for each forward pass.
133146 The allocated RAM does mainly depend on this.
134147 Small number --> small RAM allocation
148+ logits_adaptor: Callable which can be used to adapt the output of a forward pass
149+ to the required vector format for the vectorised metric functions
135150 **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
136151 measure function.
137152
@@ -141,7 +156,7 @@ def mc_dropout_mean_st(classifier: BaseEstimator, X: modALinput, n_instances: in
141156 """
142157
143158 # set dropout layers to train mode
144- predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass )
159+ predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass , logits_adaptor )
145160
146161 mean_standard_deviations = _mean_standard_deviation (predictions )
147162
@@ -152,7 +167,9 @@ def mc_dropout_mean_st(classifier: BaseEstimator, X: modALinput, n_instances: in
152167
153168def mc_dropout_max_entropy (classifier : BaseEstimator , X : modALinput , n_instances : int = 1 ,
154169 random_tie_break : bool = False , dropout_layer_indexes : list = [],
155- num_cycles : int = 50 , sample_per_forward_pass : int = 1000 , ** mc_dropout_kwargs ) -> np .ndarray :
170+ num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
171+ logits_adaptor : Callable [[torch .tensor , modALinput ], torch .tensor ] = default_logits_adaptor ,
172+ ** mc_dropout_kwargs ) -> np .ndarray :
156173 """
157174 Mc-Dropout maximum entropy query strategy. Returns the indexes of the instances
158175 with the largest entropy of the per class calculated entropies over multiple dropout cycles
@@ -174,14 +191,16 @@ def mc_dropout_max_entropy(classifier: BaseEstimator, X: modALinput, n_instances
174191 sample_per_forward_pass: max. sample number for each forward pass.
175192 The allocated RAM does mainly depend on this.
176193 Small number --> small RAM allocation
194+ logits_adaptor: Callable which can be used to adapt the output of a forward pass
195+ to the required vector format for the vectorised metric functions
177196 **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
178197 measure function.
179198
180199 Returns:
181200 The indices of the instances from X chosen to be labelled;
182201 The mc-dropout metric of the chosen instances;
183202 """
184- predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass )
203+ predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass , logits_adaptor )
185204
186205 #get entropy values for predictions
187206 entropy = _entropy (predictions )
@@ -193,7 +212,9 @@ def mc_dropout_max_entropy(classifier: BaseEstimator, X: modALinput, n_instances
193212
194213def mc_dropout_max_variationRatios (classifier : BaseEstimator , X : modALinput , n_instances : int = 1 ,
195214 random_tie_break : bool = False , dropout_layer_indexes : list = [],
196- num_cycles : int = 50 , sample_per_forward_pass : int = 1000 , ** mc_dropout_kwargs ) -> np .ndarray :
215+ num_cycles : int = 50 , sample_per_forward_pass : int = 1000 ,
216+ logits_adaptor : Callable [[torch .tensor , modALinput ], torch .tensor ] = default_logits_adaptor ,
217+ ** mc_dropout_kwargs ) -> np .ndarray :
197218 """
198219 Mc-Dropout maximum variation ratios query strategy. Returns the indexes of the instances
199220 with the largest variation ratios over multiple dropout cycles
@@ -215,14 +236,16 @@ def mc_dropout_max_variationRatios(classifier: BaseEstimator, X: modALinput, n_i
215236 sample_per_forward_pass: max. sample number for each forward pass.
216237 The allocated RAM does mainly depend on this.
217238 Small number --> small RAM allocation
239+ logits_adaptor: Callable which can be used to adapt the output of a forward pass
240+ to the required vector format for the vectorised metric functions
218241 **uncertainty_measure_kwargs: Keyword arguments to be passed for the uncertainty
219242 measure function.
220243
221244 Returns:
222245 The indices of the instances from X chosen to be labelled;
223246 The mc-dropout metric of the chosen instances;
224247 """
225- predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass )
248+ predictions = get_predictions (classifier , X , dropout_layer_indexes , num_cycles , sample_per_forward_pass , logits_adaptor )
226249
227250 #get variation ratios values for predictions
228251 variationRatios = _variation_ratios (predictions )
@@ -233,7 +256,8 @@ def mc_dropout_max_variationRatios(classifier: BaseEstimator, X: modALinput, n_i
233256 return shuffled_argmax (variationRatios , n_instances = n_instances )
234257
235258def get_predictions (classifier : BaseEstimator , X : modALinput , dropout_layer_indexes : list ,
236- num_predictions : int = 50 , sample_per_forward_pass : int = 1000 ):
259+ num_predictions : int = 50 , sample_per_forward_pass : int = 1000 ,
260+ logits_adaptor : Callable [[torch .tensor , modALinput ], torch .tensor ] = default_logits_adaptor ):
237261 """
238262 Runs num_predictions times the prediction of the classifier on the input X
239263 and puts the predictions in a list.
@@ -247,6 +271,8 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
247271 sample_per_forward_pass: max. sample number for each forward pass.
248272 The allocated RAM does mainly depend on this.
249273 Small number --> small RAM allocation
274+ logits_adaptor: Callable which can be used to adapt the output of a forward pass
275+ to the required vector format for the vectorised metric functions
250276 Return:
251277 prediction: list with all predictions
252278 """
@@ -258,14 +284,6 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
258284 # set dropout layers to train mode
259285 set_dropout_mode (classifier .estimator .module_ , dropout_layer_indexes , train_mode = True )
260286
261- if isinstance (X , Mapping ): #check for dict
262- for k , v in X .items ():
263- v .detach ()
264- elif torch .is_tensor (X ): #check for tensor
265- X .detach ()
266- else :
267- raise RuntimeError ("Error in model data type, only dict or tensors supported" )
268-
269287 for i in range (num_predictions ):
270288 split_args = []
271289
@@ -287,16 +305,19 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
287305
288306
289307 probas = None
308+
290309 for samples in split_args :
291310 #call Skorch infer function to perform model forward pass
292311 #In comparison to: predict(), predict_proba() the infer()
293312 # does not change train/eval mode of other layers
294- prediction = classifier .estimator .infer (samples )
313+ logits = classifier .estimator .infer (samples )
314+ prediction = logits_adaptor (logits , samples )
295315 mask = ~ prediction .isnan ()
296316 prediction [mask ] = prediction [mask ].unsqueeze (0 ).softmax (1 )
297317 prediction = to_numpy (prediction )
298318 probas = prediction if probas is None else np .vstack ((probas , prediction ))
299319
320+
300321 predictions .append (probas )
301322
302323 # set dropout layers to eval
0 commit comments