Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 263 additions & 0 deletions test/prototype/mx_formats/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,266 @@ def test_cuda_mx_dim0_not_supported():
rowwise=True,
colwise=False,
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100() and not is_MI350(),
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
)
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_triton_mxfp8_dim0_special_values(scaling_mode: ScaleCalculationMode):
# Create tensor with special values - make it compatible with block_size=32
block_size = 32
special_vals = torch.zeros(2, block_size, dtype=torch.bfloat16, device="cuda")

# Fill first few elements of each row with special values
special_vals[0, :4] = torch.tensor(
[float("inf"), -float("inf"), float("nan"), 0.0], dtype=torch.bfloat16
)
special_vals[1, :4] = torch.tensor(
[
torch.finfo(torch.float32).max,
torch.finfo(torch.float32).min,
torch.finfo(torch.float32).tiny,
-torch.finfo(torch.float32).tiny,
],
dtype=torch.bfloat16,
)

x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
special_vals, block_size=block_size, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
special_vals,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)
x_mx_t = x_mx_t.to(torch.float32)
x_s_t = x_s_t.to(torch.uint8)
x_mx_ref = x_mx_ref.to(torch.float32)
x_s_ref = x_s_ref.to(torch.uint8)

# Check for NaNs in output (allow NaNs if input had NaNs, but check scales)
input_has_nan = special_vals.isnan().any()
if not input_has_nan:
assert not x_mx_t.isnan().any(), (
"quantized tensor should not contain NaNs when input has no NaNs"
)
assert not x_s_t.isnan().any(), (
"scales should not contain NaNs when input has no NaNs"
)

# Use NaN-aware comparison to handle nan != nan case properly
# Check NaN patterns match
nan_ref = torch.isnan(x_mx_ref)
nan_triton = torch.isnan(x_mx_t)
assert torch.equal(nan_ref, nan_triton), (
"NaN pattern mismatch between reference and triton"
)

# Check finite values
finite_mask = torch.isfinite(x_mx_ref) & torch.isfinite(x_mx_t)
if finite_mask.any():
assert torch.equal(x_mx_ref[finite_mask], x_mx_t[finite_mask]), (
"Finite values mismatch"
)

# Check infinity patterns
inf_ref = torch.isinf(x_mx_ref)
inf_triton = torch.isinf(x_mx_t)
assert torch.equal(inf_ref, inf_triton), (
"Infinity pattern mismatch between reference and triton"
)
if inf_ref.any():
assert torch.equal(x_mx_ref[inf_ref], x_mx_t[inf_ref]), (
"Infinity values mismatch"
)

# Check scales using exact comparison
x_s_ref_uint8 = x_s_ref.to(torch.uint8)
x_s_t_uint8 = x_s_t.to(torch.uint8)
assert torch.equal(x_s_t_uint8, x_s_ref_uint8), (
"Scale values mismatch between reference and triton"
)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100() and not is_MI350(),
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
)
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_triton_mxfp8_dim0_overflow_underflow(scaling_mode):
Comment thread
slayton58 marked this conversation as resolved.
"""Test with values near overflow and underflow thresholds."""
# Values near float8_e4m3fn limits
f8_max = torch.finfo(torch.float8_e4m3fn).max # ~448
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny nit, we have exact values for these, the approx ~ shouldn't be there. max == 448, min subnormal == 2^-9.

f8_min = torch.finfo(torch.float8_e4m3fn).tiny # ~1.95e-06
block_size = 32

overflow_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")

# Fill first few elements of each row with overflow/underflow values
overflow_vals[0, :4] = torch.tensor(
[f8_max * 0.9, f8_max * 1.1, f8_max * 2.0, f8_max * 10.0], dtype=torch.bfloat16
)
overflow_vals[1, :4] = torch.tensor(
[-f8_max * 0.9, -f8_max * 1.1, -f8_max * 2.0, -f8_max * 10.0],
dtype=torch.bfloat16,
)
overflow_vals[2, :4] = torch.tensor(
[f8_min * 0.1, f8_min * 0.5, f8_min * 2.0, f8_min * 10.0], dtype=torch.bfloat16
)
overflow_vals[3, :4] = torch.tensor(
[-f8_min * 0.1, -f8_min * 0.5, -f8_min * 2.0, -f8_min * 10.0],
dtype=torch.bfloat16,
)

x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
overflow_vals, block_size=block_size, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
overflow_vals,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)

assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100() and not is_MI350(),
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
)
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_triton_mxfp8_dim0_extreme_range(scaling_mode):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment on what's being tested - we're testing that things match a reference, but not necessarily that the observed behaviour is actually correct -- both imo should be tested. This is also a good candidate to merge with the overflow/underflow test above, as I think they only differ in terms of what the extreme values are (which is parametrize-able)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I deleted some tests and focused on having fewer, higher value tests.

"""Test with tensors containing both very large and very small values."""
# Mix of extreme values in same tensor to test scaling edge cases
block_size = 32
extreme_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")

# Fill first few elements with extreme values
extreme_vals[0, :4] = torch.tensor([1e30, 1e-30, 1e20, 1e-20], dtype=torch.bfloat16)
extreme_vals[1, :4] = torch.tensor(
[-1e30, -1e-30, -1e20, -1e-20], dtype=torch.bfloat16
)
extreme_vals[2, :4] = torch.tensor(
[torch.finfo(torch.float32).max, torch.finfo(torch.float32).tiny, 1.0, -1.0],
dtype=torch.bfloat16,
)
extreme_vals[3, :4] = torch.tensor([0.0, 1e-40, 1e40, -1e40], dtype=torch.bfloat16)

x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
extreme_vals, block_size=block_size, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
extreme_vals,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)

assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(
not is_sm_at_least_100() and not is_MI350(),
reason="mxfp8 requires CUDA capability 10.0 or greater or ROCm gfx950 or greater.",
)
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_triton_mxfp8_dim0_denormals_subnormals(scaling_mode):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment on what's being tested.

"""Test with denormal/subnormal values that might cause precision issues."""
# Create values in the denormal range
bf16_tiny = torch.finfo(torch.bfloat16).tiny
f32_tiny = torch.finfo(torch.float32).tiny
block_size = 32

denormal_vals = torch.zeros(4, block_size, dtype=torch.bfloat16, device="cuda")

# Fill first few elements with denormal values
denormal_vals[0, :4] = torch.tensor(
[bf16_tiny, bf16_tiny * 0.5, bf16_tiny * 0.1, bf16_tiny * 2.0],
dtype=torch.bfloat16,
)
denormal_vals[1, :4] = torch.tensor(
[f32_tiny, f32_tiny * 0.5, f32_tiny * 0.1, f32_tiny * 2.0], dtype=torch.bfloat16
)
denormal_vals[2, :4] = torch.tensor(
[-bf16_tiny, -bf16_tiny * 0.5, -bf16_tiny * 0.1, -bf16_tiny * 2.0],
dtype=torch.bfloat16,
)
denormal_vals[3, :4] = torch.tensor(
[1e-40, 1e-38, 1e-36, 1e-34], dtype=torch.bfloat16
) # Very small values

x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(
denormal_vals, block_size=block_size, scaling_mode=scaling_mode
)
x_mx_t, x_s_t = triton_to_mxfp8_dim0(
denormal_vals,
inner_block_size=block_size,
scaling_mode=scaling_mode.value.lower(),
)

assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
assert not x_s_t.isnan().any(), "scales should not contain NaNs"
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)


@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.RCEIL,))
def test_all_nan_block_scale_behavior(scaling_mode):
Comment thread
slayton58 marked this conversation as resolved.
"""
Test that PyTorch, CUDA, and Triton implementations align on NaN scale behavior:
- Mixed real + NaN: scale = max of real values (ignore NaNs)
- All NaN block: scale = NaN (not -inf)
"""
from torchao.prototype.mx_formats.mx_tensor import to_mx

block_size = 32

# Create test case with both mixed and all-NaN blocks
# First 32 elements: mixed NaN + real values
# Second 32 elements: all NaN values
# Third 32 elements: normal values for reference
test_vals = torch.zeros(3 * block_size, dtype=torch.bfloat16, device="cuda")

# Block 1: Mixed NaN + real values [NaN, 1.0, NaN, 5.0, NaN, 3.0, ...]
test_vals[:block_size:3] = float("nan") # Every 3rd element is NaN
test_vals[1:block_size:3] = 1.0 # Some real values
test_vals[2:block_size:3] = 5.0 # Some larger real values

# Block 2: All NaN values
test_vals[block_size : 2 * block_size] = float("nan")

# Block 3: Normal values for reference
test_vals[2 * block_size :] = torch.linspace(1.0, 10.0, block_size)

# Test PyTorch implementation through to_mx
scale_pytorch, _ = to_mx(test_vals, torch.float8_e4m3fn, block_size, scaling_mode)

# Convert to regular tensor for easier inspection
scale_pytorch_vals = scale_pytorch.to(torch.float32)

# Test expectations:
# Block 0 (mixed): Should have real scale value (not NaN), based on max real value
assert not torch.isnan(scale_pytorch_vals[0]), (
"Mixed NaN+real block should have real scale (PyTorch should ignore NaNs)"
)

# Block 1 (all NaN): Should have NaN scale to match CUDA/Triton behavior
assert torch.isnan(scale_pytorch_vals[1]), (
"All-NaN block should have NaN scale to match CUDA/Triton behavior"
)

# Block 2 (normal): Should have real scale
assert not torch.isnan(scale_pytorch_vals[2]), "Normal block should have real scale"
68 changes: 68 additions & 0 deletions test/prototype/mx_formats/test_mx_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,74 @@ def test_exponent_nan_in(elem_dtype):
assert not torch.any(torch.isnan(tensor_mx.scale[1:]))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_all_nan_blocks(elem_dtype):
Comment thread
slayton58 marked this conversation as resolved.
"""
Test NaN handling for blocks with all NaN values vs mixed NaN + real values.
Verifies that PyTorch implementation aligns with CUDA/Triton behavior:
- Mixed real + NaN: scale = max of real values
- All NaN: scale = NaN
"""
block_size = 4

# Test case 1: Mixed NaN + real values (should ignore NaNs, use real values)
mixed_tensor = torch.tensor(
[float("nan"), 2.0, float("nan"), 4.0, 1.0, 3.0, 5.0, 2.0],
device="cuda",
dtype=torch.bfloat16,
)
mixed_mx = MXTensor.to_mx(mixed_tensor, elem_dtype, block_size)

# First block [NaN, 2.0, NaN, 4.0] should have scale based on max(2.0, 4.0) = 4.0
# Second block [1.0, 3.0, 5.0, 2.0] should have scale based on max = 5.0
assert not torch.isnan(mixed_mx.scale[0]), (
"Mixed NaN+real block should not have NaN scale"
)
assert not torch.isnan(mixed_mx.scale[1]), (
"Real-only block should not have NaN scale"
)

# Test case 2: All NaN blocks (should return NaN scale)
all_nan_tensor = torch.tensor(
[float("nan"), float("nan"), float("nan"), float("nan"), 1.0, 2.0, 3.0, 4.0],
device="cuda",
dtype=torch.bfloat16,
)
all_nan_mx = MXTensor.to_mx(all_nan_tensor, elem_dtype, block_size)

# First block [NaN, NaN, NaN, NaN] should have NaN scale (matches CUDA/Triton)
# Second block [1.0, 2.0, 3.0, 4.0] should have real scale
assert torch.isnan(all_nan_mx.scale[0]), (
"All-NaN block should have NaN scale to match CUDA/Triton"
)
assert not torch.isnan(all_nan_mx.scale[1]), (
"Real-only block should not have NaN scale"
)

# Test case 3: Completely all NaN tensor
completely_nan_tensor = torch.tensor(
[
float("nan"),
float("nan"),
float("nan"),
float("nan"),
float("nan"),
float("nan"),
float("nan"),
float("nan"),
],
device="cuda",
dtype=torch.bfloat16,
)
completely_nan_mx = MXTensor.to_mx(completely_nan_tensor, elem_dtype, block_size)

# Both blocks should have NaN scales
assert torch.all(torch.isnan(completely_nan_mx.scale)), (
"All-NaN tensor should have all NaN scales"
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
def test_exponent_nan_out(elem_dtype):
Expand Down
Loading
Loading