Skip to content

Commit ad081bb

Browse files
committed
change position of softmax in get_predictions
1 parent 2ae3a65 commit ad081bb

1 file changed

Lines changed: 4 additions & 5 deletions

File tree

modAL/dropout.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,16 +315,15 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
315315
#In comparison to: predict(), predict_proba() the infer()
316316
# does not change train/eval mode of other layers
317317
logits = classifier.estimator.infer(samples)
318-
319318
prediction = logits_adaptor(logits, samples)
320-
mask = ~prediction.isnan()
321-
prediction[mask] = prediction[mask].unsqueeze(0).softmax(1)
322319

323320
if probas is None: probas = torch.empty((number_of_samples, prediction.shape[-1]))
324-
325321
probas[range(sample_per_forward_pass*index, sample_per_forward_pass*(index+1)), :] = prediction
326322

327-
probas = to_numpy(prediction)
323+
324+
mask = ~probas.isnan()
325+
probas[mask] = probas[mask].unsqueeze(0).softmax(1)
326+
probas = to_numpy(probas)
328327
predictions.append(probas)
329328

330329
# set dropout layers to eval

0 commit comments

Comments
 (0)