Skip to content

perf: gpu memory optimization by omitting the bottom rows of the merkle tree#2840

Open
GunaDD wants to merge 4 commits into
develop-v2.1.0-rv64from
perf/merkle-tree-mem-opt
Open

perf: gpu memory optimization by omitting the bottom rows of the merkle tree#2840
GunaDD wants to merge 4 commits into
develop-v2.1.0-rv64from
perf/merkle-tree-mem-opt

Conversation

@GunaDD

@GunaDD GunaDD commented Jun 3, 2026

Copy link
Copy Markdown
Contributor

Motivation: A memory optimization of the GPU merkle tree is needed (once we go to 2^32 address space size) because currently how the MemoryMerkleTree works is that if we have N bytes then we will have N/16 leaves (and N/8 nodes in total) and each leaf is stored as 8 field elements. This means we need 4N bytes of VRAM to store the memory merkle tree in the GPU. With N = 2^32 (once we go to 2^32 AS size), 4N is roughy 17 GB which will OOM the GPU (our current limit is ~15 G). 

Optimization idea of this PR: Don't store the last OMITTED_BOTTOM_LEVELS levels of the MemoryMerkleSubTree in the GPU memory and only computing it when needed for address spaces with large sizes (AS2 and deferral AS). This is done by recomputing (i.e. re-computing the poseidon2 hash) from the raw memory. This saves the GPU memory required for the Merkle tree itself by a factor of 8.

Reth benchmark results: https://github.com/axiom-crypto/openvm-eth/actions/runs/26959247863

The summary of the result is proving time increased by 0.77s but the generate mem proving ctxs went down from 3.84 G to 2.09 G and set initial memory went down from 3.63 G to 1.88 G. Note that it didn't went down by a factor of 8 because there are other things that uses the memory. The other things here include the boundarychipgpu and initial_memory buffer (see the MemoryInventoryGPU struct for details).

Closes INT-8079

@github-actions

github-actions Bot commented Jun 3, 2026

Copy link
Copy Markdown

Claude finished @GunaDD's task in 4m 47s —— View job


Code Review: GPU Memory Optimization for Merkle Tree

  • Gather context and read the PR diff
  • Review CUDA kernel changes (merkle_tree.cu)
  • Review Rust-side changes (mod.rs, cuda.rs)
  • Review CUDA header changes (params.cuh)
  • Post review findings

Summary

This PR reduces GPU memory consumption by omitting the bottom 3 levels of large merkle subtrees from GPU buffers. Instead of storing the full 2^(H+1) - 1 nodes, subtrees with height > 3 store only 2^(H-2) - 1 retained nodes — roughly 87.5% memory savings on the subtree heap. Omitted nodes are recomputed on demand from the original memory data via recompute_omitted_node.

The optimization is well-structured: a MemoryMerkleSubTreeLayout enum (mirrored in Rust and CUDA) controls the behavior, and a clean "virtual node" abstraction (load_virtual_node / store_virtual_node) encapsulates the four-way dispatch (zero hash, omitted recompute, vertical path, heap lookup). The removal of initial_subtrees_advance and adjust_subtrees_before_layer_update kernels is a nice simplification enabled by the new index math.

I verified the correctness of the index arithmetic, the guard ordering in load_virtual_node/store_virtual_node, and the layer-by-layer update flow. The old-value reads are consistent with the original initial data, and the no-op stores for omitted nodes are correct because the update kernel propagates new values through the layer working buffer, not through the subtree.

Issues

1. Unused layout parameter in stored_node_index (nit)

stored_node_index (merkle_tree.cu:326-338) accepts a layout parameter but never reads it. The callers already guard against omitted nodes before calling this function, so it isn't needed here. Consider removing it to avoid confusion about who is responsible for the layout dispatch.

Fix this →

2. OMITTED_BOTTOM_LEVELS constant is duplicated without a sync check

The constant is defined independently in Rust (mod.rs:33) and CUDA (params.cuh:64). If they ever diverge, the build and update phases would silently compute incorrect indices and produce wrong proofs. Consider adding a debug_assert! or a compile-time check on the Rust side that validates the CUDA value (e.g., via a C FFI getter), or at minimum add a prominent comment in both locations pointing to the other.

3. initial_data_ptr extends the lifetime requirement of d_data

initial_data_ptr (mod.rs:168) stores a raw device pointer to the initial memory data. Previously, d_data only needed to be alive for the build_async kernel stream. Now it must also survive until update_with_touched_blocks completes, since the update kernel dereferences it to recompute omitted nodes. This is a subtle change that could cause a use-after-free on the GPU if the caller frees d_data between build and update.

The existing doc comment on build_async says the caller must ensure d_data lives long enough for the build — this should be updated to say it must live until after the last update_with_touched_blocks call.

4. Register pressure in recompute_omitted_node

The function allocates digest_t layer[1 << OMITTED_BOTTOM_LEVELS] = 256 bytes of thread-local scratch plus 64 bytes for cells (merkle_tree.cu:68-83). At OMITTED_BOTTOM_LEVELS=3 this is manageable (~320 bytes per thread), but increasing the constant even to 4 would double the scratch to 512 bytes and significantly impact occupancy. Worth adding a compile-time assertion or comment that OMITTED_BOTTOM_LEVELS must not exceed 3–4 without profiling.

Minor observations

  • hash_raw_memory_leaf runtime branch: The refactor replaces if constexpr with a runtime if (merkle_tree.cu:32). For the existing merkle_tree_init template path (FULL layout), the compiler can likely constant-fold after inlining. For merkle_tree_init_omitted, it's a real branch, but all threads in a warp take the same path, so no divergence cost.

  • Good test coverage: The new test_cuda_merkle_subtree_layout_and_buffer_sizes test covers the boundary cases (below, equal, above OMITTED_BOTTOM_LEVELS), and the existing CPU-GPU equivalence test now also asserts layout assignments. The equivalence test is the critical one — it confirms that the optimized GPU tree produces the same root as the CPU reference, both for initial build and after updates.

  • Clean removal of initial_subtrees_advance / adjust_subtrees_before_layer_update: The old pointer-arithmetic kernels are no longer needed because stored_node_index handles the mapping. This reduces kernel launch overhead and simplifies the update loop.

Verdict

The optimization is sound and well-implemented. The correctness of the index math and the guard ordering check out. The main actionable items are:

  1. Remove the unused layout param from stored_node_index (easy cleanup)
  2. Add documentation about the extended lifetime requirement of initial_data_ptr (important for safety)
  3. Consider adding a cross-language sync check for OMITTED_BOTTOM_LEVELS (defensive)

@github-actions

This comment has been minimized.

@shuklaayush shuklaayush force-pushed the develop-v2.1.0-rv64 branch 3 times, most recently from 296adcd to fd69b88 Compare June 4, 2026 14:03
@GunaDD GunaDD force-pushed the perf/merkle-tree-mem-opt branch from c279370 to c70a343 Compare June 4, 2026 14:45
@github-actions

This comment has been minimized.

@github-actions

github-actions Bot commented Jun 4, 2026

Copy link
Copy Markdown

Code review

No issues found. Checked for bugs and CLAUDE.md compliance.

@github-actions

This comment has been minimized.

@GunaDD

GunaDD commented Jun 4, 2026

Copy link
Copy Markdown
Contributor Author

reth benchmark: https://github.com/axiom-crypto/openvm-eth/actions/runs/26959247863

proving time increased by 0.77s but the generate mem proving ctxs went down from 3.84 G to 2.09 G and set initial memory went down from 3.63 G to 1.88 G

@github-actions

This comment has been minimized.

@GunaDD GunaDD force-pushed the perf/merkle-tree-mem-opt branch from c0b1165 to 1886d1b Compare June 4, 2026 20:59
@github-actions

This comment has been minimized.

@shuklaayush shuklaayush left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dropped some comments. i'll leave it to @Golovanov399 and/or @gaxiom to review the cuda changes in detail

inline constexpr size_t BLOCKS_PER_LEAF = DIGEST_WIDTH / BLOCK_FE_WIDTH;

// Number of bottom merkle levels omitted from large GPU subtree buffers.
inline constexpr size_t OMITTED_BOTTOM_LEVELS = 3;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i feel like this is conceptually similar to paging so let's call this MERKLE_PAGE_BITS instead. see how metered execution also does paging and this old doc which has some diagrams

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm what does PAGE mean there actually

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

essentially what you're doing is saying that instead of storing the commitment to individual memory addresses, you'll only store a commitment to a collection of memory addresses (a page)
this is similar to how virtual memory is divided into pages (not exactly but close enough)

Comment thread crates/vm/src/system/cuda/merkle_tree/cuda.rs
/// Returns the bounds [start, end) of the layer at the given depth.
/// These bounds correspond to the indices of the layer in the buffer.
/// depth: 0 = root, 1 = root's children, ..., height-1 = leaves
pub fn layer_bounds(&self, depth: usize) -> (usize, usize) {

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably just remove this

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we do it though ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, let's just get rid of it

Comment thread crates/vm/src/system/cuda/merkle_tree/mod.rs Outdated
Comment thread crates/vm/cuda/src/system/memory/merkle_tree.cu
@shuklaayush shuklaayush force-pushed the develop-v2.1.0-rv64 branch 2 times, most recently from 2995a72 to e582d73 Compare June 7, 2026 11:31
@Golovanov399

Copy link
Copy Markdown
Contributor

Should it be rebased?

@GunaDD GunaDD force-pushed the perf/merkle-tree-mem-opt branch from 1886d1b to 4c9cea5 Compare June 8, 2026 14:40
@github-actions

This comment has been minimized.

@GunaDD GunaDD force-pushed the perf/merkle-tree-mem-opt branch 2 times, most recently from 52efe21 to 4c9cea5 Compare June 8, 2026 16:01
@github-actions

This comment has been minimized.

@Golovanov399 Golovanov399 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left some comments, in principle I understand if not redundantly recomputing stuff twice for close addresses is out of the scope of this task, but I think it's worth doing in long term and if it's not in this PR then I ask you to make a corresponding ticket for the future


// cells is a 2-to-1 compression buffer
Fp cells[CELLS];
for (size_t width = num_leaves / 2; width > 0; width /= 2) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You either don't need 1 << OMITTED_BOTTOM_LEVELS in layer (you only use half of this memory, yet it probably affects register slippage or something), or you may want to parallelize computation on each of the omitted levels (probably not worth it). Or maybe you want to delegate this all to the same SM and do some device synchronize and whatnot, but it can be postponed

auto const old_right_digest = layer_value_on_height(
subtree_layer,
digest_t old_right_digest;
load_virtual_node(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noting that you may save some recomputing on the lowest levels here by somehow reusing what you need from old_left_digest for old_right_digest

@jonathanpwang jonathanpwang removed their request for review June 9, 2026 01:07
Comment on lines +61 to +74
/// Shared handle to the initial-memory buffer (`d_data`) captured in [`Self::build_async`].
/// `None` for empty/dummy subtrees that have no backing buffer.
///
/// Holding an `Arc` makes the subtree a co-owner of the buffer, so the host cannot free it
/// while the subtree is alive. Under the `OmitBottomLevels` layout the omitted bottom levels
/// are never materialized into `buf`; they are recomputed on demand from this buffer during
/// [`MemoryMerkleTree::update_with_touched_blocks`] (see `recompute_omitted_node` in
/// `merkle_tree.cu`), so the buffer must stay alive until that update completes.
///
/// NOTE: this only fixes host-side ownership. The buffer is also consumed by GPU kernels
/// enqueued on the stream, so it must additionally outlive those kernels — that ordering is
/// guaranteed by the `stream.synchronize()` in [`MemoryMerkleTree::drop_subtrees`], not by the
/// `Arc`. Drop the subtrees (which releases these handles) only after that sync.
initial_data: Option<Arc<DeviceBuffer<u8>>>,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments like this are too verbose and borderline slop. see if you can trim them down

@shuklaayush shuklaayush force-pushed the develop-v2.1.0-rv64 branch from dac2c10 to db32ec4 Compare June 10, 2026 18:01
@GunaDD GunaDD force-pushed the perf/merkle-tree-mem-opt branch from 814c620 to 0e6d6a6 Compare June 10, 2026 22:09
@github-actions

Copy link
Copy Markdown
group app.proof_time_ms app.cycles leaf.proof_time_ms
fibonacci 1,523 4,000,051 530
keccak 16,387 14,365,133 3,047
sha2_bench 10,518 11,167,961 1,958
regex 1,520 4,090,656 436
ecrecover 447 112,210 311
pairing 602 592,827 295
kitchen_sink 3,920 1,979,971 862

Note: cells_used metrics omitted because CUDA tracegen does not expose unpadded trace heights.

Commit: 0e6d6a6

Benchmark Workflow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants