3333from openml .testing import SimpleImputer , TestBase
3434
3535
36-
3736class TestFlow (TestBase ):
3837 _multiprocess_can_split_ = True
3938
@@ -162,12 +161,16 @@ def test_from_xml_to_xml(self):
162161 def test_to_xml_from_xml (self ):
163162 scaler = sklearn .preprocessing .StandardScaler (with_mean = False )
164163 estimator_name = (
165- "base_estimator" if Version (sklearn .__version__ ) < Version ("1.4" ) else "estimator"
164+ "base_estimator"
165+ if Version (sklearn .__version__ ) < Version ("1.4" )
166+ else "estimator"
166167 )
167168 boosting = sklearn .ensemble .AdaBoostClassifier (
168169 ** {estimator_name : sklearn .tree .DecisionTreeClassifier ()},
169170 )
170- model = sklearn .pipeline .Pipeline (steps = (("scaler" , scaler ), ("boosting" , boosting )))
171+ model = sklearn .pipeline .Pipeline (
172+ steps = (("scaler" , scaler ), ("boosting" , boosting ))
173+ )
171174 flow = self .extension .model_to_flow (model )
172175 flow .flow_id = - 234
173176 # end of setup
@@ -180,7 +183,10 @@ def test_to_xml_from_xml(self):
180183 openml .flows .functions .assert_flows_equal (new_flow , flow )
181184 assert new_flow is not flow
182185
183- @pytest .mark .skip (reason = "Pending resolution of #1657" )
186+ @pytest .mark .skipif (
187+ os .getenv ("OPENML_USE_LOCAL_SERVICES" ) == "true" ,
188+ reason = "Pending resolution of #1657" ,
189+ )
184190 @pytest .mark .sklearn ()
185191 @pytest .mark .test_server ()
186192 def test_publish_flow (self ):
@@ -205,7 +211,9 @@ def test_publish_flow(self):
205211
206212 flow .publish ()
207213 TestBase ._mark_entity_for_removal ("flow" , flow .flow_id , flow .name )
208- TestBase .logger .info (f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } " )
214+ TestBase .logger .info (
215+ f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } "
216+ )
209217 assert isinstance (flow .flow_id , int )
210218
211219 @pytest .mark .sklearn ()
@@ -215,15 +223,20 @@ def test_publish_existing_flow(self, flow_exists_mock):
215223 flow = self .extension .model_to_flow (clf )
216224 flow_exists_mock .return_value = 1
217225
218- with pytest .raises (openml .exceptions .PyOpenMLError , match = "OpenMLFlow already exists" ):
226+ with pytest .raises (
227+ openml .exceptions .PyOpenMLError , match = "OpenMLFlow already exists"
228+ ):
219229 flow .publish (raise_error_if_exists = True )
220230
221231 TestBase ._mark_entity_for_removal ("flow" , flow .flow_id , flow .name )
222232 TestBase .logger .info (
223233 f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } " ,
224234 )
225235
226- @pytest .mark .skip (reason = "Pending resolution of #1657" )
236+ @pytest .mark .skipif (
237+ os .getenv ("OPENML_USE_LOCAL_SERVICES" ) == "true" ,
238+ reason = "Pending resolution of #1657" ,
239+ )
227240 @pytest .mark .sklearn ()
228241 @pytest .mark .test_server ()
229242 def test_publish_flow_with_similar_components (self ):
@@ -234,7 +247,9 @@ def test_publish_flow_with_similar_components(self):
234247 flow , _ = self ._add_sentinel_to_flow_name (flow , None )
235248 flow .publish ()
236249 TestBase ._mark_entity_for_removal ("flow" , flow .flow_id , flow .name )
237- TestBase .logger .info (f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } " )
250+ TestBase .logger .info (
251+ f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } "
252+ )
238253 # For a flow where both components are published together, the upload
239254 # date should be equal
240255 assert flow .upload_date == flow .components ["lr" ].upload_date , (
@@ -249,7 +264,9 @@ def test_publish_flow_with_similar_components(self):
249264 flow1 , sentinel = self ._add_sentinel_to_flow_name (flow1 , None )
250265 flow1 .publish ()
251266 TestBase ._mark_entity_for_removal ("flow" , flow .flow_id , flow .name )
252- TestBase .logger .info (f"collected from { __file__ .split ('/' )[- 1 ]} : { flow1 .flow_id } " )
267+ TestBase .logger .info (
268+ f"collected from { __file__ .split ('/' )[- 1 ]} : { flow1 .flow_id } "
269+ )
253270
254271 # In order to assign different upload times to the flows!
255272 time .sleep (1 )
@@ -261,29 +278,40 @@ def test_publish_flow_with_similar_components(self):
261278 flow2 , _ = self ._add_sentinel_to_flow_name (flow2 , sentinel )
262279 flow2 .publish ()
263280 TestBase ._mark_entity_for_removal ("flow" , flow2 .flow_id , flow2 .name )
264- TestBase .logger .info (f"collected from { __file__ .split ('/' )[- 1 ]} : { flow2 .flow_id } " )
281+ TestBase .logger .info (
282+ f"collected from { __file__ .split ('/' )[- 1 ]} : { flow2 .flow_id } "
283+ )
265284 # If one component was published before the other, the components in
266285 # the flow should have different upload dates
267286 assert flow2 .upload_date != flow2 .components ["dt" ].upload_date
268287
269- clf3 = sklearn .ensemble .AdaBoostClassifier (sklearn .tree .DecisionTreeClassifier (max_depth = 3 ))
288+ clf3 = sklearn .ensemble .AdaBoostClassifier (
289+ sklearn .tree .DecisionTreeClassifier (max_depth = 3 )
290+ )
270291 flow3 = self .extension .model_to_flow (clf3 )
271292 flow3 , _ = self ._add_sentinel_to_flow_name (flow3 , sentinel )
272293 # Child flow has different parameter. Check for storing the flow
273294 # correctly on the server should thus not check the child's parameters!
274295 flow3 .publish ()
275296 TestBase ._mark_entity_for_removal ("flow" , flow3 .flow_id , flow3 .name )
276- TestBase .logger .info (f"collected from { __file__ .split ('/' )[- 1 ]} : { flow3 .flow_id } " )
297+ TestBase .logger .info (
298+ f"collected from { __file__ .split ('/' )[- 1 ]} : { flow3 .flow_id } "
299+ )
277300
278- @pytest .mark .skip (reason = "Pending resolution of #1657" )
301+ @pytest .mark .skipif (
302+ os .getenv ("OPENML_USE_LOCAL_SERVICES" ) == "true" ,
303+ reason = "Pending resolution of #1657" ,
304+ )
279305 @pytest .mark .sklearn ()
280306 @pytest .mark .test_server ()
281307 def test_semi_legal_flow (self ):
282308 # TODO: Test if parameters are set correctly!
283309 # should not throw error as it contains two differentiable forms of
284310 # Bagging i.e., Bagging(Bagging(J48)) and Bagging(J48)
285311 estimator_name = (
286- "base_estimator" if Version (sklearn .__version__ ) < Version ("1.4" ) else "estimator"
312+ "base_estimator"
313+ if Version (sklearn .__version__ ) < Version ("1.4" )
314+ else "estimator"
287315 )
288316 semi_legal = sklearn .ensemble .BaggingClassifier (
289317 ** {
@@ -299,7 +327,9 @@ def test_semi_legal_flow(self):
299327
300328 flow .publish ()
301329 TestBase ._mark_entity_for_removal ("flow" , flow .flow_id , flow .name )
302- TestBase .logger .info (f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } " )
330+ TestBase .logger .info (
331+ f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } "
332+ )
303333
304334 @pytest .mark .sklearn ()
305335 @mock .patch ("openml.flows.functions.get_flow" )
@@ -386,14 +416,21 @@ def get_sentinel():
386416 flow_id = openml .flows .flow_exists (name , version )
387417 assert not flow_id
388418
389- @pytest .mark .skip (reason = "Pending resolution of #1657" )
419+ @pytest .mark .skipif (
420+ os .getenv ("OPENML_USE_LOCAL_SERVICES" ) == "true" ,
421+ reason = "Pending resolution of #1657" ,
422+ )
390423 @pytest .mark .sklearn ()
391424 @pytest .mark .test_server ()
392425 def test_existing_flow_exists (self ):
393426 # create a flow
394427 nb = sklearn .naive_bayes .GaussianNB ()
395428
396- sparse = "sparse" if Version (sklearn .__version__ ) < Version ("1.4" ) else "sparse_output"
429+ sparse = (
430+ "sparse"
431+ if Version (sklearn .__version__ ) < Version ("1.4" )
432+ else "sparse_output"
433+ )
397434 ohe_params = {sparse : False , "handle_unknown" : "ignore" }
398435 if Version (sklearn .__version__ ) >= Version ("0.20" ):
399436 ohe_params ["categories" ] = "auto"
@@ -428,7 +465,10 @@ def test_existing_flow_exists(self):
428465 )
429466 assert downloaded_flow_id == flow .flow_id
430467
431- @pytest .mark .skip (reason = "Pending resolution of #1657" )
468+ @pytest .mark .skipif (
469+ os .getenv ("OPENML_USE_LOCAL_SERVICES" ) == "true" ,
470+ reason = "Pending resolution of #1657" ,
471+ )
432472 @pytest .mark .sklearn ()
433473 @pytest .mark .test_server ()
434474 def test_sklearn_to_upload_to_flow (self ):
@@ -449,13 +489,20 @@ def test_sklearn_to_upload_to_flow(self):
449489 )
450490 fu = sklearn .pipeline .FeatureUnion (transformer_list = [("pca" , pca ), ("fs" , fs )])
451491 estimator_name = (
452- "base_estimator" if Version (sklearn .__version__ ) < Version ("1.4" ) else "estimator"
492+ "base_estimator"
493+ if Version (sklearn .__version__ ) < Version ("1.4" )
494+ else "estimator"
453495 )
454496 boosting = sklearn .ensemble .AdaBoostClassifier (
455497 ** {estimator_name : sklearn .tree .DecisionTreeClassifier ()},
456498 )
457499 model = sklearn .pipeline .Pipeline (
458- steps = [("ohe" , ohe ), ("scaler" , scaler ), ("fu" , fu ), ("boosting" , boosting )],
500+ steps = [
501+ ("ohe" , ohe ),
502+ ("scaler" , scaler ),
503+ ("fu" , fu ),
504+ ("boosting" , boosting ),
505+ ],
459506 )
460507 parameter_grid = {
461508 "boosting__n_estimators" : [1 , 5 , 10 , 100 ],
@@ -482,7 +529,9 @@ def test_sklearn_to_upload_to_flow(self):
482529
483530 flow .publish ()
484531 TestBase ._mark_entity_for_removal ("flow" , flow .flow_id , flow .name )
485- TestBase .logger .info (f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } " )
532+ TestBase .logger .info (
533+ f"collected from { __file__ .split ('/' )[- 1 ]} : { flow .flow_id } "
534+ )
486535 assert isinstance (flow .flow_id , int )
487536
488537 # Check whether we can load the flow again
@@ -565,7 +614,10 @@ def test_extract_tags(self):
565614 tags = openml .utils .extract_xml_tags ("oml:tag" , flow_dict )
566615 assert tags == ["study_14" ]
567616
568- flow_xml = "<oml:flow><oml:tag>OpenmlWeka</oml:tag>\n " "<oml:tag>weka</oml:tag></oml:flow>"
617+ flow_xml = (
618+ "<oml:flow><oml:tag>OpenmlWeka</oml:tag>\n "
619+ "<oml:tag>weka</oml:tag></oml:flow>"
620+ )
569621 flow_dict = xmltodict .parse (flow_xml )
570622 tags = openml .utils .extract_xml_tags ("oml:tag" , flow_dict ["oml:flow" ])
571623 assert tags == ["OpenmlWeka" , "weka" ]
0 commit comments