Skip to content

Commit 23edc89

Browse files
committed
address feedback about patching
* improve thread-safety of patching * gate repeated patch calls there still exist problematic edge cases (race condition where one thread restores while another is using patched functions)
1 parent efbd93b commit 23edc89

3 files changed

Lines changed: 106 additions & 51 deletions

File tree

mkl_fft/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949

5050
import mkl_fft.interfaces # isort: skip
5151

52-
5352
__all__ = [
5453
"fft",
5554
"ifft",

mkl_fft/_patch_numpy.py

Lines changed: 87 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -27,94 +27,131 @@
2727
"""Define functions for patching NumPy with MKL-based NumPy interface."""
2828

2929
from contextlib import ContextDecorator
30-
from threading import local as threading_local
30+
from threading import Lock
3131

3232
import numpy as np
3333

3434
import mkl_fft.interfaces.numpy_fft as _nfft
3535

36-
_tls = threading_local()
3736

38-
39-
class _Patch:
40-
"""Internal object for patching NumPy with mkl_fft interfaces."""
41-
42-
_is_patched = False
43-
__patched_functions__ = _nfft.__all__
44-
_restore_dict = {}
37+
class _GlobalPatch(ContextDecorator):
38+
def __init__(self):
39+
self._lock = Lock()
40+
self._patch_count = 0
41+
self._restore_dict = {}
42+
# make _patched_functions a tuple (immutable)
43+
self._patched_functions = tuple(_nfft.__all__)
4544

4645
def _register_func(self, name, func):
47-
if name not in self.__patched_functions__:
48-
raise ValueError("%s not an mkl_fft function." % name)
49-
f = getattr(np.fft, name)
50-
self._restore_dict[name] = f
46+
if name not in self._patched_functions:
47+
raise ValueError(f"{name} not an mkl_fft function.")
48+
if name not in self._restore_dict:
49+
self._restore_dict[name] = getattr(np.fft, name)
5150
setattr(np.fft, name, func)
5251

5352
def _restore_func(self, name, verbose=False):
54-
if name not in self.__patched_functions__:
55-
raise ValueError("%s not an mkl_fft function." % name)
53+
if name not in self._patched_functions:
54+
raise ValueError(f"{name} not an mkl_fft function.")
5655
try:
5756
val = self._restore_dict[name]
5857
except KeyError:
5958
if verbose:
60-
print("failed to restore")
59+
print(f"failed to restore {name}")
6160
return
6261
else:
6362
if verbose:
64-
print("found and restoring...")
63+
print(f"found and restoring {name}...")
6564
setattr(np.fft, name, val)
6665

67-
def restore(self, verbose=False):
68-
for name in self._restore_dict.keys():
69-
self._restore_func(name, verbose=verbose)
70-
self._is_patched = False
71-
72-
def do_patch(self):
73-
for f in self.__patched_functions__:
74-
self._register_func(f, getattr(_nfft, f))
75-
self._is_patched = True
66+
def do_patch(self, verbose=False):
67+
with self._lock:
68+
if self._patch_count == 0:
69+
if verbose:
70+
print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.")
71+
print(
72+
"Please direct bug reports to https://github.com/IntelPython/mkl_fft"
73+
)
74+
for f in self._patched_functions:
75+
self._register_func(f, getattr(_nfft, f))
76+
self._patch_count += 1
77+
78+
def do_restore(self, verbose=False):
79+
with self._lock:
80+
if self._patch_count > 0:
81+
self._patch_count -= 1
82+
if self._patch_count == 0:
83+
if verbose:
84+
print("Now restoring original NumPy FFT submodule.")
85+
for name in tuple(self._restore_dict):
86+
self._restore_func(name, verbose=verbose)
87+
self._restore_dict.clear()
7688

7789
def is_patched(self):
78-
return self._is_patched
90+
with self._lock:
91+
return self._patch_count > 0
7992

93+
def __enter__(self):
94+
self.do_patch()
95+
return self
8096

81-
def _initialize_tls():
82-
_tls.patch = _Patch()
83-
_tls.initialized = True
97+
def __exit__(self, *exc):
98+
self.do_restore()
99+
return False
84100

85101

86-
def _is_tls_initialized():
87-
return (getattr(_tls, "initialized", None) is not None) and (
88-
_tls.initialized is True
89-
)
102+
_patch = _GlobalPatch()
90103

91104

92105
def patch_numpy_fft(verbose=False):
93-
if verbose:
94-
print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.")
95-
print(
96-
"Please direct bug reports to https://github.com/IntelPython/mkl_fft"
97-
)
98-
if not _is_tls_initialized():
99-
_initialize_tls()
100-
_tls.patch.do_patch()
106+
"""Patch NumPy's fft submodule with mkl_fft's numpy_interface.
107+
108+
Parameters
109+
----------
110+
verbose : bool, optional
111+
print message when starting the patching process.
112+
113+
"""
114+
_patch.do_patch(verbose=verbose)
101115

102116

103117
def restore_numpy_fft(verbose=False):
104-
if verbose:
105-
print("Now restoring original NumPy FFT submodule.")
106-
if not _is_tls_initialized():
107-
_initialize_tls()
108-
_tls.patch.restore(verbose=verbose)
118+
"""
119+
Restore NumPy's fft submodule to its original implementations.
120+
121+
Parameters
122+
----------
123+
verbose : bool, optional
124+
print message when starting restoration process.
125+
126+
"""
127+
_patch.do_restore(verbose=verbose)
109128

110129

111130
def is_patched():
112-
if not _is_tls_initialized():
113-
_initialize_tls()
114-
return _tls.patch.is_patched()
131+
"""Return True if NumPy's fft submodule is currently patched by mkl_fft."""
132+
return _patch.is_patched()
115133

116134

117135
class mkl_fft(ContextDecorator):
136+
"""
137+
Context manager and decorator to temporarily patch NumPy fft submodule
138+
with MKL-based implementations.
139+
140+
Examples
141+
--------
142+
>>> import mkl_fft
143+
>>> mkl_fft.is_patched()
144+
# False
145+
146+
>>> with mkl_fft.mkl_fft(): # Enable mkl_fft in Numpy
147+
>>> print(mkl_fft.is_patched())
148+
# True
149+
150+
>>> mkl_fft.is_patched()
151+
# False
152+
153+
"""
154+
118155
def __enter__(self):
119156
patch_numpy_fft()
120157
return self

mkl_fft/tests/test_patch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,22 @@ def test_patch():
4040
mkl_fft.restore_numpy_fft() # Disable mkl_fft in Numpy
4141
assert not mkl_fft.is_patched()
4242
assert np.fft.fft.__module__ == old_module
43+
44+
45+
def test_patch_redundant_patching():
46+
old_module = np.fft.fft.__module__
47+
assert not mkl_fft.is_patched()
48+
49+
mkl_fft.patch_numpy_fft()
50+
mkl_fft.patch_numpy_fft()
51+
52+
assert mkl_fft.is_patched()
53+
assert np.fft.fft.__module__ == _nfft.fft.__module__
54+
55+
mkl_fft.restore_numpy_fft()
56+
assert mkl_fft.is_patched()
57+
assert np.fft.fft.__module__ == _nfft.fft.__module__
58+
59+
mkl_fft.restore_numpy_fft()
60+
assert not mkl_fft.is_patched()
61+
assert np.fft.fft.__module__ == old_module

0 commit comments

Comments
 (0)