Skip to content

Commit 96ddc13

Browse files
PGijsbersmfeurer
authored andcommitted
Add support for serializing numpy data types. (#635)
* Add support for serializing numpy data types. * Added tests on numpy-types in sklearn_to_flow.
1 parent 98a73b3 commit 96ddc13

2 files changed

Lines changed: 22 additions & 2 deletions

File tree

openml/flows/sklearn_converter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737

3838
def sklearn_to_flow(o, parent_model=None):
3939
# TODO: assert that only on first recursion lvl `parent_model` can be None
40-
40+
simple_numpy_types = [nptype for type_cat, nptypes in np.sctypes.items()
41+
for nptype in nptypes
42+
if type_cat != 'others']
43+
simple_types = tuple([bool, int, float, str] + simple_numpy_types)
4144
if _is_estimator(o):
4245
# is the main model or a submodel
4346
rval = _serialize_model(o)
@@ -46,7 +49,9 @@ def sklearn_to_flow(o, parent_model=None):
4649
rval = [sklearn_to_flow(element, parent_model) for element in o]
4750
if isinstance(o, tuple):
4851
rval = tuple(rval)
49-
elif isinstance(o, (bool, int, float, str)) or o is None:
52+
elif isinstance(o, simple_types) or o is None:
53+
if isinstance(o, tuple(simple_numpy_types)):
54+
o = o.item()
5055
# base parameter values
5156
rval = o
5257
elif isinstance(o, dict):

tests/test_flows/test_sklearn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,3 +1180,18 @@ def test_obtain_parameter_values(self):
11801180
if parameter['oml:name'] == 'n_estimators':
11811181
self.assertEqual(parameter['oml:value'], '5')
11821182
self.assertEqual(parameter['oml:component'], 2)
1183+
1184+
def test_numpy_type_allowed_in_flow(self):
1185+
""" Simple numpy types should be serializable. """
1186+
dt = sklearn.tree.DecisionTreeClassifier(
1187+
max_depth=np.float64(3.0),
1188+
min_samples_leaf=np.int32(5)
1189+
)
1190+
sklearn_to_flow(dt)
1191+
1192+
def test_numpy_array_not_allowed_in_flow(self):
1193+
""" Simple numpy arrays should not be serializable. """
1194+
bin = sklearn.preprocessing.MultiLabelBinarizer(
1195+
classes=np.asarray([1, 2, 3])
1196+
)
1197+
self.assertRaises(TypeError, sklearn_to_flow, bin)

0 commit comments

Comments
 (0)