@@ -284,8 +284,12 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
284284
285285 split_args = []
286286
287+ number_of_samples = 0
288+
287289 if isinstance (X , Mapping ): #check for dict
288290 for k , v in X .items ():
291+ number_of_samples = v .size (0 )
292+
289293 v .detach ()
290294 split_v = torch .split (v , sample_per_forward_pass )
291295 #create sub-dictionary split for each forward pass with same keys&values
@@ -295,16 +299,18 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
295299 split_args [split_idx ][k ] = split
296300
297301 elif torch .is_tensor (X ): #check for tensor
302+ number_of_samples = X .size (0 )
298303 X .detach ()
299304 split_args = torch .split (X , sample_per_forward_pass )
300305 else :
301306 raise RuntimeError ("Error in model data type, only dict or tensors supported" )
302307
308+
303309 for i in range (num_predictions ):
304310
305311 probas = None
306312
307- for samples in split_args :
313+ for index , samples in enumerate ( split_args ) :
308314 #call Skorch infer function to perform model forward pass
309315 #In comparison to: predict(), predict_proba() the infer()
310316 # does not change train/eval mode of other layers
@@ -313,10 +319,12 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
313319 prediction = logits_adaptor (logits , samples )
314320 mask = ~ prediction .isnan ()
315321 prediction [mask ] = prediction [mask ].unsqueeze (0 ).softmax (1 )
316- prediction = to_numpy (prediction )
317- probas = prediction if probas is None else np .vstack ((probas , prediction ))
318322
323+ if probas is None : probas = torch .empty ((number_of_samples , prediction .shape [- 1 ]))
324+
325+ probas [range (sample_per_forward_pass * index , sample_per_forward_pass * (index + 1 )), :] = prediction
319326
327+ probas = to_numpy (prediction )
320328 predictions .append (probas )
321329
322330 # set dropout layers to eval
0 commit comments