Skip to content

Commit b730ea4

Browse files
docs: improve utils
1 parent 12bb10b commit b730ea4

3 files changed

Lines changed: 31 additions & 37 deletions

File tree

modAL/utils/combination.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
from typing import Callable, Optional, Collection, Tuple
1+
from typing import Callable, Optional, Sequence, Tuple
22

33
import numpy as np
44
from sklearn.base import BaseEstimator
55

66
from modAL.utils.data import modALinput
77

88

9-
def make_linear_combination(*functions: Callable, weights: Optional[Collection] = None) -> Callable:
9+
def make_linear_combination(*functions: Callable, weights: Optional[Sequence] = None) -> Callable:
1010
"""
1111
Takes the given functions and makes a function which returns the linear combination of the output of original
1212
functions. It works well with functions returning numpy arrays of the same shape.
@@ -35,7 +35,7 @@ def linear_combination(*args, **kwargs):
3535
return linear_combination
3636

3737

38-
def make_product(*functions: Callable, exponents: Optional[Collection] = None) -> Callable:
38+
def make_product(*functions: Callable, exponents: Optional[Sequence] = None) -> Callable:
3939
"""
4040
Takes the given functions and makes a function which returns the product of the output of original functions. It
4141
works well with functions returning numpy arrays of the same shape.

modAL/utils/data.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import Union
2+
from itertools import chain
23

34
import numpy as np
45
import scipy.sparse as sp
@@ -7,12 +8,20 @@
78
modALinput = Union[list, np.ndarray, sp.csr_matrix]
89

910

10-
def data_vstack(blocks):
11+
def data_vstack(blocks: modALinput) -> modALinput:
1112
"""
1213
Stack vertically both sparse and dense arrays.
14+
15+
Args:
16+
blocks: Sequence of modALinput objects.
17+
18+
Returns:
19+
New sequence of vertically stacked elements.
1320
"""
1421
if isinstance(blocks[0], np.ndarray):
1522
return np.concatenate(blocks)
23+
elif isinstance(blocks[0], list):
24+
return list(chain(blocks))
1625
elif sp.issparse(blocks[0]):
1726
return sp.vstack(blocks)
1827
else:

modAL/utils/validation.py

Lines changed: 18 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,19 @@
1+
from typing import Sequence
2+
13
import numpy as np
24
from sklearn.exceptions import NotFittedError
5+
from sklearn.base import BaseEstimator
36

47

5-
def check_class_labels(*args):
8+
def check_class_labels(*args: BaseEstimator) -> bool:
69
"""
7-
Checks the known class labels for each classifier. Returns True if all classifier
8-
knows the same labels and returns False if not.
10+
Checks the known class labels for each classifier.
911
10-
Parameters
11-
----------
12-
*args: sklearn classifier objects
13-
Classifier objects to check the known class labels.
12+
Args:
13+
*args: Classifier objects to check the known class labels.
1414
15-
Returns
16-
-------
17-
bool
18-
True, if class labels match for all classifiers,
19-
False otherwise.
15+
Returns:
16+
True, if class labels match for all classifiers, False otherwise.
2017
"""
2118
try:
2219
classes_ = [estimator.classes_ for estimator in args]
@@ -30,31 +27,19 @@ def check_class_labels(*args):
3027
return True
3128

3229

33-
def check_class_proba(proba, known_labels, all_labels):
30+
def check_class_proba(proba: np.ndarray, known_labels: Sequence, all_labels: Sequence) -> np.ndarray:
3431
"""
35-
Checks the class probabilities and reshapes it if not all labels are present
36-
in the classifier.
37-
38-
Parameters
39-
----------
40-
proba: numpy.ndarray of shape (n_samples, n_known_classes)
41-
The class probabilities of a classifier.
42-
43-
known_labels:
44-
The class labels known by the classifier.
32+
Checks the class probabilities and reshapes it if not all labels are present in the classifier.
4533
46-
all_labels:
47-
All class labels.
48-
49-
Returns
50-
-------
51-
aug_proba: numpy.ndarray of shape (n_samples, n_classes)
52-
Class probabilities augmented such that the probability of all classes
53-
is present. If the classifier is unaware of a particular class, all
54-
probabilities are zero.
34+
Args:
35+
proba: The class probabilities of a classifier.
36+
known_labels: The class labels known by the classifier.
37+
all_labels: All class labels.
5538
39+
Returns:
40+
Class probabilities augmented such that the probability of all classes is present. If the classifier is unaware
41+
of a particular class, all probabilities are zero.
5642
"""
57-
5843
# TODO: rewrite this function using numpy.insert
5944

6045
label_idx_map = -np.ones(len(all_labels), dtype='int')

0 commit comments

Comments
 (0)