Skip to content

Commit 2ae3a65

Browse files
committed
runtime_improvement, removement of np.vstack
1 parent 83d92cb commit 2ae3a65

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

modAL/dropout.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,12 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
284284

285285
split_args = []
286286

287+
number_of_samples = 0
288+
287289
if isinstance(X, Mapping): #check for dict
288290
for k, v in X.items():
291+
number_of_samples = v.size(0)
292+
289293
v.detach()
290294
split_v = torch.split(v, sample_per_forward_pass)
291295
#create sub-dictionary split for each forward pass with same keys&values
@@ -295,16 +299,18 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
295299
split_args[split_idx][k] = split
296300

297301
elif torch.is_tensor(X): #check for tensor
302+
number_of_samples = X.size(0)
298303
X.detach()
299304
split_args = torch.split(X, sample_per_forward_pass)
300305
else:
301306
raise RuntimeError("Error in model data type, only dict or tensors supported")
302307

308+
303309
for i in range(num_predictions):
304310

305311
probas = None
306312

307-
for samples in split_args:
313+
for index, samples in enumerate(split_args):
308314
#call Skorch infer function to perform model forward pass
309315
#In comparison to: predict(), predict_proba() the infer()
310316
# does not change train/eval mode of other layers
@@ -313,10 +319,12 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
313319
prediction = logits_adaptor(logits, samples)
314320
mask = ~prediction.isnan()
315321
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)
316-
prediction = to_numpy(prediction)
317-
probas = prediction if probas is None else np.vstack((probas, prediction))
318322

323+
if probas is None: probas = torch.empty((number_of_samples, prediction.shape[-1]))
324+
325+
probas[range(sample_per_forward_pass*index, sample_per_forward_pass*(index+1)), :] = prediction
319326

327+
probas = to_numpy(prediction)
320328
predictions.append(probas)
321329

322330
# set dropout layers to eval

0 commit comments

Comments
 (0)