Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 145f82d

Browse files
ThomasDelteilszha
authored andcommitted
Updating SymbolBlock.imports to support different dtypes (#15230)
* updating symbol block for different dtypes * remove logging * update test and fix lint issues * lint * fix initializer
1 parent eb48370 commit 145f82d

3 files changed

Lines changed: 119 additions & 29 deletions

File tree

python/mxnet/gluon/block.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
import re
2727
from collections import OrderedDict
2828

29-
from ..base import mx_real_t
29+
30+
from ..base import mx_real_t, MXNetError
3031
from .. import symbol, ndarray, initializer
3132
from ..symbol import Symbol
3233
from ..ndarray import NDArray
@@ -354,7 +355,7 @@ def save_params(self, filename):
354355
'save_parameters may resolve this error.'%e.message)
355356

356357
def load_parameters(self, filename, ctx=None, allow_missing=False,
357-
ignore_extra=False, cast_dtype=False):
358+
ignore_extra=False, cast_dtype=False, dtype_source='current'):
358359
"""Load parameters from file previously saved by `save_parameters`.
359360
360361
Parameters
@@ -371,7 +372,10 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,
371372
cast_dtype : bool, default False
372373
Cast the data type of the NDArray loaded from the checkpoint to the dtype
373374
provided by the Parameter if any.
374-
375+
dtype_source : str, default 'current'
376+
must be in {'current', 'saved'}
377+
Only valid if cast_dtype=True, specify the source of the dtype for casting
378+
the parameters
375379
References
376380
----------
377381
`Saving and Loading Gluon Models \
@@ -386,7 +390,8 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,
386390
# legacy loading
387391
del loaded
388392
self.collect_params().load(
389-
filename, ctx, allow_missing, ignore_extra, self.prefix, cast_dtype=cast_dtype)
393+
filename, ctx, allow_missing, ignore_extra, self.prefix,
394+
cast_dtype=cast_dtype, dtype_source=dtype_source)
390395
return
391396

392397
if not allow_missing:
@@ -402,7 +407,7 @@ def load_parameters(self, filename, ctx=None, allow_missing=False,
402407
"which contains parameters %s. Set ignore_extra=True to ignore. "%(
403408
name, filename, _brief_print_list(self._params.keys())))
404409
if name in params:
405-
params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype)
410+
params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)
406411

407412
def load_params(self, filename, ctx=None, allow_missing=False,
408413
ignore_extra=False):
@@ -1021,10 +1026,15 @@ def imports(symbol_file, input_names, param_file=None, ctx=None):
10211026
sym = symbol.load(symbol_file)
10221027
if isinstance(input_names, str):
10231028
input_names = [input_names]
1024-
inputs = [symbol.var(i) for i in input_names]
1029+
if param_file is None:
1030+
# Get a valid type inference by using fp32
1031+
inputs = [symbol.var(i, dtype=mx_real_t) for i in input_names]
1032+
else:
1033+
# Do not specify type, rely on saved params type instead
1034+
inputs = [symbol.var(i) for i in input_names]
10251035
ret = SymbolBlock(sym, inputs)
10261036
if param_file is not None:
1027-
ret.collect_params().load(param_file, ctx=ctx)
1037+
ret.collect_params().load(param_file, ctx=ctx, cast_dtype=True, dtype_source='saved')
10281038
return ret
10291039

10301040
def __repr__(self):
@@ -1156,7 +1166,11 @@ def _infer_param_types(in_params, out_params, arg_params, aux_params, default_dt
11561166
# Try to infer types of other parameters.
11571167
if can_infer_input_type:
11581168
params = {k:v for k, v in zip(input_sym_names, input_sym_arg_types)}
1159-
arg_types, _, aux_types = out_params.infer_type(**params)
1169+
try:
1170+
arg_types, _, aux_types = out_params.infer_type(**params)
1171+
except MXNetError:
1172+
# Cannot infer type with current input
1173+
arg_types, aux_types = None, None
11601174

11611175
if arg_types is None or len(arg_types) != len(arg_params):
11621176
arg_types = []

python/mxnet/gluon/parameter.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def __init__(self, name, grad_req='write', shape=None, dtype=mx_real_t,
117117
shape = (shape,)
118118
self._shape = shape
119119
self.name = name
120-
self.dtype = dtype
120+
self._dtype = dtype
121121
self.lr_mult = lr_mult
122122
self.wd_mult = wd_mult
123123
self.grad_req = grad_req
@@ -155,6 +155,18 @@ def grad_req(self, req):
155155
elif self._data is not None:
156156
self._init_grad()
157157

158+
@property
159+
def dtype(self):
160+
"""The type of the parameter.
161+
162+
Setting the dtype value is equivalent to casting the value of the parameter
163+
"""
164+
return self._dtype
165+
166+
@dtype.setter
167+
def dtype(self, dtype):
168+
self.cast(dtype)
169+
158170
@property
159171
def shape(self):
160172
"""The shape of the parameter.
@@ -241,8 +253,24 @@ def _get_row_sparse(self, arr_list, ctx, row_id):
241253
self._trainer._row_sparse_pull(self, results, row_id)
242254
return results
243255

244-
def _load_init(self, data, ctx, cast_dtype=False):
245-
"""(Re)initializes by loading from data."""
256+
def _load_init(self, data, ctx, cast_dtype=False, dtype_source='current'):
257+
"""
258+
(Re)initializes by loading from data.
259+
Parameters
260+
----------
261+
data : NDArray
262+
The data to load
263+
ctx : Context or list of Context
264+
Context(s) initialize loaded parameters on.
265+
cast_dtype : bool, default False
266+
Cast the data type of the parameter
267+
dtype_source : str, default 'current'
268+
must be in {'current', 'saved'}
269+
Only valid if cast_dtype=True, specify the source of the dtype for casting
270+
the parameters
271+
"""
272+
if cast_dtype:
273+
assert dtype_source in ['current', 'saved']
246274
if self.shape:
247275
for self_dim, data_dim in zip(self.shape, data.shape):
248276
assert self_dim in (0, data_dim), \
@@ -252,8 +280,12 @@ def _load_init(self, data, ctx, cast_dtype=False):
252280
self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape))
253281
if self.dtype:
254282
if cast_dtype and np.dtype(self.dtype).type != data.dtype:
255-
data = data.astype(self.dtype, copy=False)
256-
assert np.dtype(self.dtype).type == data.dtype, \
283+
if dtype_source == 'current':
284+
data = data.astype(self.dtype, copy=False)
285+
elif dtype_source == 'saved':
286+
self.dtype = data.dtype
287+
else:
288+
assert np.dtype(self.dtype).type == data.dtype, \
257289
"Failed loading Parameter '%s' from saved params: " \
258290
"dtype incompatible expected %s vs saved %s. " \
259291
"Set cast_dtype=True to cast the dtype of saved params."%(
@@ -580,7 +612,7 @@ def cast(self, dtype):
580612
dtype : str or numpy.dtype
581613
The new data type.
582614
"""
583-
self.dtype = dtype
615+
self._dtype = dtype
584616
if self._data is None:
585617
return
586618
with autograd.pause():
@@ -894,7 +926,8 @@ def save(self, filename, strip_prefix=''):
894926
ndarray.save(filename, arg_dict)
895927

896928
def load(self, filename, ctx=None, allow_missing=False,
897-
ignore_extra=False, restore_prefix='', cast_dtype=False):
929+
ignore_extra=False, restore_prefix='', cast_dtype=False,
930+
dtype_source="current"):
898931
"""Load parameters from file.
899932
900933
Parameters
@@ -911,8 +944,11 @@ def load(self, filename, ctx=None, allow_missing=False,
911944
restore_prefix : str, default ''
912945
prepend prefix to names of stored parameters before loading.
913946
cast_dtype : bool, default False
914-
Cast the data type of the NDArray loaded from the checkpoint to the dtype
915-
provided by the Parameter if any.
947+
Cast the data type of the parameter
948+
dtype_source : str, default 'current'
949+
must be in {'current', 'saved'}
950+
Only valid if cast_dtype=True, specify the source of the dtype for casting
951+
the parameters
916952
"""
917953
if restore_prefix:
918954
for name in self.keys():
@@ -938,4 +974,4 @@ def load(self, filename, ctx=None, allow_missing=False,
938974
"Please make sure source and target networks have the same prefix."%(
939975
name[lprefix:], filename, _brief_print_list(self._params.keys()))
940976
continue
941-
self[name]._load_init(arg_dict[name], ctx, cast_dtype=cast_dtype)
977+
self[name]._load_init(arg_dict[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)

tests/python/unittest/test_gluon.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def test_parameter_invalid_access():
9494
assertRaises(RuntimeError, p1.list_row_sparse_data, row_id)
9595

9696
@with_seed()
97-
def test_paramdict():
97+
def test_parameter_dict():
9898
ctx = mx.cpu(1)
9999
params0 = gluon.ParameterDict('net_')
100100
params0.get('w0', shape=(10, 10))
@@ -107,15 +107,15 @@ def test_paramdict():
107107
prev_w0 = params0.get('w0').data(ctx)
108108
prev_w1 = params0.get('w1').row_sparse_data(all_row_ids)
109109
# save params
110-
params0.save('test_paramdict.params')
110+
params0.save('test_parameter_dict.params')
111111

112112
# load params
113113
params1 = gluon.ParameterDict('net_')
114114
params1.get('w0', shape=(10, 10))
115115
params1.get('w1', shape=(10, 10), stype='row_sparse')
116-
params1.load('test_paramdict.params', ctx)
116+
params1.load('test_parameter_dict.params', ctx)
117117
trainer1 = mx.gluon.Trainer(params1, 'sgd')
118-
118+
119119
# compare the values before and after save/load
120120
cur_w0 = params1.get('w0').data(ctx)
121121
cur_w1 = params1.get('w1').row_sparse_data(all_row_ids)
@@ -127,13 +127,30 @@ def test_paramdict():
127127
params2 = gluon.ParameterDict('net_')
128128
params2.get('w0', shape=(10, 10))
129129
params2.get('w1', shape=(10, 10))
130-
params2.load('test_paramdict.params', ctx)
130+
params2.load('test_parameter_dict.params', ctx)
131131

132132
# compare the values before and after save/load
133133
cur_w0 = params2.get('w0').data(ctx)
134134
cur_w1 = params2.get('w1').data(ctx)
135135
mx.test_utils.assert_almost_equal(prev_w0.asnumpy(), cur_w0.asnumpy())
136136
mx.test_utils.assert_almost_equal(prev_w1.asnumpy(), cur_w1.asnumpy())
137+
138+
# test the dtype casting functionality
139+
params0 = gluon.ParameterDict('')
140+
params0.get('w0', shape=(10, 10), dtype='float32')
141+
params0.get('w1', shape=(10, 10), dtype='int8')
142+
params0.initialize(mx.init.One(), ctx=ctx)
143+
params0.save('test_parameter_dict.params')
144+
145+
params1 = gluon.ParameterDict('')
146+
params1.get('w0', shape=(10, 10), dtype='float16')
147+
params1.get('w1', shape=(10, 10), dtype='float64')
148+
params1.load('test_parameter_dict.params', cast_dtype=True, dtype_source='current')
149+
assert params1['w0'].data().dtype == np.float16
150+
assert params1['w1'].data().dtype == np.float64
151+
params1.load('test_parameter_dict.params', cast_dtype=True, dtype_source='saved')
152+
assert params1['w0'].data().dtype == np.float32
153+
assert params1['w1'].data().dtype == np.int8
137154

138155

139156
@with_seed()
@@ -242,7 +259,7 @@ def __init__(self, **kwargs):
242259

243260

244261
@with_seed()
245-
def test_collect_paramters():
262+
def test_collect_parameters():
246263
net = nn.HybridSequential(prefix="test_")
247264
with net.name_scope():
248265
net.add(nn.Conv2D(10, 3))
@@ -355,18 +372,30 @@ def hybrid_forward(self, F, x):
355372
net_fp32.forward(data)
356373
net_fp32.export(tmpfile, 0)
357374

358-
# 2. Load the saved model and verify if all the params are loaded correctly.
359-
# and choose one of the param to verify the type if fp64.
360-
sm = mx.sym.load(tmpfile + '-symbol.json')
375+
# 2.a Load the saved model and verify if all the params are loaded correctly.
376+
# and choose one of the param to verify the type if fp64.\
377+
sym_file = tmpfile + '-symbol.json'
378+
params_file = tmpfile + '-0000.params'
379+
sm = mx.sym.load(sym_file)
361380
inputs = mx.sym.var('data', dtype='float64')
362381
net_fp64 = mx.gluon.SymbolBlock(sm, inputs)
363-
net_fp64.collect_params().load(tmpfile + '-0000.params', ctx=ctx)
364-
# 3. Get a conv layer's weight parameter name. Conv layer's weight param is
382+
net_fp64.collect_params().load(params_file, ctx=ctx)
383+
# Get a conv layer's weight parameter name. Conv layer's weight param is
365384
# expected to be of dtype casted, fp64.
366385
for param_name in net_fp64.params.keys():
367386
if 'conv' in param_name and 'weight' in param_name:
368387
break
369388
assert np.dtype(net_fp64.params[param_name].dtype) == np.dtype(np.float64)
389+
390+
# 3.b Verify same functionnality with the imports API
391+
net_fp_64 = mx.gluon.SymbolBlock.imports(sym_file, 'data', params_file, ctx=ctx)
392+
393+
# Get a conv layer's weight parameter name. Conv layer's weight param is
394+
# expected to be of dtype casted, fp64.
395+
for param_name in net_fp_64.params.keys():
396+
if 'conv' in param_name and 'weight' in param_name:
397+
break
398+
assert np.dtype(net_fp_64.params[param_name].dtype) == np.dtype(np.float64)
370399

371400
# Cast the symbol block to FP32 and try to forward a FP32 data.
372401
# This will verify SymbolBlock.cast() functionality.
@@ -2750,6 +2779,17 @@ def test_gluon_param_load():
27502779
net.cast('float16')
27512780
net.load_parameters('test_gluon_param_load.params', cast_dtype=True)
27522781
mx.nd.waitall()
2782+
2783+
@with_seed()
2784+
def test_gluon_param_load_dtype_source():
2785+
net = mx.gluon.nn.Dense(10, in_units=10)
2786+
net.initialize()
2787+
net.cast('float16')
2788+
net.save_parameters('test_gluon_param_load_dtype_source.params')
2789+
net.cast('float32')
2790+
net.load_parameters('test_gluon_param_load_dtype_source.params', cast_dtype=True, dtype_source="saved")
2791+
assert net.weight.dtype == np.float16
2792+
mx.nd.waitall()
27532793

27542794
if __name__ == '__main__':
27552795
import nose

0 commit comments

Comments
 (0)