Skip to content

Commit a28cb51

Browse files
committed
add function to check flow equality
1 parent a5cb405 commit a28cb51

3 files changed

Lines changed: 103 additions & 38 deletions

File tree

openml/flows/functions.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,4 +127,37 @@ def _list_flows(api_call):
127127
'uploader': flow_['oml:uploader']}
128128
flows[fid] = flow
129129

130-
return flows
130+
return flows
131+
132+
133+
def are_flows_equal(flow1, flow2):
134+
"""Check equality of two flows.
135+
136+
Two flows are equal if their all keys which are not set by the server
137+
are equal, as well as all their parameters and components.
138+
"""
139+
if not isinstance(flow2, flow1.__class__):
140+
return False
141+
142+
# Name is actually not generated by the server, but it will be
143+
# tested further down with a getter (allows mocking in the tests)
144+
generated_by_the_server = ['flow_id', 'uploader', 'version',
145+
'upload_date', ]
146+
ignored_by_python_API = ['binary_url', 'binary_format', 'binary_md5',
147+
'model']
148+
149+
for key in set(flow1.__dict__.keys()).union(flow2.__dict__.keys()):
150+
if key in generated_by_the_server + ignored_by_python_API:
151+
continue
152+
attr1 = getattr(flow1, key, None)
153+
attr2 = getattr(flow2, key, None)
154+
if key == 'components':
155+
for name in set(attr1.keys()).union(attr2.keys()):
156+
if not (name in attr1 and name in attr2):
157+
return False
158+
if not are_flows_equal(attr1[name], attr2[name]):
159+
return False
160+
else:
161+
if attr1 != attr2:
162+
return False
163+
return True

tests/test_flows/test_flow.py

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@
22
import hashlib
33
import re
44
import time
5-
import random
6-
import unittest
75

86
import xmltodict
97

@@ -26,38 +24,6 @@
2624
from openml.flows.sklearn_converter import _format_external_version
2725

2826

29-
def are_flows_equal(flow1, flow2):
30-
"""Check equality of two flows.
31-
32-
Two flows are equal if their all keys which are not set by the server
33-
are equal, as well as all their parameters and components.
34-
"""
35-
if not isinstance(flow2, flow1.__class__):
36-
return False
37-
38-
# Name is actually not generated by the server, but it will be
39-
# tested further down with a getter (allows mocking in the tests)
40-
generated_by_the_server = ['name', 'flow_id', 'uploader', 'version',
41-
'upload_date', 'source_url',
42-
'binary_url', 'source_format',
43-
'binary_format', 'source_md5',
44-
'binary_md5', 'model']
45-
46-
for key in set(flow1.__dict__.keys()).union(flow2.__dict__.keys()):
47-
if key in generated_by_the_server:
48-
continue
49-
attr1 = getattr(flow1, key, None)
50-
attr2 = getattr(flow2, key, None)
51-
if key == 'components':
52-
for name in set(attr1.keys()).union(attr2.keys()):
53-
if not are_flows_equal(attr1[name], attr2[name]):
54-
return False
55-
else:
56-
if attr1 != attr2:
57-
return False
58-
return True
59-
60-
6127
def get_sentinel():
6228
# Create a unique prefix for the flow. Necessary because the flow is
6329
# identified by its name and external version online. Having a unique
@@ -134,7 +100,7 @@ def test_to_xml_from_xml(self):
134100
xml = flow._to_xml()
135101
xml_dict = xmltodict.parse(xml)
136102
new_flow = openml.flows.OpenMLFlow._from_dict(xml_dict)
137-
self.assertTrue(are_flows_equal(new_flow, flow))
103+
self.assertTrue(openml.flows.functions.are_flows_equal(new_flow, flow))
138104
self.assertIsNot(new_flow, flow)
139105

140106
def test_publish_flow(self):
@@ -267,7 +233,7 @@ def test_sklearn_to_upload_to_flow(self):
267233

268234
self.assertEqual(server_xml, local_xml)
269235

270-
self.assertTrue(are_flows_equal(new_flow, flow))
236+
self.assertTrue(openml.flows.functions.are_flows_equal(new_flow, flow))
271237
self.assertIsNot(new_flow, flow)
272238

273239
fixture_name = '%ssklearn.model_selection._search.RandomizedSearchCV(' \

tests/test_flows/test_flow_functions.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import OrderedDict
2+
import copy
13
import unittest
24

35
import openml
@@ -39,4 +41,68 @@ def test_list_flows_paginate(self):
3941
flows = openml.flows.list_flows(offset=i, size=size)
4042
self.assertGreaterEqual(size, len(flows))
4143
for did in flows:
42-
self._check_flow(flows[did])
44+
self._check_flow(flows[did])
45+
46+
def test_are_flows_equal(self):
47+
flow = openml.flows.OpenMLFlow(name='Test',
48+
description='Test flow',
49+
model=None,
50+
components=OrderedDict(),
51+
parameters=OrderedDict(),
52+
parameters_meta_info=OrderedDict(),
53+
external_version='1',
54+
tags=['abc', 'def'],
55+
language='English',
56+
dependencies='abc',
57+
class_name='Test',
58+
custom_name='Test')
59+
60+
# Test most important values that can be set by a user
61+
self.assertTrue(openml.flows.functions.are_flows_equal(flow, flow))
62+
for attribute, new_value in [('name', 'Tes'),
63+
('description', 'Test flo'),
64+
('external_version', '2'),
65+
('tags', ['abc', 'de']),
66+
('language', 'english'),
67+
('dependencies', 'ab'),
68+
('class_name', 'Tes'),
69+
('custom_name', 'Tes')]:
70+
new_flow = copy.deepcopy(flow)
71+
setattr(new_flow, attribute, new_value)
72+
self.assertNotEqual(getattr(flow, attribute), getattr(new_flow, attribute))
73+
self.assertFalse(openml.flows.functions.are_flows_equal(flow, new_flow),
74+
msg=attribute)
75+
76+
# Test that the API ignores several keys when comparing flows
77+
self.assertTrue(openml.flows.functions.are_flows_equal(flow, flow))
78+
for attribute, new_value in [('flow_id', 1),
79+
('uploader', 1),
80+
('version', 1),
81+
('upload_date', '18.12.1988'),
82+
('binary_url', 'openml.org'),
83+
('binary_format', 'gzip'),
84+
('binary_md5', '12345'),
85+
('model', [])]:
86+
new_flow = copy.deepcopy(flow)
87+
setattr(new_flow, attribute, new_value)
88+
self.assertNotEqual(getattr(flow, attribute), getattr(new_flow, attribute))
89+
self.assertTrue(openml.flows.functions.are_flows_equal(flow, new_flow),
90+
msg=attribute)
91+
92+
# Now test for parameters
93+
flow.parameters['abc'] = 1.0
94+
flow.parameters['def'] = 2.0
95+
self.assertTrue(openml.flows.functions.are_flows_equal(flow, flow))
96+
new_flow = copy.deepcopy(flow)
97+
new_flow.parameters['abc'] = 2.0
98+
self.assertFalse(openml.flows.functions.are_flows_equal(flow, new_flow))
99+
100+
# Now test for components (subflows)
101+
parent_flow = copy.deepcopy(flow)
102+
subflow = copy.deepcopy(flow)
103+
parent_flow.components['subflow'] = subflow
104+
self.assertTrue(openml.flows.functions.are_flows_equal(parent_flow, parent_flow))
105+
self.assertFalse(openml.flows.functions.are_flows_equal(parent_flow, subflow))
106+
new_flow = copy.deepcopy(parent_flow)
107+
new_flow.components['subflow'].name = 'Subflow name'
108+
self.assertFalse(openml.flows.functions.are_flows_equal(parent_flow, new_flow))

0 commit comments

Comments
 (0)