Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 69f4f31

Browse files
[BUGFIX] fix model zoo parallel download (#17372)
* use temp file * fix dependency * Update model_store.py * Update test_gluon_model_zoo.py * remove NamedTempFile
1 parent 1cb738a commit 69f4f31

3 files changed

Lines changed: 52 additions & 16 deletions

File tree

python/mxnet/gluon/model_zoo/model_store.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,11 @@
2222
import os
2323
import zipfile
2424
import logging
25+
import tempfile
26+
import uuid
27+
import shutil
2528

26-
from ..utils import download, check_sha1
29+
from ..utils import download, check_sha1, replace_file
2730
from ... import base, util
2831

2932
_model_sha1 = {name: checksum for checksum, name in [
@@ -103,16 +106,21 @@ def get_model_file(name, root=os.path.join(base.data_dir(), 'models')):
103106

104107
util.makedirs(root)
105108

106-
zip_file_path = os.path.join(root, file_name+'.zip')
107109
repo_url = os.environ.get('MXNET_GLUON_REPO', apache_repo_url)
108110
if repo_url[-1] != '/':
109111
repo_url = repo_url + '/'
112+
113+
random_uuid = str(uuid.uuid4())
114+
temp_zip_file_path = os.path.join(root, file_name+'.zip'+random_uuid)
110115
download(_url_format.format(repo_url=repo_url, file_name=file_name),
111-
path=zip_file_path,
112-
overwrite=True)
113-
with zipfile.ZipFile(zip_file_path) as zf:
114-
zf.extractall(root)
115-
os.remove(zip_file_path)
116+
path=temp_zip_file_path, overwrite=True)
117+
with zipfile.ZipFile(temp_zip_file_path) as zf:
118+
temp_dir = tempfile.mkdtemp(dir=root)
119+
zf.extractall(temp_dir)
120+
temp_file_path = os.path.join(temp_dir, file_name+'.params')
121+
replace_file(temp_file_path, file_path)
122+
shutil.rmtree(temp_dir)
123+
os.remove(temp_zip_file_path)
116124

117125
if check_sha1(file_path, sha1_hash):
118126
return file_path

python/mxnet/gluon/utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from __future__ import absolute_import
2222

2323
__all__ = ['split_data', 'split_and_load', 'clip_global_norm',
24-
'check_sha1', 'download']
24+
'check_sha1', 'download', 'replace_file']
2525

2626
import os
2727
import sys
@@ -35,7 +35,7 @@
3535
import numpy as np
3636

3737
from .. import ndarray
38-
from ..util import is_np_shape, is_np_array
38+
from ..util import is_np_shape, is_np_array, makedirs
3939
from .. import numpy as _mx_np # pylint: disable=reimported
4040

4141

@@ -209,8 +209,14 @@ def check_sha1(filename, sha1_hash):
209209

210210
if not sys.platform.startswith('win32'):
211211
# refer to https://github.com/untitaker/python-atomicwrites
212-
def _replace_atomic(src, dst):
213-
"""Implement atomic os.replace with linux and OSX. Internal use only"""
212+
def replace_file(src, dst):
213+
"""Implement atomic os.replace with linux and OSX.
214+
215+
Parameters
216+
----------
217+
src : source file path
218+
dst : destination file path
219+
"""
214220
try:
215221
os.rename(src, dst)
216222
except OSError:
@@ -252,19 +258,25 @@ def _handle_errors(rv, src):
252258
finally:
253259
raise OSError(msg)
254260

255-
def _replace_atomic(src, dst):
261+
def replace_file(src, dst):
256262
"""Implement atomic os.replace with windows.
263+
257264
refer to https://docs.microsoft.com/en-us/windows/desktop/api/winbase/nf-winbase-movefileexw
258265
The function fails when one of the process(copy, flush, delete) fails.
259-
Internal use only"""
266+
267+
Parameters
268+
----------
269+
src : source file path
270+
dst : destination file path
271+
"""
260272
_handle_errors(ctypes.windll.kernel32.MoveFileExW(
261273
_str_to_unicode(src), _str_to_unicode(dst),
262274
_windows_default_flags | _MOVEFILE_REPLACE_EXISTING
263275
), src)
264276

265277

266278
def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_ssl=True):
267-
"""Download an given URL
279+
"""Download a given URL
268280
269281
Parameters
270282
----------
@@ -310,7 +322,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
310322
if overwrite or not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
311323
dirname = os.path.dirname(os.path.abspath(os.path.expanduser(fname)))
312324
if not os.path.exists(dirname):
313-
os.makedirs(dirname)
325+
makedirs(dirname)
314326
while retries + 1 > 0:
315327
# Disable pyling too broad Exception
316328
# pylint: disable=W0703
@@ -330,7 +342,7 @@ def download(url, path=None, overwrite=False, sha1_hash=None, retries=5, verify_
330342
# delete the temporary file
331343
if not os.path.exists(fname) or (sha1_hash and not check_sha1(fname, sha1_hash)):
332344
# atmoic operation in the same file system
333-
_replace_atomic('{}.{}'.format(fname, random_uuid), fname)
345+
replace_file('{}.{}'.format(fname, random_uuid), fname)
334346
else:
335347
try:
336348
os.remove('{}.{}'.format(fname, random_uuid))

tests/python/unittest/test_gluon_model_zoo.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from mxnet.gluon.model_zoo.vision import get_model
2121
import sys
2222
from common import setup_module, with_seed, teardown
23+
import multiprocessing
2324

2425

2526
def eprint(*args, **kwargs):
@@ -49,6 +50,21 @@ def test_models():
4950
model.collect_params().initialize()
5051
model(mx.nd.random.uniform(shape=data_shape)).wait_to_read()
5152

53+
def parallel_download(model_name):
54+
model = get_model(model_name, pretrained=True, root='./parallel_download')
55+
print(type(model))
56+
57+
@with_seed()
58+
def test_parallel_download():
59+
processes = []
60+
name = 'mobilenetv2_0.25'
61+
for _ in range(10):
62+
p = multiprocessing.Process(target=parallel_download, args=(name,))
63+
processes.append(p)
64+
for p in processes:
65+
p.start()
66+
for p in processes:
67+
p.join()
5268

5369
if __name__ == '__main__':
5470
import nose

0 commit comments

Comments
 (0)