1+ from typing import Sequence
2+
13import numpy as np
24from sklearn .exceptions import NotFittedError
5+ from sklearn .base import BaseEstimator
36
47
5- def check_class_labels (* args ) :
8+ def check_class_labels (* args : BaseEstimator ) -> bool :
69 """
7- Checks the known class labels for each classifier. Returns True if all classifier
8- knows the same labels and returns False if not.
10+ Checks the known class labels for each classifier.
911
10- Parameters
11- ----------
12- *args: sklearn classifier objects
13- Classifier objects to check the known class labels.
12+ Args:
13+ *args: Classifier objects to check the known class labels.
1414
15- Returns
16- -------
17- bool
18- True, if class labels match for all classifiers,
19- False otherwise.
15+ Returns:
16+ True, if class labels match for all classifiers, False otherwise.
2017 """
2118 try :
2219 classes_ = [estimator .classes_ for estimator in args ]
@@ -30,31 +27,19 @@ def check_class_labels(*args):
3027 return True
3128
3229
33- def check_class_proba (proba , known_labels , all_labels ) :
30+ def check_class_proba (proba : np . ndarray , known_labels : Sequence , all_labels : Sequence ) -> np . ndarray :
3431 """
35- Checks the class probabilities and reshapes it if not all labels are present
36- in the classifier.
37-
38- Parameters
39- ----------
40- proba: numpy.ndarray of shape (n_samples, n_known_classes)
41- The class probabilities of a classifier.
42-
43- known_labels:
44- The class labels known by the classifier.
32+ Checks the class probabilities and reshapes it if not all labels are present in the classifier.
4533
46- all_labels:
47- All class labels.
48-
49- Returns
50- -------
51- aug_proba: numpy.ndarray of shape (n_samples, n_classes)
52- Class probabilities augmented such that the probability of all classes
53- is present. If the classifier is unaware of a particular class, all
54- probabilities are zero.
34+ Args:
35+ proba: The class probabilities of a classifier.
36+ known_labels: The class labels known by the classifier.
37+ all_labels: All class labels.
5538
39+ Returns:
40+ Class probabilities augmented such that the probability of all classes is present. If the classifier is unaware
41+ of a particular class, all probabilities are zero.
5642 """
57-
5843 # TODO: rewrite this function using numpy.insert
5944
6045 label_idx_map = - np .ones (len (all_labels ), dtype = 'int' )
0 commit comments