Skip to content

Commit 9ca8ea5

Browse files
committed
Fix: add better l2 cache clear
1 parent 66065c2 commit 9ca8ea5

1 file changed

Lines changed: 41 additions & 17 deletions

File tree

  • problems/nvidia/nvfp4_group_gemm

problems/nvidia/nvfp4_group_gemm/utils.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,7 @@ def get_device(use_cuda: bool = True) -> torch.device:
2828
# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py
2929
@torch.no_grad()
3030
def verbose_allclose(
31-
received: torch.Tensor,
32-
expected: torch.Tensor,
33-
rtol=1e-05,
34-
atol=1e-08,
35-
max_print=5
31+
received: torch.Tensor, expected: torch.Tensor, rtol=1e-05, atol=1e-08, max_print=5
3632
) -> list[str]:
3733
"""
3834
Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches.
@@ -64,9 +60,13 @@ def verbose_allclose(
6460
nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected))
6561

6662
# Find +inf mismatched elements
67-
posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected))
63+
posinf_mismatched = torch.logical_xor(
64+
torch.isposinf(received), torch.isposinf(expected)
65+
)
6866
# Find -inf mismatched elements
69-
neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected))
67+
neginf_mismatched = torch.logical_xor(
68+
torch.isneginf(received), torch.isneginf(expected)
69+
)
7070

7171
# Find all mismatched elements
7272
mismatched = torch.logical_or(
@@ -87,14 +87,18 @@ def verbose_allclose(
8787
i = tuple(index.tolist())
8888
mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}")
8989
if num_mismatched > max_print:
90-
mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
90+
mismatch_details.append(
91+
f"... and {num_mismatched - max_print} more mismatched elements."
92+
)
9193
return mismatch_details
9294

9395
return []
9496

9597

9698
@torch.no_grad()
97-
def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5):
99+
def verbose_allequal(
100+
received: torch.Tensor, expected: torch.Tensor, max_print: int = 5
101+
):
98102
"""
99103
Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches.
100104
@@ -120,32 +124,43 @@ def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print:
120124
i = tuple(index.tolist())
121125
mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}")
122126
if num_mismatched > max_print:
123-
mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.")
127+
mismatch_details.append(
128+
f"... and {num_mismatched - max_print} more mismatched elements."
129+
)
124130
return mismatch_details
125131

126132
return []
127133

128134

129-
def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]:
135+
def match_reference(
136+
data, output, reference: callable, rtol=1e-05, atol=1e-08
137+
) -> tuple[bool, str]:
130138
"""
131139
Convenient "default" implementation for tasks' `check_implementation` function.
132140
"""
133141
expected = reference(data)
134142

135143
if len(output) != len(expected):
136-
return False, f"output length mismatch: got {len(output)}, expected {len(expected)}"
144+
return (
145+
False,
146+
f"output length mismatch: got {len(output)}, expected {len(expected)}",
147+
)
137148

138149
for i, (output_i, expected_i) in enumerate(zip(output, expected)):
139150
reasons = verbose_allclose(output_i, expected_i, rtol=rtol, atol=atol)
140151
if len(reasons) > 0:
141-
return False, f"mismatch found! custom implementation doesn't match reference: {i} {reasons}"
152+
return (
153+
False,
154+
f"mismatch found! custom implementation doesn't match reference: {i} {reasons}",
155+
)
142156

143-
return True, ''
157+
return True, ""
144158

145159

146160
def make_match_reference(reference: callable, **kwargs):
147161
def wrapped(data, output):
148162
return match_reference(data, output, reference=reference, **kwargs)
163+
149164
return wrapped
150165

151166

@@ -156,7 +171,7 @@ def __init__(self):
156171
self.cublas = None
157172

158173
def __enter__(self):
159-
self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '')
174+
self.cublas = os.environ.get("CUBLAS_WORKSPACE_CONFIG", "")
160175
self.allow_tf32 = torch.backends.cudnn.allow_tf32
161176
self.deterministic = torch.backends.cudnn.deterministic
162177
torch.backends.cudnn.allow_tf32 = False
@@ -168,7 +183,8 @@ def __exit__(self, exc_type, exc_value, traceback):
168183
torch.backends.cudnn.allow_tf32 = self.allow_tf32
169184
torch.backends.cudnn.deterministic = self.deterministic
170185
torch.use_deterministic_algorithms(False)
171-
os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas
186+
os.environ["CUBLAS_WORKSPACE_CONFIG"] = self.cublas
187+
172188

173189
def clear_l2_cache():
174190
# import cupy as cp
@@ -177,4 +193,12 @@ def clear_l2_cache():
177193
dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda")
178194
# write stuff to
179195
dummy.fill_(42)
180-
del dummy
196+
del dummy
197+
198+
199+
def clear_l2_cache_large():
200+
# import cupy as cp
201+
# cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0)
202+
# create a large dummy tensor
203+
dummy = torch.randn((16000, 1024, 1024), device="cuda")
204+
del dummy

0 commit comments

Comments
 (0)