@@ -94,7 +94,7 @@ def test_parameter_invalid_access():
9494 assertRaises (RuntimeError , p1 .list_row_sparse_data , row_id )
9595
9696@with_seed ()
97- def test_paramdict ():
97+ def test_parameter_dict ():
9898 ctx = mx .cpu (1 )
9999 params0 = gluon .ParameterDict ('net_' )
100100 params0 .get ('w0' , shape = (10 , 10 ))
@@ -107,15 +107,15 @@ def test_paramdict():
107107 prev_w0 = params0 .get ('w0' ).data (ctx )
108108 prev_w1 = params0 .get ('w1' ).row_sparse_data (all_row_ids )
109109 # save params
110- params0 .save ('test_paramdict .params' )
110+ params0 .save ('test_parameter_dict .params' )
111111
112112 # load params
113113 params1 = gluon .ParameterDict ('net_' )
114114 params1 .get ('w0' , shape = (10 , 10 ))
115115 params1 .get ('w1' , shape = (10 , 10 ), stype = 'row_sparse' )
116- params1 .load ('test_paramdict .params' , ctx )
116+ params1 .load ('test_parameter_dict .params' , ctx )
117117 trainer1 = mx .gluon .Trainer (params1 , 'sgd' )
118-
118+
119119 # compare the values before and after save/load
120120 cur_w0 = params1 .get ('w0' ).data (ctx )
121121 cur_w1 = params1 .get ('w1' ).row_sparse_data (all_row_ids )
@@ -127,13 +127,30 @@ def test_paramdict():
127127 params2 = gluon .ParameterDict ('net_' )
128128 params2 .get ('w0' , shape = (10 , 10 ))
129129 params2 .get ('w1' , shape = (10 , 10 ))
130- params2 .load ('test_paramdict .params' , ctx )
130+ params2 .load ('test_parameter_dict .params' , ctx )
131131
132132 # compare the values before and after save/load
133133 cur_w0 = params2 .get ('w0' ).data (ctx )
134134 cur_w1 = params2 .get ('w1' ).data (ctx )
135135 mx .test_utils .assert_almost_equal (prev_w0 .asnumpy (), cur_w0 .asnumpy ())
136136 mx .test_utils .assert_almost_equal (prev_w1 .asnumpy (), cur_w1 .asnumpy ())
137+
138+ # test the dtype casting functionality
139+ params0 = gluon .ParameterDict ('' )
140+ params0 .get ('w0' , shape = (10 , 10 ), dtype = 'float32' )
141+ params0 .get ('w1' , shape = (10 , 10 ), dtype = 'int8' )
142+ params0 .initialize (mx .init .One (), ctx = ctx )
143+ params0 .save ('test_parameter_dict.params' )
144+
145+ params1 = gluon .ParameterDict ('' )
146+ params1 .get ('w0' , shape = (10 , 10 ), dtype = 'float16' )
147+ params1 .get ('w1' , shape = (10 , 10 ), dtype = 'float64' )
148+ params1 .load ('test_parameter_dict.params' , cast_dtype = True , dtype_source = 'current' )
149+ assert params1 ['w0' ].data ().dtype == np .float16
150+ assert params1 ['w1' ].data ().dtype == np .float64
151+ params1 .load ('test_parameter_dict.params' , cast_dtype = True , dtype_source = 'saved' )
152+ assert params1 ['w0' ].data ().dtype == np .float32
153+ assert params1 ['w1' ].data ().dtype == np .int8
137154
138155
139156@with_seed ()
@@ -242,7 +259,7 @@ def __init__(self, **kwargs):
242259
243260
244261@with_seed ()
245- def test_collect_paramters ():
262+ def test_collect_parameters ():
246263 net = nn .HybridSequential (prefix = "test_" )
247264 with net .name_scope ():
248265 net .add (nn .Conv2D (10 , 3 ))
@@ -355,18 +372,30 @@ def hybrid_forward(self, F, x):
355372 net_fp32 .forward (data )
356373 net_fp32 .export (tmpfile , 0 )
357374
358- # 2. Load the saved model and verify if all the params are loaded correctly.
359- # and choose one of the param to verify the type if fp64.
360- sm = mx .sym .load (tmpfile + '-symbol.json' )
375+ # 2.a Load the saved model and verify if all the params are loaded correctly.
376+ # and choose one of the param to verify the type if fp64.\
377+ sym_file = tmpfile + '-symbol.json'
378+ params_file = tmpfile + '-0000.params'
379+ sm = mx .sym .load (sym_file )
361380 inputs = mx .sym .var ('data' , dtype = 'float64' )
362381 net_fp64 = mx .gluon .SymbolBlock (sm , inputs )
363- net_fp64 .collect_params ().load (tmpfile + '-0000.params' , ctx = ctx )
364- # 3. Get a conv layer's weight parameter name. Conv layer's weight param is
382+ net_fp64 .collect_params ().load (params_file , ctx = ctx )
383+ # Get a conv layer's weight parameter name. Conv layer's weight param is
365384 # expected to be of dtype casted, fp64.
366385 for param_name in net_fp64 .params .keys ():
367386 if 'conv' in param_name and 'weight' in param_name :
368387 break
369388 assert np .dtype (net_fp64 .params [param_name ].dtype ) == np .dtype (np .float64 )
389+
390+ # 3.b Verify same functionnality with the imports API
391+ net_fp_64 = mx .gluon .SymbolBlock .imports (sym_file , 'data' , params_file , ctx = ctx )
392+
393+ # Get a conv layer's weight parameter name. Conv layer's weight param is
394+ # expected to be of dtype casted, fp64.
395+ for param_name in net_fp_64 .params .keys ():
396+ if 'conv' in param_name and 'weight' in param_name :
397+ break
398+ assert np .dtype (net_fp_64 .params [param_name ].dtype ) == np .dtype (np .float64 )
370399
371400 # Cast the symbol block to FP32 and try to forward a FP32 data.
372401 # This will verify SymbolBlock.cast() functionality.
@@ -2750,6 +2779,17 @@ def test_gluon_param_load():
27502779 net .cast ('float16' )
27512780 net .load_parameters ('test_gluon_param_load.params' , cast_dtype = True )
27522781 mx .nd .waitall ()
2782+
2783+ @with_seed ()
2784+ def test_gluon_param_load_dtype_source ():
2785+ net = mx .gluon .nn .Dense (10 , in_units = 10 )
2786+ net .initialize ()
2787+ net .cast ('float16' )
2788+ net .save_parameters ('test_gluon_param_load_dtype_source.params' )
2789+ net .cast ('float32' )
2790+ net .load_parameters ('test_gluon_param_load_dtype_source.params' , cast_dtype = True , dtype_source = "saved" )
2791+ assert net .weight .dtype == np .float16
2792+ mx .nd .waitall ()
27532793
27542794if __name__ == '__main__' :
27552795 import nose
0 commit comments