3535)
3636
3737
38+ SIMPLE_NUMPY_TYPES = [nptype for type_cat , nptypes in np .sctypes .items ()
39+ for nptype in nptypes if type_cat != 'others' ]
40+ SIMPLE_TYPES = tuple ([bool , int , float , str ] + SIMPLE_NUMPY_TYPES )
41+
42+
3843def sklearn_to_flow (o , parent_model = None ):
3944 # TODO: assert that only on first recursion lvl `parent_model` can be None
40- simple_numpy_types = [nptype for type_cat , nptypes in np .sctypes .items ()
41- for nptype in nptypes
42- if type_cat != 'others' ]
43- simple_types = tuple ([bool , int , float , str ] + simple_numpy_types )
4445 if _is_estimator (o ):
4546 # is the main model or a submodel
4647 rval = _serialize_model (o )
@@ -49,8 +50,8 @@ def sklearn_to_flow(o, parent_model=None):
4950 rval = [sklearn_to_flow (element , parent_model ) for element in o ]
5051 if isinstance (o , tuple ):
5152 rval = tuple (rval )
52- elif isinstance (o , simple_types ) or o is None :
53- if isinstance (o , tuple (simple_numpy_types )):
53+ elif isinstance (o , SIMPLE_TYPES ) or o is None :
54+ if isinstance (o , tuple (SIMPLE_NUMPY_TYPES )):
5455 o = o .item ()
5556 # base parameter values
5657 rval = o
@@ -507,28 +508,34 @@ def _extract_information_from_model(model):
507508 rval = sklearn_to_flow (v , model )
508509
509510 def flatten_all (list_ ):
510- """ Flattens arbitrary depth lists of lists. """
511+ """ Flattens arbitrary depth lists of lists (e.g. [[1,2],[3,[1]]] -> [1,2,3,1]) . """
511512 for el in list_ :
512513 if isinstance (el , (list , tuple )):
513514 yield from flatten_all (el )
514515 else :
515516 yield el
516517
517- if isinstance ( rval , ( list , tuple )) :
518- nested_list_of_simple_types = all ([ isinstance ( el , ( bool , str , int , float ))
519- for el in flatten_all ( rval )])
520- else :
521- nested_list_of_simple_types = False
522-
523- if ( isinstance (rval , (list , tuple ))
518+ # In case rval is a list of lists (or tuples), we need to identify two situations :
519+ # - sklearn pipeline steps, feature union or base classifiers in voting classifier.
520+ # They look like e.g. [("imputer", Imputer()), ("classifier", SVC())]
521+ # - a list of lists with simple types (e.g. int or str), such as for an OrdinalEncoder
522+ # where all possible values for each feature are described: [[0,1,2], [1,2,5]]
523+ is_non_empty_list_of_lists_with_same_type = (
524+ isinstance (rval , (list , tuple ))
524525 and len (rval ) > 0
525526 and isinstance (rval [0 ], (list , tuple ))
526- and all ([isinstance (rval [i ], type (rval [0 ]))
527- for i in range (len (rval ))])
528- and not nested_list_of_simple_types ):
527+ and all ([isinstance (rval_i , type (rval [0 ])) for rval_i in rval ])
528+ )
529+
530+ nested_list_of_simple_types = (
531+ is_non_empty_list_of_lists_with_same_type
532+ and all ([isinstance (el , SIMPLE_TYPES ) for el in flatten_all (rval )])
533+ )
529534
530- # Steps in a pipeline or feature union, or base classifiers in
531- # voting classifier
535+ if is_non_empty_list_of_lists_with_same_type and not nested_list_of_simple_types :
536+ # If a list of lists is identified that include 'non-simple' types (e.g. objects),
537+ # we assume they are steps in a pipeline, feature union, or base classifiers in
538+ # a voting classifier.
532539 parameter_value = list ()
533540 reserved_keywords = set (model .get_params (deep = False ).keys ())
534541
0 commit comments