Skip to content

[nvfp4_training] Add Triton kernel for global amax of columnwise RHT (SM90+)#4247

Open
rdspring1 wants to merge 1 commit intopytorch:mainfrom
rdspring1:triton_rht_amax
Open

[nvfp4_training] Add Triton kernel for global amax of columnwise RHT (SM90+)#4247
rdspring1 wants to merge 1 commit intopytorch:mainfrom
rdspring1:triton_rht_amax

Conversation

@rdspring1
Copy link
Copy Markdown

@rdspring1 rdspring1 commented Apr 7, 2026

Summary

  • Adds triton_rht_amax (hadamard_amax_triton.py): a persistent, warp-specialized Triton kernel that applies the Randomized Hadamard Transform (RHT) to the input and reduces to a scalar global absolute maximum, without materializing the full post-RHT tensor
  • This is a prerequisite building block for triton_rht_quantize_row_col: the global amax determines the per-tensor decode scale (global_amax / (FP8_E4M3_MAX × FP4_E2M1_MAX)) used in two-level NVFP4 quantization
  • Adds _compute_pid and related RHT matrix helpers to hadamard_utils.py

Key design choices

  • Persistent grid: kernel launches with NUM_SMS CTAs and each CTA iterates over all tiles, amortizing launch overhead
  • Warp specialization: producer warps issue TMA loads; consumer warps run wgmma for the RHT matrix multiply — matches SM90+ warp-specialized pipeline
  • No output buffer: per-CTA cumulative max is reduced with one atomic_max per CTA, avoiding a full (N, M) intermediate allocation

Test plan

  • pytest test/prototype/mx_formats/test_hadamard_amax_triton.py -v
  • Covered transitively by test_hadamard_quantize_row_col_triton.py — incorrect amax would break scale computation and fail SQNR/bitwise tests

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 7, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4247

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 7, 2026
@rdspring1 rdspring1 changed the title [nvfp4_training] Add Triton kernel for fused RHT + global amax (SM90+) [nvfp4_training] Add Triton kernel for global amax of columnwise RHT (SM90+) Apr 7, 2026
@rdspring1 rdspring1 marked this pull request as ready for review April 7, 2026 05:22


def get_wgrad_sign_vector(device) -> torch.Tensor:
"""Hard-coded random signs for Hadamard transform."""
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, why is this hard-coded vs. generated?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Converted get_wgrad_sign_vector to generate random sign vector.

[1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1],
[1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1],
],
dtype=torch.float32,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this needs to be 32b? Won't we save BW later if we store this natively in something lower (4b for instance), along with avoiding casts (See here)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added dtype argument with torch.bfloat16 default



@functools.lru_cache(maxsize=None)
def get_rht_matrix(with_random_sign_mask: bool, device) -> torch.Tensor:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should have argument for hadamard dimension (kwarg-only perhaps, default=16?). Should probably also have an option dtype argument too (seeing notes above)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could add the kwarg with default 16 and for now just assert == 16, wdyt? only support 16 for now, build out prototype quickly and support generating other RHT matrix sizes later?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, that's perfectly fine - I'm concerned with getting the API as correct as possible here

Copy link
Copy Markdown
Author

@rdspring1 rdspring1 Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes get_rht_matrix to

def get_rht_matrix(
    sign_vector: tuple[int, ...] | None,
    device,
    dtype: torch.dtype = torch.bfloat16,
    hadamard_dimension: int = 16,
) -> torch.Tensor:
  • If sign_vector is None, it calls get_wgrad_sign_vector. Otherwise the sign_vector tuple is converted to torch.tensor. It is a tuple so it is hashable for lru_cache.

tl.atomic_max(global_max_ptr, tile_max.to(tl.float32))


def triton_rht_amax(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No option for non-global amax domain

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added scaling_type: F.ScalingType = F.ScalingType.TensorWise. The function throw ValueError if it is anything except F.ScalingType.TensorWise

num_warps=cfg.NUM_WARPS,
)

best = get_best_config(cache_key, HADAMARD_CONFIGS, benchmark_fn)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you're re-implementing autotune?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed.


# Reference: same deterministic matrix (lru_cached, hard-coded sign vector)
B = get_rht_matrix(with_random_sign_mask=True, device="cuda")
ref_amax = (A.t().reshape(N * M // 16, 16) @ B).to(torch.bfloat16).abs().max().float()
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is (as written) doing A (bf16) @ B (fp32) - intended?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B from get_rht_matrix is return rht_matrix.to(dtype=torch.bfloat16), so it is always bf16

signs = get_wgrad_sign_vector(device=device)
else:
signs = torch.ones(1, dtype=torch.float32, device=device)
sign_matrix = signs * torch.eye(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also be able to specify different sign vectors if desired.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added sign_vector argument.

Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When adding new kernels could you add microbenchmarks in ao/benchmarks/prototype? Feel free to start a nvfp4_training directory in there

return times[len(times) // 2]


def get_best_config(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this necessary? triton autotuner automatically caches the best config for the given keys.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed. I replace host-side TMA descriptor with in-kernel ones so it is compatible with @triton.autotune decorator.

_autotune_cache: dict[tuple, KernelConfig] = {}


def do_bench(fn: Callable, warmup_iters: int = 3, bench_iters: int = 10) -> float:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use our existing benchmark util for this instead of defining a new one:

def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added micro benchmark for the triton kernel.

@danielvegamyhre danielvegamyhre added module: training quantize_ api training flow nvfp4 labels Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow nvfp4

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

4 participants