Skip to content

Commit 83d92cb

Browse files
committed
split_args replacement
1 parent a4750f0 commit 83d92cb

1 file changed

Lines changed: 19 additions & 19 deletions

File tree

modAL/dropout.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -282,26 +282,26 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
282282
# set dropout layers to train mode
283283
set_dropout_mode(classifier.estimator.module_, dropout_layer_indexes, train_mode=True)
284284

285-
for i in range(num_predictions):
286-
split_args = []
287-
288-
if isinstance(X, Mapping): #check for dict
289-
for k, v in X.items():
290-
v.detach()
291-
split_v = torch.split(v, sample_per_forward_pass)
292-
#create sub-dictionary split for each forward pass with same keys&values
293-
for split_idx, split in enumerate(split_v):
294-
if len(split_args)<=split_idx:
295-
split_args.append({})
296-
split_args[split_idx][k] = split
297-
298-
elif torch.is_tensor(X): #check for tensor
299-
X.detach()
300-
split_args = torch.split(X, sample_per_forward_pass)
301-
else:
302-
raise RuntimeError("Error in model data type, only dict or tensors supported")
303-
285+
split_args = []
286+
287+
if isinstance(X, Mapping): #check for dict
288+
for k, v in X.items():
289+
v.detach()
290+
split_v = torch.split(v, sample_per_forward_pass)
291+
#create sub-dictionary split for each forward pass with same keys&values
292+
for split_idx, split in enumerate(split_v):
293+
if len(split_args)<=split_idx:
294+
split_args.append({})
295+
split_args[split_idx][k] = split
304296

297+
elif torch.is_tensor(X): #check for tensor
298+
X.detach()
299+
split_args = torch.split(X, sample_per_forward_pass)
300+
else:
301+
raise RuntimeError("Error in model data type, only dict or tensors supported")
302+
303+
for i in range(num_predictions):
304+
305305
probas = None
306306

307307
for samples in split_args:

0 commit comments

Comments
 (0)