Skip to content

Commit 7238857

Browse files
committed
incorporated Pieters review
1 parent 67163a2 commit 7238857

3 files changed

Lines changed: 52 additions & 30 deletions

File tree

openml/extensions/sklearn/extension.py

Lines changed: 43 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -889,49 +889,52 @@ def _format_external_version(
889889
return '%s==%s' % (model_package_name, model_package_version_number)
890890

891891
@staticmethod
892-
def _check_parameter_value_recursive(param_grid: Union[Dict, List[Dict]], parameter_name: str, legal_values: Optional[List]):
892+
def _check_parameter_value_recursive(param_grid: Union[Dict, List[Dict]],
893+
parameter_name: str,
894+
legal_values: Optional[List]):
893895
"""
894-
Checks within a flow (recursively) whether a given hyperparameter complies to one of the values presented in a
895-
grid. If the hyperparameter does not exist in the grid, True is returned.
896+
Checks within a flow (recursively) whether a given hyperparameter
897+
complies to one of the values presented in a grid. If the
898+
hyperparameter does not exist in the grid, True is returned.
896899
897900
Parameters
898901
----------
899902
param_grid: Union[Dict, List[Dict]]
900-
Dict mapping from hyperparameter list to value, to a list of such dicts
903+
Dict mapping from hyperparameter list to value, to a list of
904+
such dicts
901905
902906
parameter_name: str
903907
The hyperparameter that needs to be inspected
904908
905909
legal_values: List
906-
The values that are accepted. None if no values are legal (the presence of the hyperparameter will trigger
907-
to return False)
910+
The values that are accepted. None if no values are legal (the
911+
presence of the hyperparameter will trigger to return False)
908912
909913
Returns
910914
-------
911915
bool
912-
True if all occurrences of the hyperparameter only have legal values, False otherwise
916+
True if all occurrences of the hyperparameter only have legal
917+
values, False otherwise
913918
914919
"""
915920
if isinstance(param_grid, dict):
916921
for param, value in param_grid.items():
917922
# n_jobs is scikitlearn parameter for paralizing jobs
918923
if param.split('__')[-1] == parameter_name:
919-
# 0 = illegal value (?), 1 / None = use one core,
920-
# n = use n cores,
921-
# -1 = use all available cores -> this makes it hard to
922-
# measure runtime in a fair way
923924
if legal_values is None or value not in legal_values:
924925
return False
925926
return True
926927
elif isinstance(param_grid, list):
927928
return all(
928-
SklearnExtension._check_parameter_value_recursive(sub_grid, parameter_name, legal_values)
929+
SklearnExtension._check_parameter_value_recursive(sub_grid,
930+
parameter_name,
931+
legal_values)
929932
for sub_grid in param_grid
930933
)
931934

932935
def _prevent_optimize_n_jobs(self, model):
933936
"""
934-
Ensures that HPO classess will not optimize the n_jobs hyperparameter
937+
Ensures that HPO classes will not optimize the n_jobs hyperparameter
935938
936939
Parameters:
937940
-----------
@@ -955,7 +958,8 @@ def _prevent_optimize_n_jobs(self, model):
955958
'{GridSearchCV, RandomizedSearchCV}. '
956959
'Should implement param check. ')
957960

958-
if not SklearnExtension._check_parameter_value_recursive(param_distributions, 'n_jobs', None):
961+
if not SklearnExtension._check_parameter_value_recursive(param_distributions,
962+
'n_jobs', None):
959963
raise PyOpenMLError('openml-python should not be used to '
960964
'optimize the n_jobs parameter.')
961965

@@ -980,12 +984,14 @@ def _can_measure_cputime(self, model: Any) -> bool:
980984
raise ValueError('model should be BaseEstimator or BaseSearchCV')
981985

982986
# check the parameters for n_jobs
983-
return SklearnExtension._check_parameter_value_recursive(model.get_params(), 'n_jobs', [1, None])
987+
return SklearnExtension._check_parameter_value_recursive(model.get_params(),
988+
'n_jobs',
989+
[1, None])
984990

985991
def _can_measure_wallclocktime(self, model: Any) -> bool:
986992
"""
987993
Returns True if the parameter settings of model are chosen s.t. the model
988-
will run on a preset number of cores (if so, openml-python can measure wallclock time)
994+
will run on a preset number of cores (if so, openml-python can measure wall-clock time)
989995
990996
Parameters:
991997
-----------
@@ -1003,7 +1009,14 @@ def _can_measure_wallclocktime(self, model: Any) -> bool:
10031009
raise ValueError('model should be BaseEstimator or BaseSearchCV')
10041010

10051011
# check the parameters for n_jobs
1006-
return not SklearnExtension._check_parameter_value_recursive(model.get_params(), 'n_jobs', [-1])
1012+
# note that clause 1 will return True also when there is no occurrence
1013+
# of n_jobs (the negate will make this fn return false). For that
1014+
# reason, we need to add clause 2 that returns True if n_jobs does not
1015+
# exist in the flow
1016+
return not SklearnExtension._check_parameter_value_recursive(
1017+
model.get_params(), 'n_jobs', [-1]) or \
1018+
SklearnExtension._check_parameter_value_recursive(
1019+
model.get_params(), 'n_jobs', None)
10071020

10081021
################################################################################################
10091022
# Methods for performing runs with extension modules
@@ -1102,10 +1115,10 @@ def _run_model_on_fold(
11021115
information.
11031116
11041117
Furthermore, it will measure run time measures in case multi-core behaviour allows this.
1105-
* exact user cpu time will be measured if the number of cores is set (recursive throughout the model)
1106-
exactly to 1
1107-
* wall clock time will be measured if the number of cores is set (recursive throughout the model) to any given
1108-
number (but not when it is set to -1)
1118+
* exact user cpu time will be measured if the number of cores is set (recursive throughout
1119+
the model) exactly to 1
1120+
* wall clock time will be measured if the number of cores is set (recursive throughout the
1121+
model) to any given number (but not when it is set to -1)
11091122
11101123
Returns the data that is necessary to construct the OpenML Run object. Is used by
11111124
run_task_get_arff_content. Do not use this function unless you know what you are doing.
@@ -1182,7 +1195,7 @@ def _prediction_to_probabilities(
11821195
# but not desirable if we want to upload to OpenML).
11831196

11841197
model_copy = sklearn.base.clone(model, safe=True)
1185-
# security check
1198+
# sanity check: prohibit users from optimizing n_jobs
11861199
self._prevent_optimize_n_jobs(model_copy)
11871200
# Runtime can be measured if the model is run sequentially
11881201
can_measure_cputime = self._can_measure_cputime(model_copy)
@@ -1228,7 +1241,8 @@ def _prediction_to_probabilities(
12281241
user_defined_measures['usercpu_time_millis_training'] = modelfit_duration_cputime
12291242
if can_measure_wallclocktime:
12301243
modelfit_duration_walltime = (time.time() - modelfit_start_walltime) * 1000
1231-
user_defined_measures['wall_clock_time_millis_training'] = modelfit_duration_walltime
1244+
user_defined_measures['wall_clock_time_millis_training'] = \
1245+
modelfit_duration_walltime
12321246

12331247
except AttributeError as e:
12341248
# typically happens when training a regressor on classification task
@@ -1264,14 +1278,16 @@ def _prediction_to_probabilities(
12641278
pred_y = model_copy.predict(test_x)
12651279

12661280
if can_measure_cputime:
1267-
modelpredict_duration_cputime = (time.process_time() - modelpredict_start_cputime) * 1000
1281+
modelpredict_duration_cputime = (time.process_time() -
1282+
modelpredict_start_cputime) * 1000
12681283
user_defined_measures['usercpu_time_millis_testing'] = modelpredict_duration_cputime
1269-
user_defined_measures['usercpu_time_millis'] = modelfit_duration_cputime + modelpredict_duration_cputime
1284+
user_defined_measures['usercpu_time_millis'] = (
1285+
modelfit_duration_cputime + modelpredict_duration_cputime)
12701286
if can_measure_wallclocktime:
12711287
modelpredict_duration_walltime = (time.time() - modelpredict_start_walltime) * 1000
12721288
user_defined_measures['wall_clock_time_millis_testing'] = modelpredict_duration_walltime
1273-
user_defined_measures['wall_clock_time_millis'] = modelfit_duration_walltime + \
1274-
modelpredict_duration_walltime
1289+
user_defined_measures['wall_clock_time_millis'] = (
1290+
modelfit_duration_walltime + modelpredict_duration_walltime)
12751291

12761292
# add client-side calculated metrics. These is used on the server as
12771293
# consistency check, only useful for supervised tasks

openml/testing.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,15 @@ def _check_fold_timing_evaluations(
158158
# a dict mapping from openml measure to a tuple with the minimum and
159159
# maximum allowed value
160160
check_measures = {
161+
# should take at least one millisecond (?)
161162
'usercpu_time_millis_testing': (0, max_time_allowed),
162163
'usercpu_time_millis_training': (0, max_time_allowed),
163-
# should take at least one millisecond (?)
164-
'usercpu_time_millis': (0, max_time_allowed)}
164+
'usercpu_time_millis': (0, max_time_allowed),
165+
'wall_clock_time_millis_training': (0, max_time_allowed),
166+
'wall_clock_time_millis_testing': (0, max_time_allowed),
167+
'wall_clock_time_millis': (0, max_time_allowed),
168+
'predictive_accuracy': (0, 1)
169+
}
165170

166171
if task_type in (TaskTypeEnum.SUPERVISED_CLASSIFICATION, TaskTypeEnum.LEARNING_CURVE):
167172
check_measures['predictive_accuracy'] = (0, 1.)

tests/test_runs/test_run_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ def _check_sample_evaluations(self, sample_evaluations, num_repeats,
293293
'wall_clock_time_millis_training': (0, max_time_allowed),
294294
'wall_clock_time_millis_testing': (0, max_time_allowed),
295295
'wall_clock_time_millis': (0, max_time_allowed),
296-
'predictive_accuracy': (0, 1)}
296+
'predictive_accuracy': (0, 1)
297+
}
297298

298299
self.assertIsInstance(sample_evaluations, dict)
299300
if sys.version_info[:2] >= (3, 3):

0 commit comments

Comments
 (0)