@@ -625,3 +625,223 @@ def test_cuda_mx_dim0_not_supported():
625625 rowwise = True ,
626626 colwise = False ,
627627 )
628+
629+
630+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
631+ @pytest .mark .skipif (
632+ not is_sm_at_least_100 () and not is_MI350 (),
633+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
634+ )
635+ def test_triton_mxfp8_dim0_special_values ():
636+ # Test only RCEIL mode to match canonical PyTorch behavior
637+ scaling_mode = ScaleCalculationMode .RCEIL
638+
639+ # Create tensor with special values - make it compatible with block_size=32
640+ block_size = 32
641+ special_vals = torch .zeros (2 , block_size , dtype = torch .bfloat16 , device = "cuda" )
642+
643+ # Fill first few elements of each row with special values
644+ special_vals [0 , :4 ] = torch .tensor (
645+ [float ("inf" ), - float ("inf" ), float ("nan" ), 0.0 ], dtype = torch .bfloat16
646+ )
647+ special_vals [1 , :4 ] = torch .tensor (
648+ [
649+ torch .finfo (torch .float32 ).max ,
650+ torch .finfo (torch .float32 ).min ,
651+ torch .finfo (torch .float32 ).tiny ,
652+ - torch .finfo (torch .float32 ).tiny ,
653+ ],
654+ dtype = torch .bfloat16 ,
655+ )
656+
657+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
658+ special_vals , block_size = block_size , scaling_mode = scaling_mode
659+ )
660+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
661+ special_vals ,
662+ inner_block_size = block_size ,
663+ scaling_mode = scaling_mode .value .lower (),
664+ )
665+ x_mx_t = x_mx_t .to (torch .float32 )
666+ x_s_t = x_s_t .to (torch .uint8 )
667+ x_mx_ref = x_mx_ref .to (torch .float32 )
668+ x_s_ref = x_s_ref .to (torch .uint8 )
669+
670+ # Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
671+ input_has_nan = special_vals .isnan ().any ()
672+ if not input_has_nan :
673+ assert not x_mx_t .isnan ().any (), (
674+ "quantized tensor should not contain NaNs when input has no NaNs"
675+ )
676+ assert not x_s_t .isnan ().any (), (
677+ "scales should not contain NaNs when input has no NaNs"
678+ )
679+
680+ # Use NaN-aware comparison to handle nan != nan case properly
681+ # Check NaN patterns match
682+ nan_ref = torch .isnan (x_mx_ref )
683+ nan_triton = torch .isnan (x_mx_t )
684+ assert torch .equal (nan_ref , nan_triton ), (
685+ "NaN pattern mismatch between reference and triton"
686+ )
687+
688+ # Check finite values
689+ finite_mask = torch .isfinite (x_mx_ref ) & torch .isfinite (x_mx_t )
690+ if finite_mask .any ():
691+ assert torch .equal (x_mx_ref [finite_mask ], x_mx_t [finite_mask ]), (
692+ "Finite values mismatch"
693+ )
694+
695+ # Check infinity patterns
696+ inf_ref = torch .isinf (x_mx_ref )
697+ inf_triton = torch .isinf (x_mx_t )
698+ assert torch .equal (inf_ref , inf_triton ), (
699+ "Infinity pattern mismatch between reference and triton"
700+ )
701+ if inf_ref .any ():
702+ assert torch .equal (x_mx_ref [inf_ref ], x_mx_t [inf_ref ]), (
703+ "Infinity values mismatch"
704+ )
705+
706+ # Check scales using exact comparison
707+ x_s_ref_uint8 = x_s_ref .to (torch .uint8 )
708+ x_s_t_uint8 = x_s_t .to (torch .uint8 )
709+ assert torch .equal (x_s_t_uint8 , x_s_ref_uint8 ), (
710+ "Scale values mismatch between reference and triton"
711+ )
712+
713+
714+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
715+ @pytest .mark .skipif (
716+ not is_sm_at_least_100 () and not is_MI350 (),
717+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
718+ )
719+ @pytest .mark .parametrize (
720+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
721+ )
722+ def test_triton_mxfp8_dim0_overflow_underflow (scaling_mode ):
723+ """Test with values near overflow and underflow thresholds."""
724+ # Values near float8_e4m3fn limits
725+ f8_max = torch .finfo (torch .float8_e4m3fn ).max # ~448
726+ f8_min = torch .finfo (torch .float8_e4m3fn ).tiny # ~1.95e-06
727+ block_size = 32
728+
729+ overflow_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
730+
731+ # Fill first few elements of each row with overflow/underflow values
732+ overflow_vals [0 , :4 ] = torch .tensor (
733+ [f8_max * 0.9 , f8_max * 1.1 , f8_max * 2.0 , f8_max * 10.0 ], dtype = torch .bfloat16
734+ )
735+ overflow_vals [1 , :4 ] = torch .tensor (
736+ [- f8_max * 0.9 , - f8_max * 1.1 , - f8_max * 2.0 , - f8_max * 10.0 ],
737+ dtype = torch .bfloat16 ,
738+ )
739+ overflow_vals [2 , :4 ] = torch .tensor (
740+ [f8_min * 0.1 , f8_min * 0.5 , f8_min * 2.0 , f8_min * 10.0 ], dtype = torch .bfloat16
741+ )
742+ overflow_vals [3 , :4 ] = torch .tensor (
743+ [- f8_min * 0.1 , - f8_min * 0.5 , - f8_min * 2.0 , - f8_min * 10.0 ],
744+ dtype = torch .bfloat16 ,
745+ )
746+
747+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
748+ overflow_vals , block_size = block_size , scaling_mode = scaling_mode
749+ )
750+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
751+ overflow_vals ,
752+ inner_block_size = block_size ,
753+ scaling_mode = scaling_mode .value .lower (),
754+ )
755+
756+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
757+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
758+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
759+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
760+
761+
762+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
763+ @pytest .mark .skipif (
764+ not is_sm_at_least_100 () and not is_MI350 (),
765+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
766+ )
767+ @pytest .mark .parametrize (
768+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
769+ )
770+ def test_triton_mxfp8_dim0_extreme_range (scaling_mode ):
771+ """Test with tensors containing both very large and very small values."""
772+ # Mix of extreme values in same tensor to test scaling edge cases
773+ block_size = 32
774+ extreme_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
775+
776+ # Fill first few elements with extreme values
777+ extreme_vals [0 , :4 ] = torch .tensor ([1e30 , 1e-30 , 1e20 , 1e-20 ], dtype = torch .bfloat16 )
778+ extreme_vals [1 , :4 ] = torch .tensor (
779+ [- 1e30 , - 1e-30 , - 1e20 , - 1e-20 ], dtype = torch .bfloat16
780+ )
781+ extreme_vals [2 , :4 ] = torch .tensor (
782+ [torch .finfo (torch .float32 ).max , torch .finfo (torch .float32 ).tiny , 1.0 , - 1.0 ],
783+ dtype = torch .bfloat16 ,
784+ )
785+ extreme_vals [3 , :4 ] = torch .tensor ([0.0 , 1e-40 , 1e40 , - 1e40 ], dtype = torch .bfloat16 )
786+
787+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
788+ extreme_vals , block_size = block_size , scaling_mode = scaling_mode
789+ )
790+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
791+ extreme_vals ,
792+ inner_block_size = block_size ,
793+ scaling_mode = scaling_mode .value .lower (),
794+ )
795+
796+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
797+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
798+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
799+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
800+
801+
802+ @pytest .mark .skipif (not has_triton (), reason = "unsupported without triton" )
803+ @pytest .mark .skipif (
804+ not is_sm_at_least_100 () and not is_MI350 (),
805+ reason = "mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater." ,
806+ )
807+ @pytest .mark .parametrize (
808+ "scaling_mode" , (ScaleCalculationMode .FLOOR , ScaleCalculationMode .RCEIL )
809+ )
810+ def test_triton_mxfp8_dim0_denormals_subnormals (scaling_mode ):
811+ """Test with denormal/subnormal values that might cause precision issues."""
812+ # Create values in the denormal range
813+ bf16_tiny = torch .finfo (torch .bfloat16 ).tiny
814+ f32_tiny = torch .finfo (torch .float32 ).tiny
815+ block_size = 32
816+
817+ denormal_vals = torch .zeros (4 , block_size , dtype = torch .bfloat16 , device = "cuda" )
818+
819+ # Fill first few elements with denormal values
820+ denormal_vals [0 , :4 ] = torch .tensor (
821+ [bf16_tiny , bf16_tiny * 0.5 , bf16_tiny * 0.1 , bf16_tiny * 2.0 ],
822+ dtype = torch .bfloat16 ,
823+ )
824+ denormal_vals [1 , :4 ] = torch .tensor (
825+ [f32_tiny , f32_tiny * 0.5 , f32_tiny * 0.1 , f32_tiny * 2.0 ], dtype = torch .bfloat16
826+ )
827+ denormal_vals [2 , :4 ] = torch .tensor (
828+ [- bf16_tiny , - bf16_tiny * 0.5 , - bf16_tiny * 0.1 , - bf16_tiny * 2.0 ],
829+ dtype = torch .bfloat16 ,
830+ )
831+ denormal_vals [3 , :4 ] = torch .tensor (
832+ [1e-40 , 1e-38 , 1e-36 , 1e-34 ], dtype = torch .bfloat16
833+ ) # Very small values
834+
835+ x_mx_ref , x_s_ref = triton_to_mxfp8_dim0_reference (
836+ denormal_vals , block_size = block_size , scaling_mode = scaling_mode
837+ )
838+ x_mx_t , x_s_t = triton_to_mxfp8_dim0 (
839+ denormal_vals ,
840+ inner_block_size = block_size ,
841+ scaling_mode = scaling_mode .value .lower (),
842+ )
843+
844+ assert not x_mx_t .isnan ().any (), "quantized tensor should not contain NaNs"
845+ assert not x_s_t .isnan ().any (), "scales should not contain NaNs"
846+ torch .testing .assert_close (x_mx_t , x_mx_ref , rtol = 0 , atol = 0 )
847+ torch .testing .assert_close (x_s_t , x_s_ref , rtol = 0 , atol = 0 )
0 commit comments