@@ -282,26 +282,26 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
282282 # set dropout layers to train mode
283283 set_dropout_mode (classifier .estimator .module_ , dropout_layer_indexes , train_mode = True )
284284
285- for i in range (num_predictions ):
286- split_args = []
287-
288- if isinstance (X , Mapping ): #check for dict
289- for k , v in X .items ():
290- v .detach ()
291- split_v = torch .split (v , sample_per_forward_pass )
292- #create sub-dictionary split for each forward pass with same keys&values
293- for split_idx , split in enumerate (split_v ):
294- if len (split_args )<= split_idx :
295- split_args .append ({})
296- split_args [split_idx ][k ] = split
297-
298- elif torch .is_tensor (X ): #check for tensor
299- X .detach ()
300- split_args = torch .split (X , sample_per_forward_pass )
301- else :
302- raise RuntimeError ("Error in model data type, only dict or tensors supported" )
303-
285+ split_args = []
286+
287+ if isinstance (X , Mapping ): #check for dict
288+ for k , v in X .items ():
289+ v .detach ()
290+ split_v = torch .split (v , sample_per_forward_pass )
291+ #create sub-dictionary split for each forward pass with same keys&values
292+ for split_idx , split in enumerate (split_v ):
293+ if len (split_args )<= split_idx :
294+ split_args .append ({})
295+ split_args [split_idx ][k ] = split
304296
297+ elif torch .is_tensor (X ): #check for tensor
298+ X .detach ()
299+ split_args = torch .split (X , sample_per_forward_pass )
300+ else :
301+ raise RuntimeError ("Error in model data type, only dict or tensors supported" )
302+
303+ for i in range (num_predictions ):
304+
305305 probas = None
306306
307307 for samples in split_args :
0 commit comments