From 80aba972a76a743d36f457f9e2f335ae3bd00d8f Mon Sep 17 00:00:00 2001 From: samuelburnham <45365069+samuelburnham@users.noreply.github.com> Date: Tue, 12 May 2026 17:58:15 +0000 Subject: [PATCH] Parallelize lookup message collection Replace the serial loop that computes per-lookup messages with a rayon parallel iteration over a preallocated flat slice of lookup references. Flattening serially first lets `collect` write straight into the output Vec without tree-reducing worker buffers. --- src/lookup.rs | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/lookup.rs b/src/lookup.rs index bb1a654..ef87730 100644 --- a/src/lookup.rs +++ b/src/lookup.rs @@ -1,6 +1,7 @@ use p3_air::{Air, BaseAir, ExtensionBuilder, WindowAccess}; use p3_field::{PrimeCharacteristicRing, batch_multiplicative_inverse}; use p3_matrix::{Matrix, dense::RowMajorMatrix}; +use p3_maybe_rayon::prelude::*; use crate::{ builder::{TwoStagedBuilder, symbolic::SymbolicExpression}, @@ -113,29 +114,23 @@ impl Lookup { fingerprint_challenge: &ExtVal, mut accumulator: ExtVal, ) -> (Vec>, Vec) { - // Collect the number of lookups per circuit while accumulating the total - // number of lookups. - let mut num_lookups_per_circuit = Vec::with_capacity(lookups.len()); - let mut total_num_lookups = 0; - for circuit_lookups in lookups { - let num_rows = circuit_lookups.len(); - // Every row is assumed to have the same number of lookups, which is - // the number of lookups of the first row. - let num_row_lookups = circuit_lookups[0].len(); - let num_circuit_lookups = num_rows * num_row_lookups; - num_lookups_per_circuit.push(num_circuit_lookups); - total_num_lookups += num_circuit_lookups; - } + // Number of lookups per circuit. Every row in a circuit is assumed to + // have the same number of lookups (the lookups are expected to be fully + // padded), so this is taken from the first row. + let num_lookups_per_circuit: Vec = lookups + .iter() + .map(|circuit_lookups| circuit_lookups.len() * circuit_lookups[0].len()) + .collect(); - // Compute and collect all messages. There's one message per lookup. - let mut messages = Vec::with_capacity(total_num_lookups); - for circuit_lookups in lookups { - let circuit_messages = circuit_lookups - .iter() - .flatten() - .map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge)); - messages.extend(circuit_messages); - } + // Compute the message for each lookup, in flat circuit-major order. + // Flatten the references serially first so the parallel map operates + // on an indexed slice and `collect` can write straight into the + // output Vec without tree-reducing worker buffers. + let flat: Vec<&Self> = lookups.iter().flatten().flatten().collect(); + let messages: Vec = flat + .par_iter() + .map(|lookup| lookup.compute_message(lookup_challenge, fingerprint_challenge)) + .collect(); // Compute the inverses of all messages in batch. let messages_inverses = batch_multiplicative_inverse(&messages);