Skip to content

Commit aa41e59

Browse files
committed
Refactored for legibility and added comments.
1 parent 0e2fc2e commit aa41e59

1 file changed

Lines changed: 26 additions & 19 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,13 @@
3535
)
3636

3737

38+
SIMPLE_NUMPY_TYPES = [nptype for type_cat, nptypes in np.sctypes.items()
39+
for nptype in nptypes if type_cat != 'others']
40+
SIMPLE_TYPES = tuple([bool, int, float, str] + SIMPLE_NUMPY_TYPES)
41+
42+
3843
def sklearn_to_flow(o, parent_model=None):
3944
# TODO: assert that only on first recursion lvl `parent_model` can be None
40-
simple_numpy_types = [nptype for type_cat, nptypes in np.sctypes.items()
41-
for nptype in nptypes
42-
if type_cat != 'others']
43-
simple_types = tuple([bool, int, float, str] + simple_numpy_types)
4445
if _is_estimator(o):
4546
# is the main model or a submodel
4647
rval = _serialize_model(o)
@@ -49,8 +50,8 @@ def sklearn_to_flow(o, parent_model=None):
4950
rval = [sklearn_to_flow(element, parent_model) for element in o]
5051
if isinstance(o, tuple):
5152
rval = tuple(rval)
52-
elif isinstance(o, simple_types) or o is None:
53-
if isinstance(o, tuple(simple_numpy_types)):
53+
elif isinstance(o, SIMPLE_TYPES) or o is None:
54+
if isinstance(o, tuple(SIMPLE_NUMPY_TYPES)):
5455
o = o.item()
5556
# base parameter values
5657
rval = o
@@ -507,28 +508,34 @@ def _extract_information_from_model(model):
507508
rval = sklearn_to_flow(v, model)
508509

509510
def flatten_all(list_):
510-
""" Flattens arbitrary depth lists of lists. """
511+
""" Flattens arbitrary depth lists of lists (e.g. [[1,2],[3,[1]]] -> [1,2,3,1]). """
511512
for el in list_:
512513
if isinstance(el, (list, tuple)):
513514
yield from flatten_all(el)
514515
else:
515516
yield el
516517

517-
if isinstance(rval, (list, tuple)):
518-
nested_list_of_simple_types = all([isinstance(el, (bool, str, int, float))
519-
for el in flatten_all(rval)])
520-
else:
521-
nested_list_of_simple_types = False
522-
523-
if (isinstance(rval, (list, tuple))
518+
# In case rval is a list of lists (or tuples), we need to identify two situations:
519+
# - sklearn pipeline steps, feature union or base classifiers in voting classifier.
520+
# They look like e.g. [("imputer", Imputer()), ("classifier", SVC())]
521+
# - a list of lists with simple types (e.g. int or str), such as for an OrdinalEncoder
522+
# where all possible values for each feature are described: [[0,1,2], [1,2,5]]
523+
is_non_empty_list_of_lists_with_same_type = (
524+
isinstance(rval, (list, tuple))
524525
and len(rval) > 0
525526
and isinstance(rval[0], (list, tuple))
526-
and all([isinstance(rval[i], type(rval[0]))
527-
for i in range(len(rval))])
528-
and not nested_list_of_simple_types):
527+
and all([isinstance(rval_i, type(rval[0])) for rval_i in rval])
528+
)
529+
530+
nested_list_of_simple_types = (
531+
is_non_empty_list_of_lists_with_same_type
532+
and all([isinstance(el, SIMPLE_TYPES) for el in flatten_all(rval)])
533+
)
529534

530-
# Steps in a pipeline or feature union, or base classifiers in
531-
# voting classifier
535+
if is_non_empty_list_of_lists_with_same_type and not nested_list_of_simple_types:
536+
# If a list of lists is identified that include 'non-simple' types (e.g. objects),
537+
# we assume they are steps in a pipeline, feature union, or base classifiers in
538+
# a voting classifier.
532539
parameter_value = list()
533540
reserved_keywords = set(model.get_params(deep=False).keys())
534541

0 commit comments

Comments
 (0)