Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
#!/usr/bin/env python3

import torch

from torchao.float8.float8_utils import compute_error
from torchao.prototype.blockwise_fp8_training.kernels import (
triton_fp8_blockwise_act_quant_lhs,
triton_fp8_blockwise_weight_quant_rhs,
triton_fp8_gemm_1x128_128x128,
)


def dequantize_lhs_block_1x128(a_q: torch.Tensor, a_s: torch.Tensor) -> torch.Tensor:
return a_q.float() * a_s.repeat_interleave(128, dim=1)[:, : a_q.size(1)]


def dequantize_rhs_block_128x128(
b_q: torch.Tensor, b_s: torch.Tensor
) -> torch.Tensor:
scales = b_s.repeat_interleave(128, dim=0).repeat_interleave(128, dim=1)
return b_q.float() * scales[: b_q.size(0), : b_q.size(1)]


def run_case(k: int, m: int = 256, n: int = 256) -> dict[str, float | int]:
torch.manual_seed(0)
weight = torch.nn.Linear(n, k, bias=False, device="cuda").weight.detach().contiguous()
grad_output = torch.ones(m, k, device="cuda", dtype=torch.float32)

grad_output_fp8, grad_output_scale = triton_fp8_blockwise_act_quant_lhs(
grad_output, 128
)
weight_fp8, weight_scale = triton_fp8_blockwise_weight_quant_rhs(weight, 128)

fp32_ref = grad_output @ weight
dequant_ref = dequantize_lhs_block_1x128(
grad_output_fp8, grad_output_scale
) @ dequantize_rhs_block_128x128(weight_fp8, weight_scale)
scaled_mm_out = torch._scaled_mm(
grad_output_fp8,
weight_fp8,
grad_output_scale,
weight_scale,
out_dtype=torch.bfloat16,
)
triton_out = triton_fp8_gemm_1x128_128x128(
grad_output_fp8,
weight_fp8,
grad_output_scale,
weight_scale,
out_dtype=torch.bfloat16,
)

return {
"k": k,
"k_blocks": k // 128,
"scaled_mm_sqnr_vs_fp32": float(compute_error(fp32_ref, scaled_mm_out)),
"triton_sqnr_vs_fp32": float(compute_error(fp32_ref, triton_out)),
"scaled_mm_sqnr_vs_dequant": float(compute_error(dequant_ref, scaled_mm_out)),
"triton_sqnr_vs_dequant": float(compute_error(dequant_ref, triton_out)),
"scaled_mm_norm": float(scaled_mm_out.float().norm()),
"triton_norm": float(triton_out.float().norm()),
"fp32_ref_norm": float(fp32_ref.norm()),
"dequant_ref_norm": float(dequant_ref.norm()),
}


def main() -> None:
if not torch.cuda.is_available():
raise SystemExit("CUDA is required")

for k in (128000, 128128, 128256):
result = run_case(k)
print(result)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ def test_blockwise_quant_linear_fwd_bwd(
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"

# Compare weight grads
sqnr = compute_error(layer_ref.weight, layer_test.weight)
sqnr = compute_error(layer_ref.weight.grad, layer_test.weight.grad)
assert not layer_test.weight.grad.isnan().any(), "Weight grad must not contain NaNs"
assert sqnr >= 30.0, f"SQNR: {sqnr} must be >= 25.0"
13 changes: 4 additions & 9 deletions torchao/prototype/blockwise_fp8_training/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def backward(ctx, grad_output):
x, weight = ctx.saved_tensors
block_size = ctx.block_size
out_dtype = ctx.out_dtype
use_triton = ctx.use_triton

# Reshape input to 2D
x_orig_shape = x.shape
Expand All @@ -86,10 +85,9 @@ def backward(ctx, grad_output):
)

# grad_x = grad_output @ weight
fp8_gemm_1x128_128x128 = (
triton_fp8_gemm_1x128_128x128 if use_triton else torch._scaled_mm
)
grad_x = fp8_gemm_1x128_128x128(
# The torch._scaled_mm path is numerically unstable for this RHS blockwise
# layout in backward, so use the Triton GEMM implementation for gradients.
grad_x = triton_fp8_gemm_1x128_128x128(
grad_output_fp8,
weight_fp8,
grad_output_scale,
Expand All @@ -113,10 +111,7 @@ def backward(ctx, grad_output):
x_fp8, x_scale = triton_fp8_blockwise_act_quant_rhs(x, block_size)

# grad_weight = grad_output.T @ x
fp8_gemm_1x128_128x1 = (
triton_fp8_gemm_1x128_128x1 if use_triton else torch._scaled_mm
)
grad_weight = fp8_gemm_1x128_128x1(
grad_weight = triton_fp8_gemm_1x128_128x1(
grad_output_t_fp8,
x_fp8,
grad_output_t_scale,
Expand Down
Loading