@@ -104,6 +104,15 @@ def run_model_on_task( # noqa: PLR0913
104104 Result of the run.
105105 flow : OpenMLFlow (optional, only if `return_flow` is True).
106106 Flow generated from the model.
107+
108+ Examples
109+ --------
110+ >>> import openml
111+ >>> import openml_sklearn # doctest: +SKIP
112+ >>> from sklearn.tree import DecisionTreeClassifier # doctest: +SKIP
113+ >>> clf = DecisionTreeClassifier() # doctest: +SKIP
114+ >>> task = openml.tasks.get_task(1) # doctest: +SKIP
115+ >>> run = openml.runs.run_model_on_task(clf, task) # doctest: +SKIP
107116 """
108117 if avoid_duplicate_runs is None :
109118 avoid_duplicate_runs = openml .config .avoid_duplicate_runs
@@ -273,9 +282,7 @@ def run_flow_on_task( # noqa: C901, PLR0912, PLR0915, PLR0913
273282 setup_id = setup_exists (flow_from_server )
274283 ids = run_exists (task .task_id , setup_id )
275284 if ids :
276- error_message = (
277- "One or more runs of this setup were already performed on the task."
278- )
285+ error_message = "One or more runs of this setup were already performed on the task."
279286 raise OpenMLRunsExistError (ids , error_message )
280287 else :
281288 # Flow does not exist on server and we do not want to upload it.
@@ -505,11 +512,15 @@ def _run_task_get_arffcontent( # noqa: PLR0915, PLR0912, C901
505512 # this information is multiple times overwritten, but due to the ordering
506513 # of tne loops, eventually it contains the information based on the full
507514 # dataset size
508- user_defined_measures_per_fold = OrderedDict () # type: 'OrderedDict[str, OrderedDict]'
515+ user_defined_measures_per_fold = (
516+ OrderedDict ()
517+ ) # type: 'OrderedDict[str, OrderedDict]'
509518 # stores sample-based evaluation measures (sublevel of fold-based)
510519 # will also be filled on a non sample-based task, but the information
511520 # is the same as the fold-based measures, and disregarded in that case
512- user_defined_measures_per_sample = OrderedDict () # type: 'OrderedDict[str, OrderedDict]'
521+ user_defined_measures_per_sample = (
522+ OrderedDict ()
523+ ) # type: 'OrderedDict[str, OrderedDict]'
513524
514525 # TODO use different iterator to only provide a single iterator (less
515526 # methods, less maintenance, less confusion)
@@ -557,9 +568,14 @@ def _run_task_get_arffcontent( # noqa: PLR0915, PLR0912, C901
557568 ) # job_rvals contain the output of all the runs with one-to-one correspondence with `jobs`
558569
559570 for n_fit , rep_no , fold_no , sample_no in jobs :
560- pred_y , proba_y , test_indices , test_y , inner_trace , user_defined_measures_fold = job_rvals [
561- n_fit - 1
562- ]
571+ (
572+ pred_y ,
573+ proba_y ,
574+ test_indices ,
575+ test_y ,
576+ inner_trace ,
577+ user_defined_measures_fold ,
578+ ) = job_rvals [n_fit - 1 ]
563579
564580 if inner_trace is not None :
565581 traces .append (inner_trace )
@@ -598,7 +614,11 @@ def _calculate_local_measure( # type: ignore
598614 if isinstance (test_y [i ], (int , np .integer ))
599615 else test_y [i ]
600616 )
601- pred_prob = proba_y .iloc [i ] if isinstance (proba_y , pd .DataFrame ) else proba_y [i ]
617+ pred_prob = (
618+ proba_y .iloc [i ]
619+ if isinstance (proba_y , pd .DataFrame )
620+ else proba_y [i ]
621+ )
602622
603623 arff_line = format_prediction (
604624 task = task ,
@@ -661,11 +681,13 @@ def _calculate_local_measure( # type: ignore
661681 if rep_no not in user_defined_measures_per_sample [measure ]:
662682 user_defined_measures_per_sample [measure ][rep_no ] = OrderedDict ()
663683 if fold_no not in user_defined_measures_per_sample [measure ][rep_no ]:
664- user_defined_measures_per_sample [measure ][rep_no ][fold_no ] = OrderedDict ()
684+ user_defined_measures_per_sample [measure ][rep_no ][
685+ fold_no
686+ ] = OrderedDict ()
665687
666- user_defined_measures_per_fold [measure ][rep_no ][fold_no ] = user_defined_measures_fold [
667- measure
668- ]
688+ user_defined_measures_per_fold [measure ][rep_no ][fold_no ] = (
689+ user_defined_measures_fold [ measure ]
690+ )
669691 user_defined_measures_per_sample [measure ][rep_no ][fold_no ][sample_no ] = (
670692 user_defined_measures_fold [measure ]
671693 )
@@ -821,7 +843,9 @@ def get_run(run_id: int, ignore_cache: bool = False) -> OpenMLRun: # noqa: FBT0
821843 run : OpenMLRun
822844 Run corresponding to ID, fetched from the server.
823845 """
824- run_dir = Path (openml .utils ._create_cache_directory_for_id (RUNS_CACHE_DIR_NAME , run_id ))
846+ run_dir = Path (
847+ openml .utils ._create_cache_directory_for_id (RUNS_CACHE_DIR_NAME , run_id )
848+ )
825849 run_file = run_dir / "description.xml"
826850
827851 run_dir .mkdir (parents = True , exist_ok = True )
@@ -840,7 +864,9 @@ def get_run(run_id: int, ignore_cache: bool = False) -> OpenMLRun: # noqa: FBT0
840864 return _create_run_from_xml (run_xml )
841865
842866
843- def _create_run_from_xml (xml : str , from_server : bool = True ) -> OpenMLRun : # noqa: PLR0915, PLR0912, C901, FBT002
867+ def _create_run_from_xml (
868+ xml : str , from_server : bool = True
869+ ) -> OpenMLRun : # noqa: PLR0915, PLR0912, C901, FBT002
844870 """Create a run object from xml returned from server.
845871
846872 Parameters
@@ -870,11 +896,13 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None): # type: ignore
870896 if not from_server :
871897 return None
872898
873- raise AttributeError ("Run XML does not contain required (server) field: " , fieldname )
899+ raise AttributeError (
900+ "Run XML does not contain required (server) field: " , fieldname
901+ )
874902
875- run = xmltodict .parse (xml , force_list = [ "oml:file" , "oml:evaluation" , "oml:parameter_setting" ])[
876- "oml:run"
877- ]
903+ run = xmltodict .parse (
904+ xml , force_list = [ "oml:file" , "oml:evaluation" , "oml:parameter_setting" ]
905+ )[ "oml:run" ]
878906 run_id = obtain_field (run , "oml:run_id" , from_server , cast = int )
879907 uploader = obtain_field (run , "oml:uploader" , from_server , cast = int )
880908 uploader_name = obtain_field (run , "oml:uploader_name" , from_server )
@@ -1029,7 +1057,9 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None): # type: ignore
10291057
10301058def _get_cached_run (run_id : int ) -> OpenMLRun :
10311059 """Load a run from the cache."""
1032- run_cache_dir = openml .utils ._create_cache_directory_for_id (RUNS_CACHE_DIR_NAME , run_id )
1060+ run_cache_dir = openml .utils ._create_cache_directory_for_id (
1061+ RUNS_CACHE_DIR_NAME , run_id
1062+ )
10331063 run_file = run_cache_dir / "description.xml"
10341064 try :
10351065 with run_file .open (encoding = "utf8" ) as fh :
@@ -1199,7 +1229,9 @@ def __list_runs(api_call: str) -> pd.DataFrame:
11991229 runs_dict = xmltodict .parse (xml_string , force_list = ("oml:run" ,))
12001230 # Minimalistic check if the XML is useful
12011231 if "oml:runs" not in runs_dict :
1202- raise ValueError (f'Error in return XML, does not contain "oml:runs": { runs_dict } ' )
1232+ raise ValueError (
1233+ f'Error in return XML, does not contain "oml:runs": { runs_dict } '
1234+ )
12031235
12041236 if "@xmlns:oml" not in runs_dict ["oml:runs" ]:
12051237 raise ValueError (
@@ -1213,7 +1245,9 @@ def __list_runs(api_call: str) -> pd.DataFrame:
12131245 f'"http://openml.org/openml": { runs_dict } ' ,
12141246 )
12151247
1216- assert isinstance (runs_dict ["oml:runs" ]["oml:run" ], list ), type (runs_dict ["oml:runs" ])
1248+ assert isinstance (runs_dict ["oml:runs" ]["oml:run" ], list ), type (
1249+ runs_dict ["oml:runs" ]
1250+ )
12171251
12181252 runs = {
12191253 int (r ["oml:run_id" ]): {
0 commit comments