Skip to content

Commit d61b500

Browse files
committed
add: expected error reduction almost implemented
1 parent d623d93 commit d61b500

1 file changed

Lines changed: 37 additions & 0 deletions

File tree

modAL/expected_error_reduction.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""
2+
Expected error reduction framework for active learning.
3+
"""
4+
5+
from typing import Tuple
6+
7+
import numpy as np
8+
9+
from scipy.stats import entropy
10+
from sklearn.base import clone
11+
12+
from modAL.models import ActiveLearner
13+
from modAL.utils.data import modALinput, data_vstack
14+
from modAL.utils.selection import multi_argmax
15+
16+
17+
def expected_error_reduction(classifier: ActiveLearner, X: modALinput,
18+
p_subsample=1.0: np.float, n_instances=1: int) -> Tuple[np.ndarray, modALinput]:
19+
20+
expected_error = np.full(shape=(len(X), ), fill_value=-np.nan)
21+
possible_labels = np.unique(classifier.y_training)
22+
23+
for x_idx, x in enumerate(X):
24+
# subsample the data if needed
25+
if np.random.rand() <= p_subsample:
26+
# estimate the expected error
27+
for y in possible_labels:
28+
X_new = data_vstack((classifier.X_training, x))
29+
y_new = None
30+
31+
refitted_estimator = clone(classifier.estimator).fit()
32+
33+
34+
query_idx = multi_argmax(expected_error, n_instances)
35+
36+
return query_idx, X[query_idx]
37+

0 commit comments

Comments
 (0)