Skip to content

Commit 46eec1e

Browse files
[mxfp8 training] triton_to_mxfp8_dim0 nan handling consistent with torch reference
stack-info: PR: #4201, branch: danielvegamyhre/stack/162
1 parent ce07646 commit 46eec1e

2 files changed

Lines changed: 309 additions & 43 deletions

File tree

test/prototype/mx_formats/test_kernels.py

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,3 +625,223 @@ def test_cuda_mx_dim0_not_supported():
625625
rowwise=True,
626626
colwise=False,
627627
)
628+
629+
630+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
631+
@pytest.mark.skipif(
632+
not is_sm_at_least_100() and not is_MI350(),
633+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
634+
)
635+
def test_triton_mxfp8_dim0_special_values():
636+
# Test only RCEIL mode to match canonical PyTorch behavior
637+
scaling_mode = ScaleCalculationMode.RCEIL
638+
639+
# Create tensor with special values - make it compatible with block_size=32
640+
block_size = 32
641+
special_vals = torch.zeros(2, block_size, dtype=torch.bfloat16, device="cuda")
642+
643+
# Fill first few elements of each row with special values
644+
special_vals[0, :4] = torch.tensor(
645+
[float("inf"), -float("inf"), float("nan"), 0.0], dtype=torch.bfloat16
646+
)
647+
special_vals[1, :4] = torch.tensor(
648+
[
649+
torch.finfo(torch.float32).max,
650+
torch.finfo(torch.float32).min,
651+
torch.finfo(torch.float32).tiny,
652+
-torch.finfo(torch.float32).tiny,
653+
],
654+
dtype=torch.bfloat16,
655+
)
656+
657+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
658+
special_vals, block_size=block_size, scaling_mode=scaling_mode
659+
)
660+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
661+
special_vals,
662+
inner_block_size=block_size,
663+
scaling_mode=scaling_mode.value.lower(),
664+
)
665+
x_mx_t = x_mx_t.to(torch.float32)
666+
x_s_t = x_s_t.to(torch.uint8)
667+
x_mx_ref = x_mx_ref.to(torch.float32)
668+
x_s_ref = x_s_ref.to(torch.uint8)
669+
670+
# Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
671+
input_has_nan = special_vals.isnan().any()
672+
if not input_has_nan:
673+
assert not x_mx_t.isnan().any(), (
674+
"quantized tensor should not contain NaNs when input has no NaNs"
675+
)
676+
assert not x_s_t.isnan().any(), (
677+
"scales should not contain NaNs when input has no NaNs"
678+
)
679+
680+
# Use NaN-aware comparison to handle nan != nan case properly
681+
# Check NaN patterns match
682+
nan_ref = torch.isnan(x_mx_ref)
683+
nan_triton = torch.isnan(x_mx_t)
684+
assert torch.equal(nan_ref, nan_triton), (
685+
"NaN pattern mismatch between reference and triton"
686+
)
687+
688+
# Check finite values
689+
finite_mask = torch.isfinite(x_mx_ref) & torch.isfinite(x_mx_t)
690+
if finite_mask.any():
691+
assert torch.equal(x_mx_ref[finite_mask], x_mx_t[finite_mask]), (
692+
"Finite values mismatch"
693+
)
694+
695+
# Check infinity patterns
696+
inf_ref = torch.isinf(x_mx_ref)
697+
inf_triton = torch.isinf(x_mx_t)
698+
assert torch.equal(inf_ref, inf_triton), (
699+
"Infinity pattern mismatch between reference and triton"
700+
)
701+
if inf_ref.any():
702+
assert torch.equal(x_mx_ref[inf_ref], x_mx_t[inf_ref]), (
703+
"Infinity values mismatch"
704+
)
705+
706+
# Check scales using exact comparison
707+
x_s_ref_uint8 = x_s_ref.to(torch.uint8)
708+
x_s_t_uint8 = x_s_t.to(torch.uint8)
709+
assert torch.equal(x_s_t_uint8, x_s_ref_uint8), (
710+
"Scale values mismatch between reference and triton"
711+
)
712+
713+
714+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
715+
@pytest.mark.skipif(
716+
not is_sm_at_least_100() and not is_MI350(),
717+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
718+
)
719+
@pytest.mark.parametrize(
720+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
721+
)
722+
def test_triton_mxfp8_dim0_overflow_underflow(scaling_mode):
723+
"""Test with values near overflow and underflow thresholds."""
724+
# Values near float8_e4m3fn limits
725+
f8_max = torch.finfo(torch.float8_e4m3fn).max # ~448
726+
f8_min = torch.finfo(torch.float8_e4m3fn).tiny # ~1.95e-06
727+
block_size = 32
728+
729+
overflow_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
730+
731+
# Fill first few elements of each row with overflow/underflow values
732+
overflow_vals[0, :4] = torch.tensor(
733+
[f8_max * 0.9, f8_max * 1.1, f8_max * 2.0, f8_max * 10.0], dtype=torch.bfloat16
734+
)
735+
overflow_vals[1, :4] = torch.tensor(
736+
[-f8_max * 0.9, -f8_max * 1.1, -f8_max * 2.0, -f8_max * 10.0],
737+
dtype=torch.bfloat16,
738+
)
739+
overflow_vals[2, :4] = torch.tensor(
740+
[f8_min * 0.1, f8_min * 0.5, f8_min * 2.0, f8_min * 10.0], dtype=torch.bfloat16
741+
)
742+
overflow_vals[3, :4] = torch.tensor(
743+
[-f8_min * 0.1, -f8_min * 0.5, -f8_min * 2.0, -f8_min * 10.0],
744+
dtype=torch.bfloat16,
745+
)
746+
747+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
748+
overflow_vals, block_size=block_size, scaling_mode=scaling_mode
749+
)
750+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
751+
overflow_vals,
752+
inner_block_size=block_size,
753+
scaling_mode=scaling_mode.value.lower(),
754+
)
755+
756+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
757+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
758+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
759+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
760+
761+
762+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
763+
@pytest.mark.skipif(
764+
not is_sm_at_least_100() and not is_MI350(),
765+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
766+
)
767+
@pytest.mark.parametrize(
768+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
769+
)
770+
def test_triton_mxfp8_dim0_extreme_range(scaling_mode):
771+
"""Test with tensors containing both very large and very small values."""
772+
# Mix of extreme values in same tensor to test scaling edge cases
773+
block_size = 32
774+
extreme_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
775+
776+
# Fill first few elements with extreme values
777+
extreme_vals[0, :4] = torch.tensor([1e30, 1e-30, 1e20, 1e-20], dtype=torch.bfloat16)
778+
extreme_vals[1, :4] = torch.tensor(
779+
[-1e30, -1e-30, -1e20, -1e-20], dtype=torch.bfloat16
780+
)
781+
extreme_vals[2, :4] = torch.tensor(
782+
[torch.finfo(torch.float32).max, torch.finfo(torch.float32).tiny, 1.0, -1.0],
783+
dtype=torch.bfloat16,
784+
)
785+
extreme_vals[3, :4] = torch.tensor([0.0, 1e-40, 1e40, -1e40], dtype=torch.bfloat16)
786+
787+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
788+
extreme_vals, block_size=block_size, scaling_mode=scaling_mode
789+
)
790+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
791+
extreme_vals,
792+
inner_block_size=block_size,
793+
scaling_mode=scaling_mode.value.lower(),
794+
)
795+
796+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
797+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
798+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
799+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
800+
801+
802+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
803+
@pytest.mark.skipif(
804+
not is_sm_at_least_100() and not is_MI350(),
805+
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
806+
)
807+
@pytest.mark.parametrize(
808+
"scaling_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
809+
)
810+
def test_triton_mxfp8_dim0_denormals_subnormals(scaling_mode):
811+
"""Test with denormal/subnormal values that might cause precision issues."""
812+
# Create values in the denormal range
813+
bf16_tiny = torch.finfo(torch.bfloat16).tiny
814+
f32_tiny = torch.finfo(torch.float32).tiny
815+
block_size = 32
816+
817+
denormal_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")
818+
819+
# Fill first few elements with denormal values
820+
denormal_vals[0, :4] = torch.tensor(
821+
[bf16_tiny, bf16_tiny * 0.5, bf16_tiny * 0.1, bf16_tiny * 2.0],
822+
dtype=torch.bfloat16,
823+
)
824+
denormal_vals[1, :4] = torch.tensor(
825+
[f32_tiny, f32_tiny * 0.5, f32_tiny * 0.1, f32_tiny * 2.0], dtype=torch.bfloat16
826+
)
827+
denormal_vals[2, :4] = torch.tensor(
828+
[-bf16_tiny, -bf16_tiny * 0.5, -bf16_tiny * 0.1, -bf16_tiny * 2.0],
829+
dtype=torch.bfloat16,
830+
)
831+
denormal_vals[3, :4] = torch.tensor(
832+
[1e-40, 1e-38, 1e-36, 1e-34], dtype=torch.bfloat16
833+
) # Very small values
834+
835+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
836+
denormal_vals, block_size=block_size, scaling_mode=scaling_mode
837+
)
838+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
839+
denormal_vals,
840+
inner_block_size=block_size,
841+
scaling_mode=scaling_mode.value.lower(),
842+
)
843+
844+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
845+
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
846+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
847+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

0 commit comments

Comments
 (0)