[nvfp4_training] Add Triton kernel for global amax of columnwise RHT (SM90+)#4247
[nvfp4_training] Add Triton kernel for global amax of columnwise RHT (SM90+)#4247rdspring1 wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4247
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
|
||
|
|
||
| def get_wgrad_sign_vector(device) -> torch.Tensor: | ||
| """Hard-coded random signs for Hadamard transform.""" |
There was a problem hiding this comment.
Wait, why is this hard-coded vs. generated?
There was a problem hiding this comment.
Converted get_wgrad_sign_vector to generate random sign vector.
| [1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1], | ||
| [1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1], | ||
| ], | ||
| dtype=torch.float32, |
There was a problem hiding this comment.
Is there a reason this needs to be 32b? Won't we save BW later if we store this natively in something lower (4b for instance), along with avoiding casts (See here)
There was a problem hiding this comment.
Added dtype argument with torch.bfloat16 default
|
|
||
|
|
||
| @functools.lru_cache(maxsize=None) | ||
| def get_rht_matrix(with_random_sign_mask: bool, device) -> torch.Tensor: |
There was a problem hiding this comment.
Should have argument for hadamard dimension (kwarg-only perhaps, default=16?). Should probably also have an option dtype argument too (seeing notes above)
There was a problem hiding this comment.
we could add the kwarg with default 16 and for now just assert == 16, wdyt? only support 16 for now, build out prototype quickly and support generating other RHT matrix sizes later?
There was a problem hiding this comment.
Yep, that's perfectly fine - I'm concerned with getting the API as correct as possible here
There was a problem hiding this comment.
Changes get_rht_matrix to
def get_rht_matrix(
sign_vector: tuple[int, ...] | None,
device,
dtype: torch.dtype = torch.bfloat16,
hadamard_dimension: int = 16,
) -> torch.Tensor:- If
sign_vectorisNone, it callsget_wgrad_sign_vector. Otherwise thesign_vectortuple is converted to torch.tensor. It is a tuple so it is hashable forlru_cache.
| tl.atomic_max(global_max_ptr, tile_max.to(tl.float32)) | ||
|
|
||
|
|
||
| def triton_rht_amax( |
There was a problem hiding this comment.
Added scaling_type: F.ScalingType = F.ScalingType.TensorWise. The function throw ValueError if it is anything except F.ScalingType.TensorWise
| num_warps=cfg.NUM_WARPS, | ||
| ) | ||
|
|
||
| best = get_best_config(cache_key, HADAMARD_CONFIGS, benchmark_fn) |
There was a problem hiding this comment.
Is there a reason you're re-implementing autotune?
|
|
||
| # Reference: same deterministic matrix (lru_cached, hard-coded sign vector) | ||
| B = get_rht_matrix(with_random_sign_mask=True, device="cuda") | ||
| ref_amax = (A.t().reshape(N * M // 16, 16) @ B).to(torch.bfloat16).abs().max().float() |
There was a problem hiding this comment.
This is (as written) doing A (bf16) @ B (fp32) - intended?
There was a problem hiding this comment.
B from get_rht_matrix is return rht_matrix.to(dtype=torch.bfloat16), so it is always bf16
| signs = get_wgrad_sign_vector(device=device) | ||
| else: | ||
| signs = torch.ones(1, dtype=torch.float32, device=device) | ||
| sign_matrix = signs * torch.eye( |
There was a problem hiding this comment.
We should also be able to specify different sign vectors if desired.
danielvegamyhre
left a comment
There was a problem hiding this comment.
When adding new kernels could you add microbenchmarks in ao/benchmarks/prototype? Feel free to start a nvfp4_training directory in there
| return times[len(times) // 2] | ||
|
|
||
|
|
||
| def get_best_config( |
There was a problem hiding this comment.
why is this necessary? triton autotuner automatically caches the best config for the given keys.
There was a problem hiding this comment.
Removed. I replace host-side TMA descriptor with in-kernel ones so it is compatible with @triton.autotune decorator.
| _autotune_cache: dict[tuple, KernelConfig] = {} | ||
|
|
||
|
|
||
| def do_bench(fn: Callable, warmup_iters: int = 3, bench_iters: int = 10) -> float: |
There was a problem hiding this comment.
can you use our existing benchmark util for this instead of defining a new one:
Line 107 in 707bee8
There was a problem hiding this comment.
Added micro benchmark for the triton kernel.
13ee7fe to
11f29f4
Compare
Summary
triton_rht_amax(hadamard_amax_triton.py): a persistent, warp-specialized Triton kernel that applies the Randomized Hadamard Transform (RHT) to the input and reduces to a scalar global absolute maximum, without materializing the full post-RHT tensortriton_rht_quantize_row_col: the global amax determines the per-tensor decode scale (global_amax / (FP8_E4M3_MAX × FP4_E2M1_MAX)) used in two-level NVFP4 quantization_compute_pidand related RHT matrix helpers tohadamard_utils.pyKey design choices
NUM_SMSCTAs and each CTA iterates over all tiles, amortizing launch overheadwgmmafor the RHT matrix multiply — matches SM90+ warp-specialized pipelineatomic_maxper CTA, avoiding a full(N, M)intermediate allocationTest plan
pytest test/prototype/mx_formats/test_hadamard_amax_triton.py -vtest_hadamard_quantize_row_col_triton.py— incorrect amax would break scale computation and fail SQNR/bitwise tests