Skip to content

Commit c354007

Browse files
committed
bugfix
1 parent edef889 commit c354007

1 file changed

Lines changed: 8 additions & 6 deletions

File tree

openml/extensions/sklearn/extension.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -918,9 +918,11 @@ def _get_parameter_values_recursive(param_grid: Union[Dict, List[Dict]],
918918
result.append(value)
919919
return result
920920
elif isinstance(param_grid, list):
921-
result = []
922-
result.extend(SklearnExtension._get_parameter_values_recursive(
923-
sub_grid, parameter_name) for sub_grid in param_grid)
921+
result = list()
922+
for sub_grid in param_grid:
923+
result.extend(SklearnExtension._get_parameter_values_recursive(sub_grid,
924+
parameter_name))
925+
return result
924926

925927
def _prevent_optimize_n_jobs(self, model):
926928
"""
@@ -947,9 +949,9 @@ def _prevent_optimize_n_jobs(self, model):
947949
print('Warning! Using subclass BaseSearchCV other than '
948950
'{GridSearchCV, RandomizedSearchCV}. '
949951
'Should implement param check. ')
950-
951-
if len(SklearnExtension._get_parameter_values_recursive(param_distributions,
952-
'n_jobs')) > 0:
952+
n_jobs_vals = SklearnExtension._get_parameter_values_recursive(param_distributions,
953+
'n_jobs')
954+
if len(n_jobs_vals) > 0:
953955
raise PyOpenMLError('openml-python should not be used to '
954956
'optimize the n_jobs parameter.')
955957

0 commit comments

Comments
 (0)