Skip to content

feat(operator):support Cuda-AddRmsNorm-operator.#635

Open
pengcheng888 wants to merge 1 commit into
masterfrom
feat/Cuda-AddRmsNorm-operator
Open

feat(operator):support Cuda-AddRmsNorm-operator.#635
pengcheng888 wants to merge 1 commit into
masterfrom
feat/Cuda-AddRmsNorm-operator

Conversation

@pengcheng888
Copy link
Copy Markdown

@pengcheng888 pengcheng888 commented Jun 3, 2026

Summary

Add a native CUDA AddRmsNorm implementation for NVIDIA GPUs. The operator fuses input + residual, writes the sum to residual_out, and writes the weighted RMS-normalized result to out.

This PR updates the AddRmsNorm base API to (input, residual, weight, eps, out, residual_out), adds shape/dtype/last-dim stride validation, and introduces CUDA kernel plus NVIDIA backend wiring. It also adds pytest coverage (including strided layouts and fp32/fp16/bf16) with a vLLM-aligned PyTorch reference.
Screenshot from 2026-06-05 10-26-12

Motivation

support Cuda-AddRmsNorm-operator for Infinicore.

Type of Change

  • feat — new feature / new operator / new platform
  • fix — bug fix
  • perf — performance improvement (no behavioral change)
  • refactor — code restructuring without behavior change
  • test — adding or fixing tests only
  • docs — documentation only
  • build / ci — build system or CI configuration
  • chore — tooling, formatting, or other non-code changes
  • Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer)

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI
  • Python bindings / user-facing API

Test Results on Supported Platforms

Platform Built pytest Result Notes / Hardware
NVIDIA yes 120 passed, 1 warning in 0.33s
Iluvatar
MetaX
Cambricon
Moore
Ascend
Full `pytest` output (optional)
paste here

(base) ubuntu@ubuntu:~/pr666/pr/zhiqiande/InfiniOps$ pytest tests/test_add_rms_norm.py --devices nvidia
============================================================================================================================= test session starts ==============================================================================================================================
platform linux -- Python 3.12.12, pytest-9.0.1, pluggy-1.5.0
rootdir: /home/ubuntu/pr666/pr/zhiqiande/InfiniOps
configfile: pyproject.toml
plugins: anyio-4.13.0, typeguard-4.4.4
collected 120 items                                                                                                                                                                                                                                                            

tests/test_add_rms_norm.py ........................................................................................................................                                                                                                                      [100%]

=============================================================================================================================== warnings summary ===============================================================================================================================
../../../../miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:63
  /home/ubuntu/miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:63: FutureWarning: The pynvml package is deprecated. Please install nvidia-ml-py instead. If you did not install pynvml directly, please report this to the maintainers of the package that installed pynvml for you.
    import pynvml  # type: ignore[import]

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================================== 120 passed, 1 warning in 0.33s ========================================================================================================================

Benchmark / Performance Impact

Notes for Reviewers


Checklist

Every contributor must verify every item below before requesting
review. Tick each box only after the check has actually been performed —
do not tick speculatively. If an item truly does not apply, replace the
checkbox with N/A and briefly explain why in an inline comment.

Title, Branch, and Commits

  • PR title follows Conventional Commits (e.g. feat(nvidia): …, fix(cuda/gemm): …).
  • Branch name follows <type>/xxx-yyyy-zzzz where <type> matches the PR title's Conventional Commits type and words are joined with hyphens (see CONTRIBUTING.md §Branches).
  • Each commit message follows Conventional Commits.
  • Small PR is a single squashable commit; or, for a large PR, every commit is meaningful, well-formed, and independently reviewable (see CONTRIBUTING.md §Pull Requests).
  • No stray merge commits from master — the branch is rebased cleanly on top of the current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — nothing unrelated to the stated motivation was added (CONTRIBUTING.md §Code/General).
  • No dead code, commented-out blocks, debug prints, printf/std::cout/print(...) left behind, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • Public API changes (if any) are intentional, documented, and reflected in affected callers/tests.

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious (CONTRIBUTING.md §Code/General).
  • Every modified or added file ends with a single trailing newline (CONTRIBUTING.md §Code/General).
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks (e.g. the `seqlens_k` tensor) (CONTRIBUTING.md §Code/General).
  • All comments and error messages are in English (CONTRIBUTING.md §Code/General).
  • Comments and error messages are complete sentences — capitalized first letter, terminal punctuation — unless the language/framework convention says otherwise (CONTRIBUTING.md §Code/General; §Python).

C++ Specific (if C++ files changed)

  • Code follows the Google C++ Style Guide strictly.
  • clang-format (version 21, per .github/workflows/clang-format.yml) has been run against all modified .h, .cc, .cuh, and .mlu files; the diff is clean.
  • clang-tidy concerns (per .clang-tidy) have been reviewed — no new warnings beyond the existing baseline.
  • Operator parameter order is inputs first, outputs last; attributes are between inputs and outputs; naming follows PyTorch → ONNX → CUDA API precedence (CONTRIBUTING.md §C++).
  • No exceptions are thrown. Error paths use assert with messages that include at least __FILE__, __LINE__, and __func__ (CONTRIBUTING.md §C++).
  • Error and warning message wording follows the LLVM Coding Standards (CONTRIBUTING.md §C++).
  • Kernel files are named correctly: custom = kernel / kernel_v2 / …; well-known algorithms use the algorithm name; library-based implementations use the library name (CONTRIBUTING.md §C++).
  • Kernel and kernel launcher are in separate files: launcher .h, kernel follows platform conventions (e.g. .cuh + .cu) even when non-templated (CONTRIBUTING.md §C++).
  • Constructor initializer list order matches member declaration order (CONTRIBUTING.md §C++).
  • Exactly one blank line between classes, between classes and functions, and between functions (CONTRIBUTING.md §C++).
  • Exactly one blank line between members (functions and variables) within a class (CONTRIBUTING.md §C++).
  • Exactly one blank line before and after the contents of a namespace (CONTRIBUTING.md §C++).
  • New operators added via src/base/<op>.h (inheriting Operator<Op>) with platform implementations under src/<category>/<platform>/ inheriting the base (CONTRIBUTING.md §Adding an Operator).
  • No raw new/delete; RAII / smart pointers / existing allocators are used.

Python Specific (if Python files changed)

  • Code is PEP 8 compliant; ruff check passes cleanly on CI (see .github/workflows/ruff.yml).
  • ruff format --check passes cleanly — if not, run ruff format and commit the result.
  • Comments are complete English sentences, starting with a capital letter and ending with punctuation; Markdown backticks are used for code references (CONTRIBUTING.md §Python).
  • Framework-specific conventions (e.g. lowercase pytest.skip messages without terminal period) are honored where applicable (CONTRIBUTING.md §Python).
  • No blank line between the function signature and the body when there is no docstring or comment (CONTRIBUTING.md §Python).
  • A blank line is present before and after if, for, and similar control-flow statements (CONTRIBUTING.md §Python).
  • A blank line appears before each return, except when it directly follows a control-flow statement (CONTRIBUTING.md §Python).
  • Docstrings (if any) follow PEP 257 (CONTRIBUTING.md §Python).
  • Type hints are added / kept consistent with the surrounding code.

Testing

  • pytest was run locally on every supported platform that this PR can affect, and the results are recorded in the "Test Results" table above (CONTRIBUTING.md §Pull Requests).
  • For any platform that could not be tested, an explicit reason is given in the table and a reviewer with access has been tagged.
  • New functionality has matching tests under tests/ following tests/test_add.py / tests/test_gemm.py patterns (CONTRIBUTING.md §Adding an Operator).
  • Tests use pytest.mark.parametrize correctly: dependent parameters share one decorator (e.g. @pytest.mark.parametrize("dtype, rtol, atol", …)), independent parameters use separate decorators ordered by parameter declaration.
  • Where appropriate, pytest.mark.auto_act_and_assert is used and the test returns a Payload whose func and ref share the same calling convention.
  • Default dtype / device parameterization is relied on, or overridden with an explicit pytest.mark.parametrize when necessary.
  • Any new test that is flaky under parallelism is marked so, or documented to require pytest -n 1.
  • For bug fixes: a regression test has been added that fails on master and passes with this PR.

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory with pip install .[dev] on at least one affected platform.
  • compile_commands.json still regenerates (CMake option CMAKE_EXPORT_COMPILE_COMMANDS=ON in pyproject.toml — required by the code-lint skill and clang-tidy -p).
  • New backends / devices have been added to auto-detection in CMakeLists.txt under if(AUTO_DETECT_DEVICES) and to if(AUTO_DETECT_BACKENDS) if applicable.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is not broken.
  • Both CI workflows (clang-format.yml, ruff.yml) are green locally (or expected to be green on CI).
  • No new runtime dependency was added without updating pyproject.toml's [project.optional-dependencies] (or justified in the PR description).

Documentation

  • README.md, CONTRIBUTING.md, or inline docs updated when behavior, build flags, or developer workflow changed.
  • New operators, new dispatch helpers, or new public utilities are documented (docstring, header comment, or an addition to CONTRIBUTING.md §Some Code Explanations).
  • Any user-visible breaking change is called out explicitly under "Motivation" and in the commit/PR title with a ! or BREAKING CHANGE: footer.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, or personal hardware identifiers have been committed.
  • Third-party code is license-compatible and attributed.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks were introduced.

@pengcheng888 pengcheng888 requested a review from a team June 3, 2026 11:41
Comment thread src/base/add_rms_norm.h Outdated
// Also consider the same change for `RmsNorm`.
AddRmsNorm(const Tensor input, const Tensor other, const Tensor weight,
float eps, Tensor out, Tensor rstd_out)
AddRmsNorm(Tensor residual, const Tensor input, const Tensor other,
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

main分支已经有了AddRmsNorm的接口,移植InfiniCore时,有了接口的修改。 是否修改,明天再讨论

Comment thread .gitignore Outdated
@@ -1,3 +1,5 @@
third_party
Copy link
Copy Markdown
Author

@pengcheng888 pengcheng888 Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我调试时额外加的third_party/目录,不合适的话,回来删掉

@pengcheng888 pengcheng888 force-pushed the feat/Cuda-AddRmsNorm-operator branch from 597918f to 9898c04 Compare June 5, 2026 02:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant