@@ -194,18 +194,14 @@ def _init_kvstore(self):
194194
195195 if config ['update_on_kvstore' ] is not None :
196196 update_on_kvstore = config ['update_on_kvstore' ]
197-
198197 if kvstore :
199198 if self ._compression_params :
200199 kvstore .set_gradient_compression (self ._compression_params )
201200 self ._distributed = 'dist' in kvstore .type
202201 if self ._distributed :
203202 # kv.pull(row_sparse_grad) is not supported for dist kvstore
204- # Captures condition for dist_async, dist_device_sync or based on config for
205- # update_on_kvstore
206203 update_on_kvstore = self ._contains_sparse_weight or self ._contains_sparse_grad \
207- or 'device' in kvstore .type or 'async' in kvstore .type \
208- or config ['update_on_kvstore' ]
204+ or 'async' in kvstore .type
209205 if update_on_kvstore :
210206 # optimizer preferably needs to be set before init for multiprecision
211207 kvstore .set_optimizer (self ._optimizer )
@@ -273,20 +269,13 @@ def step(self, batch_size, ignore_stale_grad=False):
273269 If true, ignores Parameters with stale gradient (gradient that has not
274270 been updated by `backward` after last step) and skip update.
275271 """
276- rescale_grad = self ._scale / batch_size
277- if self ._update_on_kvstore and self ._distributed and \
278- self ._optimizer .rescale_grad != rescale_grad :
279- raise UserWarning ('Possible change in the `batch_size` from previous `step` detected.' \
280- 'Optimizer gradient normalizing factor will not change w.r.t new batch_size when ' \
281- 'update_on_kvstore=True and when distributed `kvstore` is used.' )
282-
283- self ._optimizer .rescale_grad = rescale_grad
284-
285272 if not self ._kv_initialized :
286273 self ._init_kvstore ()
287274 if self ._params_to_init :
288275 self ._init_params ()
289276
277+ self ._optimizer .rescale_grad = self ._scale / batch_size
278+
290279 self ._allreduce_grads ()
291280 self ._update (ignore_stale_grad )
292281
0 commit comments