Skip to content

Commit e5ba984

Browse files
AbhishekAbhishek
authored andcommitted
[DOC] Add usage examples to core function docstrings (#1538)
1 parent 7feb2a3 commit e5ba984

4 files changed

Lines changed: 103 additions & 24 deletions

File tree

openml/datasets/functions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,11 @@ def get_datasets(
364364
-------
365365
datasets : list of datasets
366366
A list of dataset objects.
367+
368+
Examples
369+
--------
370+
>>> import openml
371+
>>> datasets = openml.datasets.get_datasets([1, 2, 3]) # doctest: +SKIP
367372
"""
368373
datasets = []
369374
for dataset_id in dataset_ids:
@@ -446,6 +451,13 @@ def get_dataset( # noqa: C901, PLR0912
446451
-------
447452
dataset : :class:`openml.OpenMLDataset`
448453
The downloaded dataset.
454+
455+
Examples
456+
--------
457+
>>> import openml
458+
>>> dataset = openml.datasets.get_dataset(1) # doctest: +SKIP
459+
>>> dataset = openml.datasets.get_dataset("iris", version=1) # doctest: +SKIP
460+
>>> dataset = openml.datasets.get_dataset(1, download_data=True) # doctest: +SKIP
449461
"""
450462
if download_all_files:
451463
warnings.warn(

openml/runs/functions.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def run_model_on_task( # noqa: PLR0913
104104
Result of the run.
105105
flow : OpenMLFlow (optional, only if `return_flow` is True).
106106
Flow generated from the model.
107+
108+
Examples
109+
--------
110+
>>> import openml
111+
>>> import openml_sklearn # doctest: +SKIP
112+
>>> from sklearn.tree import DecisionTreeClassifier # doctest: +SKIP
113+
>>> clf = DecisionTreeClassifier() # doctest: +SKIP
114+
>>> task = openml.tasks.get_task(1) # doctest: +SKIP
115+
>>> run = openml.runs.run_model_on_task(clf, task) # doctest: +SKIP
107116
"""
108117
if avoid_duplicate_runs is None:
109118
avoid_duplicate_runs = openml.config.avoid_duplicate_runs
@@ -273,9 +282,7 @@ def run_flow_on_task( # noqa: C901, PLR0912, PLR0915, PLR0913
273282
setup_id = setup_exists(flow_from_server)
274283
ids = run_exists(task.task_id, setup_id)
275284
if ids:
276-
error_message = (
277-
"One or more runs of this setup were already performed on the task."
278-
)
285+
error_message = "One or more runs of this setup were already performed on the task."
279286
raise OpenMLRunsExistError(ids, error_message)
280287
else:
281288
# Flow does not exist on server and we do not want to upload it.
@@ -505,11 +512,15 @@ def _run_task_get_arffcontent( # noqa: PLR0915, PLR0912, C901
505512
# this information is multiple times overwritten, but due to the ordering
506513
# of tne loops, eventually it contains the information based on the full
507514
# dataset size
508-
user_defined_measures_per_fold = OrderedDict() # type: 'OrderedDict[str, OrderedDict]'
515+
user_defined_measures_per_fold = (
516+
OrderedDict()
517+
) # type: 'OrderedDict[str, OrderedDict]'
509518
# stores sample-based evaluation measures (sublevel of fold-based)
510519
# will also be filled on a non sample-based task, but the information
511520
# is the same as the fold-based measures, and disregarded in that case
512-
user_defined_measures_per_sample = OrderedDict() # type: 'OrderedDict[str, OrderedDict]'
521+
user_defined_measures_per_sample = (
522+
OrderedDict()
523+
) # type: 'OrderedDict[str, OrderedDict]'
513524

514525
# TODO use different iterator to only provide a single iterator (less
515526
# methods, less maintenance, less confusion)
@@ -557,9 +568,14 @@ def _run_task_get_arffcontent( # noqa: PLR0915, PLR0912, C901
557568
) # job_rvals contain the output of all the runs with one-to-one correspondence with `jobs`
558569

559570
for n_fit, rep_no, fold_no, sample_no in jobs:
560-
pred_y, proba_y, test_indices, test_y, inner_trace, user_defined_measures_fold = job_rvals[
561-
n_fit - 1
562-
]
571+
(
572+
pred_y,
573+
proba_y,
574+
test_indices,
575+
test_y,
576+
inner_trace,
577+
user_defined_measures_fold,
578+
) = job_rvals[n_fit - 1]
563579

564580
if inner_trace is not None:
565581
traces.append(inner_trace)
@@ -598,7 +614,11 @@ def _calculate_local_measure( # type: ignore
598614
if isinstance(test_y[i], (int, np.integer))
599615
else test_y[i]
600616
)
601-
pred_prob = proba_y.iloc[i] if isinstance(proba_y, pd.DataFrame) else proba_y[i]
617+
pred_prob = (
618+
proba_y.iloc[i]
619+
if isinstance(proba_y, pd.DataFrame)
620+
else proba_y[i]
621+
)
602622

603623
arff_line = format_prediction(
604624
task=task,
@@ -661,11 +681,13 @@ def _calculate_local_measure( # type: ignore
661681
if rep_no not in user_defined_measures_per_sample[measure]:
662682
user_defined_measures_per_sample[measure][rep_no] = OrderedDict()
663683
if fold_no not in user_defined_measures_per_sample[measure][rep_no]:
664-
user_defined_measures_per_sample[measure][rep_no][fold_no] = OrderedDict()
684+
user_defined_measures_per_sample[measure][rep_no][
685+
fold_no
686+
] = OrderedDict()
665687

666-
user_defined_measures_per_fold[measure][rep_no][fold_no] = user_defined_measures_fold[
667-
measure
668-
]
688+
user_defined_measures_per_fold[measure][rep_no][fold_no] = (
689+
user_defined_measures_fold[measure]
690+
)
669691
user_defined_measures_per_sample[measure][rep_no][fold_no][sample_no] = (
670692
user_defined_measures_fold[measure]
671693
)
@@ -821,7 +843,9 @@ def get_run(run_id: int, ignore_cache: bool = False) -> OpenMLRun: # noqa: FBT0
821843
run : OpenMLRun
822844
Run corresponding to ID, fetched from the server.
823845
"""
824-
run_dir = Path(openml.utils._create_cache_directory_for_id(RUNS_CACHE_DIR_NAME, run_id))
846+
run_dir = Path(
847+
openml.utils._create_cache_directory_for_id(RUNS_CACHE_DIR_NAME, run_id)
848+
)
825849
run_file = run_dir / "description.xml"
826850

827851
run_dir.mkdir(parents=True, exist_ok=True)
@@ -840,7 +864,9 @@ def get_run(run_id: int, ignore_cache: bool = False) -> OpenMLRun: # noqa: FBT0
840864
return _create_run_from_xml(run_xml)
841865

842866

843-
def _create_run_from_xml(xml: str, from_server: bool = True) -> OpenMLRun: # noqa: PLR0915, PLR0912, C901, FBT002
867+
def _create_run_from_xml(
868+
xml: str, from_server: bool = True
869+
) -> OpenMLRun: # noqa: PLR0915, PLR0912, C901, FBT002
844870
"""Create a run object from xml returned from server.
845871
846872
Parameters
@@ -870,11 +896,13 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None): # type: ignore
870896
if not from_server:
871897
return None
872898

873-
raise AttributeError("Run XML does not contain required (server) field: ", fieldname)
899+
raise AttributeError(
900+
"Run XML does not contain required (server) field: ", fieldname
901+
)
874902

875-
run = xmltodict.parse(xml, force_list=["oml:file", "oml:evaluation", "oml:parameter_setting"])[
876-
"oml:run"
877-
]
903+
run = xmltodict.parse(
904+
xml, force_list=["oml:file", "oml:evaluation", "oml:parameter_setting"]
905+
)["oml:run"]
878906
run_id = obtain_field(run, "oml:run_id", from_server, cast=int)
879907
uploader = obtain_field(run, "oml:uploader", from_server, cast=int)
880908
uploader_name = obtain_field(run, "oml:uploader_name", from_server)
@@ -1029,7 +1057,9 @@ def obtain_field(xml_obj, fieldname, from_server, cast=None): # type: ignore
10291057

10301058
def _get_cached_run(run_id: int) -> OpenMLRun:
10311059
"""Load a run from the cache."""
1032-
run_cache_dir = openml.utils._create_cache_directory_for_id(RUNS_CACHE_DIR_NAME, run_id)
1060+
run_cache_dir = openml.utils._create_cache_directory_for_id(
1061+
RUNS_CACHE_DIR_NAME, run_id
1062+
)
10331063
run_file = run_cache_dir / "description.xml"
10341064
try:
10351065
with run_file.open(encoding="utf8") as fh:
@@ -1199,7 +1229,9 @@ def __list_runs(api_call: str) -> pd.DataFrame:
11991229
runs_dict = xmltodict.parse(xml_string, force_list=("oml:run",))
12001230
# Minimalistic check if the XML is useful
12011231
if "oml:runs" not in runs_dict:
1202-
raise ValueError(f'Error in return XML, does not contain "oml:runs": {runs_dict}')
1232+
raise ValueError(
1233+
f'Error in return XML, does not contain "oml:runs": {runs_dict}'
1234+
)
12031235

12041236
if "@xmlns:oml" not in runs_dict["oml:runs"]:
12051237
raise ValueError(
@@ -1213,7 +1245,9 @@ def __list_runs(api_call: str) -> pd.DataFrame:
12131245
f'"http://openml.org/openml": {runs_dict}',
12141246
)
12151247

1216-
assert isinstance(runs_dict["oml:runs"]["oml:run"], list), type(runs_dict["oml:runs"])
1248+
assert isinstance(runs_dict["oml:runs"]["oml:run"], list), type(
1249+
runs_dict["oml:runs"]
1250+
)
12171251

12181252
runs = {
12191253
int(r["oml:run_id"]): {

openml/study/functions.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def get_suite(suite_id: int | str) -> OpenMLBenchmarkSuite:
3030
-------
3131
OpenMLSuite
3232
The OpenML suite object
33+
34+
Examples
35+
--------
36+
>>> import openml
37+
>>> suite = openml.study.get_suite(99) # doctest: +SKIP
38+
>>> suite = openml.study.get_suite("OpenML-CC18") # doctest: +SKIP
3339
"""
3440
study = _get_study(suite_id, entity_type="task")
3541
assert isinstance(study, OpenMLBenchmarkSuite)
@@ -59,6 +65,11 @@ def get_study(
5965
-------
6066
OpenMLStudy
6167
The OpenML study object
68+
69+
Examples
70+
--------
71+
>>> import openml
72+
>>> study = openml.study.get_study(1) # doctest: +SKIP
6273
"""
6374
if study_id == "OpenML100":
6475
message = (
@@ -109,7 +120,10 @@ def _get_study(id_: int | str, entity_type: str) -> BaseStudy:
109120
tags = []
110121
if "oml:tag" in result_dict:
111122
for tag in result_dict["oml:tag"]:
112-
current_tag = {"name": tag["oml:name"], "write_access": tag["oml:write_access"]}
123+
current_tag = {
124+
"name": tag["oml:name"],
125+
"write_access": tag["oml:write_access"],
126+
}
113127
if "oml:window_start" in tag:
114128
current_tag["window_start"] = tag["oml:window_start"]
115129
tags.append(current_tag)
@@ -210,6 +224,15 @@ def create_study(
210224
-------
211225
OpenMLStudy
212226
A local OpenML study object (call publish method to upload to server)
227+
228+
Examples
229+
--------
230+
>>> import openml
231+
>>> study = openml.study.create_study( # doctest: +SKIP
232+
... name="My Study",
233+
... description="A study on classification tasks",
234+
... run_ids=[1, 2, 3],
235+
... )
213236
"""
214237
return OpenMLStudy(
215238
study_id=None,

openml/tasks/functions.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,11 @@ def get_tasks(
380380
tasks = []
381381
for task_id in task_ids:
382382
tasks.append(
383-
get_task(task_id, download_data=download_data, download_qualities=download_qualities)
383+
get_task(
384+
task_id,
385+
download_data=download_data,
386+
download_qualities=download_qualities,
387+
)
384388
)
385389
return tasks
386390

@@ -411,6 +415,12 @@ def get_task(
411415
Returns
412416
-------
413417
task: OpenMLTask
418+
419+
Examples
420+
--------
421+
>>> import openml
422+
>>> task = openml.tasks.get_task(1) # doctest: +SKIP
423+
>>> task = openml.tasks.get_task(1, download_splits=True) # doctest: +SKIP
414424
"""
415425
if not isinstance(task_id, int):
416426
raise TypeError(f"Task id should be integer, is {type(task_id)}")

0 commit comments

Comments
 (0)