Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
315 changes: 315 additions & 0 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,318 @@ def test_cuda_mx_dim0_not_supported():
rowwise=True,
colwise=False,
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100() and not is_MI350(),
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
)
@pytest.mark.skipif(
not is_cuda_version_at_least(12, 8),
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
)
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_triton_mxfp8_dim0_special_values(scaling_mode: ScaleCalculationMode):
# Create tensor with special values - make it compatible with block_size=32
block_size = 32
special_vals = torch.zeros(2, block_size, dtype=torch.bfloat16, device="cuda")

# Fill first few elements of each row with special values
special_vals[0, :4] = torch.tensor(
[float("inf"), -float("inf"), float("nan"), 0.0], dtype=torch.bfloat16
)
special_vals[1, :4] = torch.tensor(
[
torch.finfo(torch.float32).max,
torch.finfo(torch.float32).min,
torch.finfo(torch.float32).tiny,
-torch.finfo(torch.float32).tiny,
],
dtype=torch.bfloat16,
)

x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
special_vals, block_size=block_size, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
special_vals,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)
x_mx_t = x_mx_t.to(torch.float32)
x_mx_ref = x_mx_ref.to(torch.float32)

# Check NaN behavior: if any value in a block is NaN, scale and entire block become NaN
for row_idx in range(special_vals.shape[0]):
input_block_has_nan = special_vals[row_idx].isnan().any()

if input_block_has_nan:
# If any value in block is NaN, scale should be NaN
assert torch.isnan(x_s_t[row_idx].to(torch.float32)), (
f"Row {row_idx}: Block with any NaN should have NaN scale"
)
# And entire quantized block should be NaN
assert torch.all(torch.isnan(x_mx_t[row_idx])), (
f"Row {row_idx}: Block with any NaN should have all NaN quantized values"
)
else:
# If no NaN in input block, scale and data should not be NaN
assert not torch.isnan(x_s_t[row_idx].to(torch.float32)), (
f"Row {row_idx}: Block without NaN should not have NaN scale"
)

# Use NaN-aware comparison to handle nan != nan case properly
# Check NaN patterns match
nan_ref = torch.isnan(x_mx_ref)
nan_triton = torch.isnan(x_mx_t)
assert torch.equal(nan_ref, nan_triton), (
"NaN pattern mismatch between reference and triton"
)

# Check finite values
finite_mask = torch.isfinite(x_mx_ref) & torch.isfinite(x_mx_t)
if finite_mask.any():
assert torch.equal(x_mx_ref[finite_mask], x_mx_t[finite_mask]), (
"Finite values mismatch"
)

# Check infinity patterns
inf_ref = torch.isinf(x_mx_ref)
inf_triton = torch.isinf(x_mx_t)
assert torch.equal(inf_ref, inf_triton), (
"Infinity pattern mismatch between reference and triton"
)
if inf_ref.any():
assert torch.equal(x_mx_ref[inf_ref], x_mx_t[inf_ref]), (
"Infinity values mismatch"
)

# Check scales using exact comparison
x_s_ref_uint8 = x_s_ref.to(torch.uint8)
x_s_t_uint8 = x_s_t.to(torch.uint8)
assert torch.equal(x_s_t_uint8, x_s_ref_uint8), (
"Scale values mismatch between reference and triton"
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100() and not is_MI350(),
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
)
@pytest.mark.skipif(
not is_cuda_version_at_least(12, 8),
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
)
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_triton_mxfp8_dim0_overflow_underflow(scaling_mode):
Comment thread
slayton58 marked this conversation as resolved.
"""Test with values near overflow and underflow thresholds."""
fp8_max = torch.finfo(torch.float8_e4m3fn).max
fp8_subnormal_min = 2e-9 # smallest positive subnormal for e4m3: https://www.emergentmind.com/topics/mxfp8-e4m3-floating-point-format
block_size = 32

test_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")

# Row 0: elem 0 is near max, elems 1-3 are above max
test_vals[0, :4] = torch.tensor(
[fp8_max * 0.9, fp8_max * 1.1, fp8_max * 2.0, fp8_max * 10.0],
dtype=torch.bfloat16,
)

# Row 1: elem 0 is near min, elems 1-3 are below min
test_vals[1, :4] = torch.tensor(
[-fp8_max * 0.9, -fp8_max * 1.1, -fp8_max * 2.0, -fp8_max * 10.0],
dtype=torch.bfloat16,
)

# Row 2: elem 0-1 are below positive subnormal min representable in e4m3, should underflow to zero if scaled down
test_vals[2, :3] = torch.tensor(
[
fp8_subnormal_min * 0.1,
fp8_subnormal_min * 0.5,
fp8_max
* 0.9, # include a large value to result in scale that would underflow the subnormals
],
dtype=torch.bfloat16,
)
# Row 3: elem 0-1 are above below negative subnormal min, should underflow to zero
test_vals[3, :3] = torch.tensor(
[
-fp8_subnormal_min * 0.1,
-fp8_subnormal_min * 0.5,
fp8_max
* 0.9, # include a large value to result in scale that would underflow the subnormals
],
dtype=torch.bfloat16,
)

x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
test_vals, block_size=block_size, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
test_vals,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)

# Test 1: Verify triton matches reference
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)

dequantized = to_dtype(
x_mx_t,
x_s_t.view(torch.float8_e8m0fnu),
torch.float8_e4m3fn,
block_size,
torch.bfloat16,
)

# Verify quantization preserves sign
original_signbits = torch.signbit(test_vals)
dequant_signbits = torch.signbit(dequantized)
assert torch.equal(original_signbits, dequant_signbits), (
"Sign bit mismatch between original and dequantized values"
)

# Verify underflow behavior
# Check rows 2 and 3 which contain underflow test cases
for row_idx in [2, 3]:
# The first two elements should be scaled below the min representable subnormal in e4m3, and thus underflow to zero
assert torch.all(dequantized[row_idx, :2] == 0.0), (
f"Row {row_idx}: should underflow to zero"
)
# Normal val shouldn't underflow
assert torch.all(dequantized[row_idx, 2] != 0.0), (
f"Row {row_idx}: should not underflow to zero"
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.skipif(
not is_cuda_version_at_least(12, 8), reason="CUDA version >= 12.8 required"
)
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_all_nan_block_scale_behavior(scaling_mode):
Comment thread
slayton58 marked this conversation as resolved.
"""
Test that PyTorch and Triton implementations align on NaN scale behavior:
- Any NaN in block: scale = NaN, entire quantized block becomes NaN
"""
from torchao.prototype.mx_formats.mx_tensor import to_mx

block_size = 32

# Create test case with both mixed and all-NaN blocks
# First 32 elements: mixed NaN + real values
# Second 32 elements: all NaN values
# Third 32 elements: normal values for reference
test_vals = torch.zeros(3 * block_size, dtype=torch.bfloat16, device="cuda")

# Block 1: Mixed NaN + real values [NaN, 1.0, NaN, 5.0, NaN, 3.0, ...]
test_vals[:block_size:3] = float("nan") # Every 3rd element is NaN
test_vals[1:block_size:3] = 1.0 # Some real values
test_vals[2:block_size:3] = 5.0

# Block 2: All NaN values
test_vals[block_size : 2 * block_size] = float("nan")

# Block 3: Normal values for reference
test_vals[2 * block_size :] = torch.linspace(1.0, 10.0, block_size)

# Test PyTorch implementation through to_mx
scale_pytorch, data_pytorch = to_mx(
test_vals, torch.float8_e4m3fn, block_size, scaling_mode
)

# Convert to regular tensor for easier inspection
scale_pytorch_vals = scale_pytorch.to(torch.float32)
data_pytorch_vals = data_pytorch.to(torch.float32)

# Test expectations: If any value in a block is NaN, scale = NaN and entire block becomes NaN

# Block 0 (mixed NaN + real): Should have NaN scale and all NaN data
assert torch.isnan(scale_pytorch_vals[0]), (
"Block with any NaN should have NaN scale"
)
assert torch.all(torch.isnan(data_pytorch_vals[:block_size])), (
"Block with any NaN should have all NaN quantized values"
)

# Block 1 (all NaN): Should have NaN scale and all NaN data
assert torch.isnan(scale_pytorch_vals[1]), "All-NaN block should have NaN scale"
assert torch.all(torch.isnan(data_pytorch_vals[block_size : 2 * block_size])), (
"All-NaN block should have all NaN quantized values"
)

# Block 2 (normal): Should have real scale and finite data
assert not torch.isnan(scale_pytorch_vals[2]), "Normal block should have real scale"
assert torch.all(torch.isfinite(data_pytorch_vals[2 * block_size :])), (
"Normal block should have finite quantized values"
)

# Also test the Triton implementation to ensure consistency
test_vals_2d = test_vals.reshape(3, block_size)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
test_vals_2d,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)

# Convert for comparison
x_s_t_vals = x_s_t.to(torch.float32)
x_mx_t_vals = x_mx_t.to(torch.float32)

# Test Triton implementation matches PyTorch behavior
# Block 0 (mixed NaN + real): Should have NaN scale and all NaN data
assert torch.isnan(x_s_t_vals[0]), (
"Triton: Block with any NaN should have NaN scale"
)
assert torch.all(torch.isnan(x_mx_t_vals[0])), (
"Triton: Block with any NaN should have all NaN quantized values"
)

# Block 1 (all NaN): Should have NaN scale and all NaN data
assert torch.isnan(x_s_t_vals[1]), "Triton: All-NaN block should have NaN scale"
assert torch.all(torch.isnan(x_mx_t_vals[1])), (
"Triton: All-NaN block should have all NaN quantized values"
)

# Block 2 (normal): Should have real scale and finite data
assert not torch.isnan(x_s_t_vals[2]), "Triton: Normal block should have real scale"
assert torch.all(torch.isfinite(x_mx_t_vals[2])), (
"Triton: Normal block should have finite quantized values"
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100() and not is_MI350(),
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
)
@pytest.mark.skipif(
not is_cuda_version_at_least(12, 8),
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
)
@pytest.mark.parametrize(
"scaling_mode", (ScaleCalculationMode.RCEIL, ScaleCalculationMode.FLOOR)
)
def test_triton_mxfp8_dim0_large_tensor_offset_no_overflow(scaling_mode):
"""Test with large tensor whose offsets exceeds the max int32 value."""
x = torch.randn((184320, 14336), dtype=torch.bfloat16, device="cuda")
block_size = 32
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
x, block_size=block_size, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
x,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)

assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
Loading
Loading