-
Notifications
You must be signed in to change notification settings - Fork 497
[mxfp8 training] update triton_to_mxfp8_dim0 nan handling; fix offset int32 overflow issue #4201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
16169df
c94f1d4
8b94f73
5b98a27
f274137
2d64d9a
5192690
5aba1d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -625,3 +625,266 @@ def test_cuda_mx_dim0_not_supported(): | |
| rowwise=True, | ||
| colwise=False, | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
| @pytest.mark.skipif( | ||
| not is_sm_at_least_100() and not is_MI350(), | ||
| reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.", | ||
| ) | ||
| @pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,)) | ||
| def test_triton_mxfp8_dim0_special_values(scaling_mode: ScaleCalculationMode): | ||
| # Create tensor with special values - make it compatible with block_size=32 | ||
| block_size = 32 | ||
| special_vals = torch.zeros(2, block_size, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| # Fill first few elements of each row with special values | ||
| special_vals[0, :4] = torch.tensor( | ||
| [float("inf"), -float("inf"), float("nan"), 0.0], dtype=torch.bfloat16 | ||
| ) | ||
| special_vals[1, :4] = torch.tensor( | ||
| [ | ||
| torch.finfo(torch.float32).max, | ||
| torch.finfo(torch.float32).min, | ||
| torch.finfo(torch.float32).tiny, | ||
| -torch.finfo(torch.float32).tiny, | ||
| ], | ||
| dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference( | ||
| special_vals, block_size=block_size, scaling_mode=scaling_mode | ||
| ) | ||
| x_mx_t, x_s_t = triton_to_mxfp8_dim0( | ||
| special_vals, | ||
| inner_block_size=block_size, | ||
| scaling_mode=scaling_mode.value.lower(), | ||
| ) | ||
| x_mx_t = x_mx_t.to(torch.float32) | ||
| x_s_t = x_s_t.to(torch.uint8) | ||
| x_mx_ref = x_mx_ref.to(torch.float32) | ||
| x_s_ref = x_s_ref.to(torch.uint8) | ||
|
|
||
| # Check for NaNs in output (allow NaNs if input had NaNs, but check scales) | ||
| input_has_nan = special_vals.isnan().any() | ||
| if not input_has_nan: | ||
| assert not x_mx_t.isnan().any(), ( | ||
| "quantized tensor should not contain NaNs when input has no NaNs" | ||
| ) | ||
| assert not x_s_t.isnan().any(), ( | ||
| "scales should not contain NaNs when input has no NaNs" | ||
| ) | ||
|
|
||
| # Use NaN-aware comparison to handle nan != nan case properly | ||
| # Check NaN patterns match | ||
| nan_ref = torch.isnan(x_mx_ref) | ||
| nan_triton = torch.isnan(x_mx_t) | ||
| assert torch.equal(nan_ref, nan_triton), ( | ||
| "NaN pattern mismatch between reference and triton" | ||
| ) | ||
|
|
||
| # Check finite values | ||
| finite_mask = torch.isfinite(x_mx_ref) & torch.isfinite(x_mx_t) | ||
| if finite_mask.any(): | ||
| assert torch.equal(x_mx_ref[finite_mask], x_mx_t[finite_mask]), ( | ||
| "Finite values mismatch" | ||
| ) | ||
|
|
||
| # Check infinity patterns | ||
| inf_ref = torch.isinf(x_mx_ref) | ||
| inf_triton = torch.isinf(x_mx_t) | ||
| assert torch.equal(inf_ref, inf_triton), ( | ||
| "Infinity pattern mismatch between reference and triton" | ||
| ) | ||
| if inf_ref.any(): | ||
| assert torch.equal(x_mx_ref[inf_ref], x_mx_t[inf_ref]), ( | ||
| "Infinity values mismatch" | ||
| ) | ||
|
|
||
| # Check scales using exact comparison | ||
| x_s_ref_uint8 = x_s_ref.to(torch.uint8) | ||
| x_s_t_uint8 = x_s_t.to(torch.uint8) | ||
| assert torch.equal(x_s_t_uint8, x_s_ref_uint8), ( | ||
| "Scale values mismatch between reference and triton" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
| @pytest.mark.skipif( | ||
| not is_sm_at_least_100() and not is_MI350(), | ||
| reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.", | ||
| ) | ||
| @pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,)) | ||
| def test_triton_mxfp8_dim0_overflow_underflow(scaling_mode): | ||
| """Test with values near overflow and underflow thresholds.""" | ||
| # Values near float8_e4m3fn limits | ||
| f8_max = torch.finfo(torch.float8_e4m3fn).max # ~448 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tiny nit, we have exact values for these, the approx |
||
| f8_min = torch.finfo(torch.float8_e4m3fn).tiny # ~1.95e-06 | ||
| block_size = 32 | ||
|
|
||
| overflow_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| # Fill first few elements of each row with overflow/underflow values | ||
| overflow_vals[0, :4] = torch.tensor( | ||
| [f8_max * 0.9, f8_max * 1.1, f8_max * 2.0, f8_max * 10.0], dtype=torch.bfloat16 | ||
| ) | ||
| overflow_vals[1, :4] = torch.tensor( | ||
| [-f8_max * 0.9, -f8_max * 1.1, -f8_max * 2.0, -f8_max * 10.0], | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| overflow_vals[2, :4] = torch.tensor( | ||
| [f8_min * 0.1, f8_min * 0.5, f8_min * 2.0, f8_min * 10.0], dtype=torch.bfloat16 | ||
| ) | ||
| overflow_vals[3, :4] = torch.tensor( | ||
| [-f8_min * 0.1, -f8_min * 0.5, -f8_min * 2.0, -f8_min * 10.0], | ||
| dtype=torch.bfloat16, | ||
| ) | ||
|
|
||
| x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference( | ||
| overflow_vals, block_size=block_size, scaling_mode=scaling_mode | ||
| ) | ||
| x_mx_t, x_s_t = triton_to_mxfp8_dim0( | ||
| overflow_vals, | ||
| inner_block_size=block_size, | ||
| scaling_mode=scaling_mode.value.lower(), | ||
| ) | ||
|
|
||
| assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" | ||
| assert not x_s_t.isnan().any(), "scales should not contain NaNs" | ||
| torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) | ||
| torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
| @pytest.mark.skipif( | ||
| not is_sm_at_least_100() and not is_MI350(), | ||
| reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.", | ||
| ) | ||
| @pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,)) | ||
| def test_triton_mxfp8_dim0_extreme_range(scaling_mode): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment on what's being tested - we're testing that things match a reference, but not necessarily that the observed behaviour is actually correct -- both imo should be tested. This is also a good candidate to merge with the overflow/underflow test above, as I think they only differ in terms of what the extreme values are (which is parametrize-able)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, I deleted some tests and focused on having fewer, higher value tests. |
||
| """Test with tensors containing both very large and very small values.""" | ||
| # Mix of extreme values in same tensor to test scaling edge cases | ||
| block_size = 32 | ||
| extreme_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| # Fill first few elements with extreme values | ||
| extreme_vals[0, :4] = torch.tensor([1e30, 1e-30, 1e20, 1e-20], dtype=torch.bfloat16) | ||
| extreme_vals[1, :4] = torch.tensor( | ||
| [-1e30, -1e-30, -1e20, -1e-20], dtype=torch.bfloat16 | ||
| ) | ||
| extreme_vals[2, :4] = torch.tensor( | ||
| [torch.finfo(torch.float32).max, torch.finfo(torch.float32).tiny, 1.0, -1.0], | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| extreme_vals[3, :4] = torch.tensor([0.0, 1e-40, 1e40, -1e40], dtype=torch.bfloat16) | ||
|
|
||
| x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference( | ||
| extreme_vals, block_size=block_size, scaling_mode=scaling_mode | ||
| ) | ||
| x_mx_t, x_s_t = triton_to_mxfp8_dim0( | ||
| extreme_vals, | ||
| inner_block_size=block_size, | ||
| scaling_mode=scaling_mode.value.lower(), | ||
| ) | ||
|
|
||
| assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" | ||
| assert not x_s_t.isnan().any(), "scales should not contain NaNs" | ||
| torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) | ||
| torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
| @pytest.mark.skipif( | ||
| not is_sm_at_least_100() and not is_MI350(), | ||
| reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.", | ||
| ) | ||
| @pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,)) | ||
| def test_triton_mxfp8_dim0_denormals_subnormals(scaling_mode): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment on what's being tested. |
||
| """Test with denormal/subnormal values that might cause precision issues.""" | ||
| # Create values in the denormal range | ||
| bf16_tiny = torch.finfo(torch.bfloat16).tiny | ||
| f32_tiny = torch.finfo(torch.float32).tiny | ||
| block_size = 32 | ||
|
|
||
| denormal_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| # Fill first few elements with denormal values | ||
| denormal_vals[0, :4] = torch.tensor( | ||
| [bf16_tiny, bf16_tiny * 0.5, bf16_tiny * 0.1, bf16_tiny * 2.0], | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| denormal_vals[1, :4] = torch.tensor( | ||
| [f32_tiny, f32_tiny * 0.5, f32_tiny * 0.1, f32_tiny * 2.0], dtype=torch.bfloat16 | ||
| ) | ||
| denormal_vals[2, :4] = torch.tensor( | ||
| [-bf16_tiny, -bf16_tiny * 0.5, -bf16_tiny * 0.1, -bf16_tiny * 2.0], | ||
| dtype=torch.bfloat16, | ||
| ) | ||
| denormal_vals[3, :4] = torch.tensor( | ||
| [1e-40, 1e-38, 1e-36, 1e-34], dtype=torch.bfloat16 | ||
| ) # Very small values | ||
|
|
||
| x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference( | ||
| denormal_vals, block_size=block_size, scaling_mode=scaling_mode | ||
| ) | ||
| x_mx_t, x_s_t = triton_to_mxfp8_dim0( | ||
| denormal_vals, | ||
| inner_block_size=block_size, | ||
| scaling_mode=scaling_mode.value.lower(), | ||
| ) | ||
|
|
||
| assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs" | ||
| assert not x_s_t.isnan().any(), "scales should not contain NaNs" | ||
| torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) | ||
| torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") | ||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") | ||
| @pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,)) | ||
| def test_all_nan_block_scale_behavior(scaling_mode): | ||
|
slayton58 marked this conversation as resolved.
|
||
| """ | ||
| Test that PyTorch, CUDA, and Triton implementations align on NaN scale behavior: | ||
| - Mixed real + NaN: scale = max of real values (ignore NaNs) | ||
| - All NaN block: scale = NaN (not -inf) | ||
| """ | ||
| from torchao.prototype.mx_formats.mx_tensor import to_mx | ||
|
|
||
| block_size = 32 | ||
|
|
||
| # Create test case with both mixed and all-NaN blocks | ||
| # First 32 elements: mixed NaN + real values | ||
| # Second 32 elements: all NaN values | ||
| # Third 32 elements: normal values for reference | ||
| test_vals = torch.zeros(3 * block_size, dtype=torch.bfloat16, device="cuda") | ||
|
|
||
| # Block 1: Mixed NaN + real values [NaN, 1.0, NaN, 5.0, NaN, 3.0, ...] | ||
| test_vals[:block_size:3] = float("nan") # Every 3rd element is NaN | ||
| test_vals[1:block_size:3] = 1.0 # Some real values | ||
| test_vals[2:block_size:3] = 5.0 # Some larger real values | ||
|
|
||
| # Block 2: All NaN values | ||
| test_vals[block_size : 2 * block_size] = float("nan") | ||
|
|
||
| # Block 3: Normal values for reference | ||
| test_vals[2 * block_size :] = torch.linspace(1.0, 10.0, block_size) | ||
|
|
||
| # Test PyTorch implementation through to_mx | ||
| scale_pytorch, _ = to_mx(test_vals, torch.float8_e4m3fn, block_size, scaling_mode) | ||
|
|
||
| # Convert to regular tensor for easier inspection | ||
| scale_pytorch_vals = scale_pytorch.to(torch.float32) | ||
|
|
||
| # Test expectations: | ||
| # Block 0 (mixed): Should have real scale value (not NaN), based on max real value | ||
| assert not torch.isnan(scale_pytorch_vals[0]), ( | ||
| "Mixed NaN+real block should have real scale (PyTorch should ignore NaNs)" | ||
| ) | ||
|
|
||
| # Block 1 (all NaN): Should have NaN scale to match CUDA/Triton behavior | ||
| assert torch.isnan(scale_pytorch_vals[1]), ( | ||
| "All-NaN block should have NaN scale to match CUDA/Triton behavior" | ||
| ) | ||
|
|
||
| # Block 2 (normal): Should have real scale | ||
| assert not torch.isnan(scale_pytorch_vals[2]), "Normal block should have real scale" | ||
Uh oh!
There was an error while loading. Please reload this page.