@@ -889,49 +889,52 @@ def _format_external_version(
889889 return '%s==%s' % (model_package_name , model_package_version_number )
890890
891891 @staticmethod
892- def _check_parameter_value_recursive (param_grid : Union [Dict , List [Dict ]], parameter_name : str , legal_values : Optional [List ]):
892+ def _check_parameter_value_recursive (param_grid : Union [Dict , List [Dict ]],
893+ parameter_name : str ,
894+ legal_values : Optional [List ]):
893895 """
894- Checks within a flow (recursively) whether a given hyperparameter complies to one of the values presented in a
895- grid. If the hyperparameter does not exist in the grid, True is returned.
896+ Checks within a flow (recursively) whether a given hyperparameter
897+ complies to one of the values presented in a grid. If the
898+ hyperparameter does not exist in the grid, True is returned.
896899
897900 Parameters
898901 ----------
899902 param_grid: Union[Dict, List[Dict]]
900- Dict mapping from hyperparameter list to value, to a list of such dicts
903+ Dict mapping from hyperparameter list to value, to a list of
904+ such dicts
901905
902906 parameter_name: str
903907 The hyperparameter that needs to be inspected
904908
905909 legal_values: List
906- The values that are accepted. None if no values are legal (the presence of the hyperparameter will trigger
907- to return False)
910+ The values that are accepted. None if no values are legal (the
911+ presence of the hyperparameter will trigger to return False)
908912
909913 Returns
910914 -------
911915 bool
912- True if all occurrences of the hyperparameter only have legal values, False otherwise
916+ True if all occurrences of the hyperparameter only have legal
917+ values, False otherwise
913918
914919 """
915920 if isinstance (param_grid , dict ):
916921 for param , value in param_grid .items ():
917922 # n_jobs is scikitlearn parameter for paralizing jobs
918923 if param .split ('__' )[- 1 ] == parameter_name :
919- # 0 = illegal value (?), 1 / None = use one core,
920- # n = use n cores,
921- # -1 = use all available cores -> this makes it hard to
922- # measure runtime in a fair way
923924 if legal_values is None or value not in legal_values :
924925 return False
925926 return True
926927 elif isinstance (param_grid , list ):
927928 return all (
928- SklearnExtension ._check_parameter_value_recursive (sub_grid , parameter_name , legal_values )
929+ SklearnExtension ._check_parameter_value_recursive (sub_grid ,
930+ parameter_name ,
931+ legal_values )
929932 for sub_grid in param_grid
930933 )
931934
932935 def _prevent_optimize_n_jobs (self , model ):
933936 """
934- Ensures that HPO classess will not optimize the n_jobs hyperparameter
937+ Ensures that HPO classes will not optimize the n_jobs hyperparameter
935938
936939 Parameters:
937940 -----------
@@ -955,7 +958,8 @@ def _prevent_optimize_n_jobs(self, model):
955958 '{GridSearchCV, RandomizedSearchCV}. '
956959 'Should implement param check. ' )
957960
958- if not SklearnExtension ._check_parameter_value_recursive (param_distributions , 'n_jobs' , None ):
961+ if not SklearnExtension ._check_parameter_value_recursive (param_distributions ,
962+ 'n_jobs' , None ):
959963 raise PyOpenMLError ('openml-python should not be used to '
960964 'optimize the n_jobs parameter.' )
961965
@@ -980,12 +984,14 @@ def _can_measure_cputime(self, model: Any) -> bool:
980984 raise ValueError ('model should be BaseEstimator or BaseSearchCV' )
981985
982986 # check the parameters for n_jobs
983- return SklearnExtension ._check_parameter_value_recursive (model .get_params (), 'n_jobs' , [1 , None ])
987+ return SklearnExtension ._check_parameter_value_recursive (model .get_params (),
988+ 'n_jobs' ,
989+ [1 , None ])
984990
985991 def _can_measure_wallclocktime (self , model : Any ) -> bool :
986992 """
987993 Returns True if the parameter settings of model are chosen s.t. the model
988- will run on a preset number of cores (if so, openml-python can measure wallclock time)
994+ will run on a preset number of cores (if so, openml-python can measure wall-clock time)
989995
990996 Parameters:
991997 -----------
@@ -1003,7 +1009,14 @@ def _can_measure_wallclocktime(self, model: Any) -> bool:
10031009 raise ValueError ('model should be BaseEstimator or BaseSearchCV' )
10041010
10051011 # check the parameters for n_jobs
1006- return not SklearnExtension ._check_parameter_value_recursive (model .get_params (), 'n_jobs' , [- 1 ])
1012+ # note that clause 1 will return True also when there is no occurrence
1013+ # of n_jobs (the negate will make this fn return false). For that
1014+ # reason, we need to add clause 2 that returns True if n_jobs does not
1015+ # exist in the flow
1016+ return not SklearnExtension ._check_parameter_value_recursive (
1017+ model .get_params (), 'n_jobs' , [- 1 ]) or \
1018+ SklearnExtension ._check_parameter_value_recursive (
1019+ model .get_params (), 'n_jobs' , None )
10071020
10081021 ################################################################################################
10091022 # Methods for performing runs with extension modules
@@ -1102,10 +1115,10 @@ def _run_model_on_fold(
11021115 information.
11031116
11041117 Furthermore, it will measure run time measures in case multi-core behaviour allows this.
1105- * exact user cpu time will be measured if the number of cores is set (recursive throughout the model)
1106- exactly to 1
1107- * wall clock time will be measured if the number of cores is set (recursive throughout the model) to any given
1108- number (but not when it is set to -1)
1118+ * exact user cpu time will be measured if the number of cores is set (recursive throughout
1119+ the model) exactly to 1
1120+ * wall clock time will be measured if the number of cores is set (recursive throughout the
1121+ model) to any given number (but not when it is set to -1)
11091122
11101123 Returns the data that is necessary to construct the OpenML Run object. Is used by
11111124 run_task_get_arff_content. Do not use this function unless you know what you are doing.
@@ -1182,7 +1195,7 @@ def _prediction_to_probabilities(
11821195 # but not desirable if we want to upload to OpenML).
11831196
11841197 model_copy = sklearn .base .clone (model , safe = True )
1185- # security check
1198+ # sanity check: prohibit users from optimizing n_jobs
11861199 self ._prevent_optimize_n_jobs (model_copy )
11871200 # Runtime can be measured if the model is run sequentially
11881201 can_measure_cputime = self ._can_measure_cputime (model_copy )
@@ -1228,7 +1241,8 @@ def _prediction_to_probabilities(
12281241 user_defined_measures ['usercpu_time_millis_training' ] = modelfit_duration_cputime
12291242 if can_measure_wallclocktime :
12301243 modelfit_duration_walltime = (time .time () - modelfit_start_walltime ) * 1000
1231- user_defined_measures ['wall_clock_time_millis_training' ] = modelfit_duration_walltime
1244+ user_defined_measures ['wall_clock_time_millis_training' ] = \
1245+ modelfit_duration_walltime
12321246
12331247 except AttributeError as e :
12341248 # typically happens when training a regressor on classification task
@@ -1264,14 +1278,16 @@ def _prediction_to_probabilities(
12641278 pred_y = model_copy .predict (test_x )
12651279
12661280 if can_measure_cputime :
1267- modelpredict_duration_cputime = (time .process_time () - modelpredict_start_cputime ) * 1000
1281+ modelpredict_duration_cputime = (time .process_time () -
1282+ modelpredict_start_cputime ) * 1000
12681283 user_defined_measures ['usercpu_time_millis_testing' ] = modelpredict_duration_cputime
1269- user_defined_measures ['usercpu_time_millis' ] = modelfit_duration_cputime + modelpredict_duration_cputime
1284+ user_defined_measures ['usercpu_time_millis' ] = (
1285+ modelfit_duration_cputime + modelpredict_duration_cputime )
12701286 if can_measure_wallclocktime :
12711287 modelpredict_duration_walltime = (time .time () - modelpredict_start_walltime ) * 1000
12721288 user_defined_measures ['wall_clock_time_millis_testing' ] = modelpredict_duration_walltime
1273- user_defined_measures ['wall_clock_time_millis' ] = modelfit_duration_walltime + \
1274- modelpredict_duration_walltime
1289+ user_defined_measures ['wall_clock_time_millis' ] = (
1290+ modelfit_duration_walltime + modelpredict_duration_walltime )
12751291
12761292 # add client-side calculated metrics. These is used on the server as
12771293 # consistency check, only useful for supervised tasks
0 commit comments