Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit c723ae2

Browse files
authored
Fix optimize_for backend_opts to be empty dictionary instead of None (#19812)
1 parent f9d90c9 commit c723ae2

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

python/mxnet/gluon/block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1397,7 +1397,7 @@ def optimize_for(self, x, *args, backend=None, clear=False,
13971397
self._first_forward = True
13981398
# clear the backend
13991399
self._backend = None
1400-
self._backend_opts = None
1400+
self._backend_opts = {}
14011401

14021402
def _clear_cached_op(self):
14031403
self._cached_graph = ()

tests/python/unittest/test_extensions.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,11 +182,17 @@ def test_subgraph():
182182
sym_filename, params_filename = sym_block3.export('optimized')
183183
assert sym_filename == 'optimized-symbol.json'
184184
assert params_filename is None
185+
186+
# Test with additional input to subgraph op
187+
sym_block3.optimize_for(a_data, b_data, backend="addInputPass")
188+
out5 = sym_block3(a_data, b_data)
189+
190+
# Reload exported block
185191
sym_block4 = nn.SymbolBlock.imports(sym_filename, ['a','b'], params_filename)
186192

187-
out5 = sym_block4(a_data, b_data)
193+
out6 = sym_block4(a_data, b_data)
188194
# check that result matches one executed by MXNet
189-
assert_almost_equal(out[0].asnumpy(), out5[0].asnumpy(), rtol=1e-3, atol=1e-3)
195+
assert_almost_equal(out[0].asnumpy(), out6[0].asnumpy(), rtol=1e-3, atol=1e-3)
190196

191197
@pytest.mark.skipif(check_platform(['x86_64']), reason="not all machine types supported")
192198
@pytest.mark.skipif(is_cd_run(), reason="continuous delivery run - ignoring test")

0 commit comments

Comments
 (0)