@@ -28,11 +28,7 @@ def get_device(use_cuda: bool = True) -> torch.device:
2828# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py
2929@torch .no_grad ()
3030def verbose_allclose (
31- received : torch .Tensor ,
32- expected : torch .Tensor ,
33- rtol = 1e-05 ,
34- atol = 1e-08 ,
35- max_print = 5
31+ received : torch .Tensor , expected : torch .Tensor , rtol = 1e-05 , atol = 1e-08 , max_print = 5
3632) -> list [str ]:
3733 """
3834 Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
@@ -64,9 +60,13 @@ def verbose_allclose(
6460 nan_mismatched = torch .logical_xor (torch .isnan (received ), torch .isnan (expected ))
6561
6662 # Find +inf mismatched elements
67- posinf_mismatched = torch .logical_xor (torch .isposinf (received ), torch .isposinf (expected ))
63+ posinf_mismatched = torch .logical_xor (
64+ torch .isposinf (received ), torch .isposinf (expected )
65+ )
6866 # Find -inf mismatched elements
69- neginf_mismatched = torch .logical_xor (torch .isneginf (received ), torch .isneginf (expected ))
67+ neginf_mismatched = torch .logical_xor (
68+ torch .isneginf (received ), torch .isneginf (expected )
69+ )
7070
7171 # Find all mismatched elements
7272 mismatched = torch .logical_or (
@@ -87,14 +87,18 @@ def verbose_allclose(
8787 i = tuple (index .tolist ())
8888 mismatch_details .append (f"ERROR AT { i } : { received [i ]} { expected [i ]} " )
8989 if num_mismatched > max_print :
90- mismatch_details .append (f"... and { num_mismatched - max_print } more mismatched elements." )
90+ mismatch_details .append (
91+ f"... and { num_mismatched - max_print } more mismatched elements."
92+ )
9193 return mismatch_details
9294
9395 return []
9496
9597
9698@torch .no_grad ()
97- def verbose_allequal (received : torch .Tensor , expected : torch .Tensor , max_print : int = 5 ):
99+ def verbose_allequal (
100+ received : torch .Tensor , expected : torch .Tensor , max_print : int = 5
101+ ):
98102 """
99103 Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches.
100104
@@ -120,32 +124,43 @@ def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print:
120124 i = tuple (index .tolist ())
121125 mismatch_details .append (f"ERROR AT { i } : { received [i ]} { expected [i ]} " )
122126 if num_mismatched > max_print :
123- mismatch_details .append (f"... and { num_mismatched - max_print } more mismatched elements." )
127+ mismatch_details .append (
128+ f"... and { num_mismatched - max_print } more mismatched elements."
129+ )
124130 return mismatch_details
125131
126132 return []
127133
128134
129- def match_reference (data , output , reference : callable , rtol = 1e-05 , atol = 1e-08 ) -> tuple [bool , str ]:
135+ def match_reference (
136+ data , output , reference : callable , rtol = 1e-05 , atol = 1e-08
137+ ) -> tuple [bool , str ]:
130138 """
131139 Convenient "default" implementation for tasks' `check_implementation` function.
132140 """
133141 expected = reference (data )
134142
135143 if len (output ) != len (expected ):
136- return False , f"output length mismatch: got { len (output )} , expected { len (expected )} "
144+ return (
145+ False ,
146+ f"output length mismatch: got { len (output )} , expected { len (expected )} " ,
147+ )
137148
138149 for i , (output_i , expected_i ) in enumerate (zip (output , expected )):
139150 reasons = verbose_allclose (output_i , expected_i , rtol = rtol , atol = atol )
140151 if len (reasons ) > 0 :
141- return False , f"mismatch found! custom implementation doesn't match reference: { i } { reasons } "
152+ return (
153+ False ,
154+ f"mismatch found! custom implementation doesn't match reference: { i } { reasons } " ,
155+ )
142156
143- return True , ''
157+ return True , ""
144158
145159
146160def make_match_reference (reference : callable , ** kwargs ):
147161 def wrapped (data , output ):
148162 return match_reference (data , output , reference = reference , ** kwargs )
163+
149164 return wrapped
150165
151166
@@ -156,7 +171,7 @@ def __init__(self):
156171 self .cublas = None
157172
158173 def __enter__ (self ):
159- self .cublas = os .environ .get (' CUBLAS_WORKSPACE_CONFIG' , '' )
174+ self .cublas = os .environ .get (" CUBLAS_WORKSPACE_CONFIG" , "" )
160175 self .allow_tf32 = torch .backends .cudnn .allow_tf32
161176 self .deterministic = torch .backends .cudnn .deterministic
162177 torch .backends .cudnn .allow_tf32 = False
@@ -168,7 +183,8 @@ def __exit__(self, exc_type, exc_value, traceback):
168183 torch .backends .cudnn .allow_tf32 = self .allow_tf32
169184 torch .backends .cudnn .deterministic = self .deterministic
170185 torch .use_deterministic_algorithms (False )
171- os .environ ['CUBLAS_WORKSPACE_CONFIG' ] = self .cublas
186+ os .environ ["CUBLAS_WORKSPACE_CONFIG" ] = self .cublas
187+
172188
173189def clear_l2_cache ():
174190 # import cupy as cp
@@ -177,4 +193,12 @@ def clear_l2_cache():
177193 dummy = torch .empty ((32 , 1024 , 1024 ), dtype = torch .int64 , device = "cuda" )
178194 # write stuff to
179195 dummy .fill_ (42 )
180- del dummy
196+ del dummy
197+
198+
199+ def clear_l2_cache_large ():
200+ # import cupy as cp
201+ # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0)
202+ # create a large dummy tensor
203+ dummy = torch .randn ((16000 , 1024 , 1024 ), device = "cuda" )
204+ del dummy
0 commit comments