Skip to content

fix(cuda/gemm): make caches safe for threaded callers#632

Open
voltjia wants to merge 1 commit into
masterfrom
fix/cuda-gemm-thread-safe-caches
Open

fix(cuda/gemm): make caches safe for threaded callers#632
voltjia wants to merge 1 commit into
masterfrom
fix/cuda-gemm-thread-safe-caches

Conversation

@voltjia
Copy link
Copy Markdown
Collaborator

@voltjia voltjia commented Jun 3, 2026

Summary

  • Change src/native/cuda/ops/gemm/blas.h so CUDA GEMM uses one thread-local BLAS handle cache per device index instead of a single process-wide handle.
  • Change src/operator.h so Operator::Call uses thread-local operator cache storage with an atomic generation counter for cache invalidation.
  • Extend tests/test_cpp_api.py with a multi-threaded public C++ Operator::Call smoke check.

Motivation

Closes #627

Downstream multi-thread and multi-device callers can call public C++ Operator::Call concurrently. The previous process-wide operator cache was unsynchronized, and CUDA GEMM used one process-wide BLAS handle. This PR removes those shared mutable process-wide caches from the hot path.

The previous draft also changed NVIDIA fp32 GEMM compute type. That change has been removed because it is not necessary for this cache/thread-safety fix.

Type of Change

  • feat — new feature / new operator / new platform
  • fix — bug fix
  • perf — performance improvement (no behavioral change)
  • refactor — code restructuring without behavior change
  • test — adding or fixing tests only
  • docs — documentation only
  • build / ci — build system or CI configuration
  • chore — tooling, formatting, or other non-code changes
  • Breaking change (requires a ! in the Conventional Commits prefix or a BREAKING CHANGE: footer)

Platforms Affected

  • CPU (WITH_CPU)
  • NVIDIA (WITH_NVIDIA)
  • Iluvatar (WITH_ILUVATAR)
  • MetaX (WITH_METAX)
  • Cambricon (WITH_CAMBRICON)
  • Moore (WITH_MOORE)
  • Ascend (WITH_ASCEND)
  • PyTorch C++ bindings (WITH_TORCH)
  • Build system / CMake / CI
  • Python bindings / user-facing API

Test Results on Supported Platforms

Platform Built pytest Result Notes / Hardware
NVIDIA Yes 2000 passed, 1001 skipped in 12.46s Remote nvidia, container infiniops-ci/nvidia:latest, WITH_CPU=ON, WITH_NVIDIA=ON
Iluvatar No Not run Not touched directly; no run requested for this PR update
MetaX No Not run Not touched directly; no run requested for this PR update
Cambricon No Not run Not touched directly; no run requested for this PR update
Moore No Not run Not touched directly; no run requested for this PR update
Ascend No Not run Not touched directly; no run requested for this PR update
Full `pytest` output (optional)
Running 3001 items in this shard
s....................................................................... [  2%]
...
sssssssssssssssssssssssssssssssssssssssssssssssss                        [100%]
2000 passed, 1001 skipped in 12.46s

Benchmark / Performance Impact

N/A. This PR changes cache ownership for thread safety; no benchmark was run.

Notes for Reviewers

  • BlasGemm::GetHandle now keys BLAS handles by Tensor::device().index() within each host thread.
  • Operator::Call cache entries are no longer shared between host threads.
  • clear_cache() still invalidates each thread's cache on that thread's next call by using a shared atomic generation counter.
  • The NVIDIA fp32 compute-type change from PR fix(cuda): make GEMM caches safe for threaded callers #631 was intentionally removed as unrelated to this fix.

Checklist

Title, Branch, and Commits

  • PR title follows Conventional Commits (e.g. feat(nvidia): …, fix(cuda/gemm): …).
  • Branch name follows <type>/xxx-yyyy-zzzz where <type> matches the PR title's Conventional Commits type and words are joined with hyphens (see CONTRIBUTING.md §Branches).
  • Each commit message follows Conventional Commits.
  • Small PR is a single squashable commit; or, for a large PR, every commit is meaningful, well-formed, and independently reviewable (see CONTRIBUTING.md §Pull Requests).
  • No stray merge commits from master — the branch is rebased cleanly on top of the current master.
  • No fixup! / squash! / wip commits remain.

Scope and Design

  • Changes are minimal — nothing unrelated to the stated motivation was added (CONTRIBUTING.md §Code/General).
  • No dead code, commented-out blocks, debug prints, printf/std::cout/print(...) left behind, or TODO without an owner and issue link.
  • No unrelated formatting churn that would obscure the diff.
  • Public API changes (if any) are intentional, documented, and reflected in affected callers/tests.

General Code Hygiene (applies to all languages)

  • The code is self-explanatory; comments were added only where the why is non-obvious (CONTRIBUTING.md §Code/General).
  • Every modified or added file ends with a single trailing newline (CONTRIBUTING.md §Code/General).
  • No trailing whitespace, tab/space mixing, or stray BOMs.
  • Identifiers in comments and error messages are wrapped in backticks (e.g. the `seqlens_k` tensor) (CONTRIBUTING.md §Code/General).
  • All comments and error messages are in English (CONTRIBUTING.md §Code/General).
  • Comments and error messages are complete sentences — capitalized first letter, terminal punctuation — unless the language/framework convention says otherwise (CONTRIBUTING.md §Code/General; §Python).

C++ Specific (if C++ files changed)

  • Code follows the Google C++ Style Guide strictly.
  • clang-format (version 21, per .github/workflows/clang-format.yml) has been run against all modified .h, .cc, .cuh, and .mlu files; the diff is clean. Not run: clang-format 21 was not available in the NVIDIA validation container.
  • clang-tidy concerns (per .clang-tidy) have been reviewed — no new warnings beyond the existing baseline.
  • N/A: Operator parameter order is unchanged.
  • No exceptions are thrown. Error paths use assert with messages that include at least __FILE__, __LINE__, and __func__ (CONTRIBUTING.md §C++).
  • N/A: No new error or warning messages.
  • N/A: No kernel files were added or renamed.
  • N/A: No kernel launchers were added.
  • Constructor initializer list order matches member declaration order (CONTRIBUTING.md §C++).
  • Exactly one blank line between classes, between classes and functions, and between functions (CONTRIBUTING.md §C++).
  • Exactly one blank line between members (functions and variables) within a class (CONTRIBUTING.md §C++).
  • Exactly one blank line before and after the contents of a namespace (CONTRIBUTING.md §C++).
  • N/A: No new operators were added.
  • No raw new/delete; RAII / smart pointers / existing allocators are used.

Python Specific (if Python files changed)

  • Code is PEP 8 compliant; ruff check passes cleanly on CI (see .github/workflows/ruff.yml).
  • ruff format --check passes cleanly — if not, run ruff format and commit the result.
  • N/A: No Python comments were added.
  • Framework-specific conventions (e.g. lowercase pytest.skip messages without terminal period) are honored where applicable (CONTRIBUTING.md §Python).
  • No blank line between the function signature and the body when there is no docstring or comment (CONTRIBUTING.md §Python).
  • A blank line is present before and after if, for, and similar control-flow statements (CONTRIBUTING.md §Python).
  • A blank line appears before each return, except when it directly follows a control-flow statement (CONTRIBUTING.md §Python).
  • N/A: No docstrings were added.
  • N/A: No type hints were added.

Testing

  • pytest was run locally on every supported platform that this PR can affect, and the results are recorded in the "Test Results" table above (CONTRIBUTING.md §Pull Requests). Only NVIDIA was run for this PR update.
  • For any platform that could not be tested, an explicit reason is given in the table and a reviewer with access has been tagged.
  • New functionality has matching tests under tests/ following tests/test_add.py / tests/test_gemm.py patterns (CONTRIBUTING.md §Adding an Operator).
  • Tests use pytest.mark.parametrize correctly: dependent parameters share one decorator (e.g. @pytest.mark.parametrize("dtype, rtol, atol", …)), independent parameters use separate decorators ordered by parameter declaration.
  • N/A: The new C++ API smoke test does not return a Payload.
  • Default dtype / device parameterization is relied on, or overridden with an explicit pytest.mark.parametrize when necessary.
  • Any new test that is flaky under parallelism is marked so, or documented to require pytest -n 1.
  • For bug fixes: a regression test has been added that fails on master and passes with this PR.

Build, CI, and Tooling

  • The project builds cleanly from a fresh directory with pip install .[dev] on at least one affected platform. Built with pip install --no-build-isolation --no-deps . in the NVIDIA CI container.
  • compile_commands.json still regenerates (CMake option CMAKE_EXPORT_COMPILE_COMMANDS=ON in pyproject.toml — required by the code-lint skill and clang-tidy -p).
  • N/A: No new backends or devices were added.
  • Only one CUDA-like GPU backend is selectable at a time — the existing mutual-exclusion check in CMakeLists.txt is not broken.
  • Both CI workflows (clang-format.yml, ruff.yml) are green locally (or expected to be green on CI). ruff was run locally in the container; clang-format 21 was not available.
  • No new runtime dependency was added without updating pyproject.toml's [project.optional-dependencies] (or justified in the PR description).

Documentation

  • N/A: No README, CONTRIBUTING, or workflow behavior changed.
  • N/A: No new operators, dispatch helpers, or public utilities were added.
  • N/A: No user-visible breaking change.

Security and Safety

  • No secrets, access tokens, internal URLs, customer data, or personal hardware identifiers have been committed.
  • N/A: No third-party code was added.
  • No unsafe pointer arithmetic, uninitialized reads, or missing bounds checks were introduced.

@voltjia voltjia force-pushed the fix/cuda-gemm-thread-safe-caches branch from 38eae03 to 80fa335 Compare June 3, 2026 03:55
Copy link
Copy Markdown
Collaborator Author

voltjia commented Jun 3, 2026

Update after review request:

  • Added a dedicated regression test: tests/test_cpp_api.py::test_cpp_operator_call_thread_local_cache_regression.
  • The test defines a local C++ probe operator whose instance records the constructing host thread. It calls Operator::Call once on the main thread and once on a child thread with the same cache key.
  • On current origin/master with only this test copied over, the new regression fails because the process-wide Operator::Call cache reuses the main-thread operator in the child thread.
  • On this PR, the same test passes because the operator cache is thread-local.

Latest remote NVIDIA validation, container infiniops-ci/nvidia:latest:

python3 -m pytest -q -rs tests/test_cpp_api.py tests/test_gemm.py --devices nvidia
2002 passed, 1000 skipped in 8.94s

Also reran ruff format tests/test_cpp_api.py and ruff check tests/test_cpp_api.py; both passed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

CUDA GEMM: process-wide operator/BLAS caches are unsafe for multi-thread + multi-device callers

1 participant