Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions expander_compiler/src/frontend/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,26 @@ pub trait BasicAPI<C: Config> {
}
res
}

/// compute constant + sum_i(coef_i * var_i) in a single instruction
/// Builder<C> overrides this with a LinComb instruction (O(1) per neuron vs O(2n) for mul+add).
/// default fallback uses repeated mul+add and is semantically equivalent.
fn linear_combination(
&mut self,
terms: &[(Variable, CircuitField<C>)],
constant: CircuitField<C>,
) -> Variable {
// fallback: 2 instructions per non-zero term; Builder<C> overrides with LinComb.
// Skip zero-coef terms to mirror Builder's behaviour and avoid useless mul+add pairs.
let mut acc = self.constant(constant);
for (var, coef) in terms {
if !coef.is_zero() {
let scaled = self.mul(*var, *coef);
acc = self.add(acc, scaled);
}
}
acc
}
Comment thread
npow marked this conversation as resolved.
}

pub trait UnconstrainedAPI<C: Config> {
Expand Down
45 changes: 45 additions & 0 deletions expander_compiler/src/frontend/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,43 @@ impl<C: Config> BasicAPI<C> for Builder<C> {
}
}

/// emits a single LinComb instruction for constant + sum_i(coef_i * var_i)
/// reduces instruction count from O(2n) to O(1) per output neuron;
/// eliminates the optimizer's expression-expansion blowup on wide linear layers
/// (naive mul+add hits ~70 GB RSS at 1.47M gates; LinComb stays ~1.1 GB)
fn linear_combination(
&mut self,
terms: &[(Variable, CircuitField<C>)],
constant: CircuitField<C>,
) -> Variable {
// validate all variables before accessing their ids
for (var, _) in terms {
ensure_variable_valid(*var);
}
// drop zero-coefficient terms; they add no constraint and inflate the LinComb vec
let mut lc_terms: Vec<LinCombTerm<C>> = Vec::with_capacity(terms.len());
lc_terms.extend(
terms
.iter()
.filter(|(_, coef)| !coef.is_zero())
.map(|(var, coef)| LinCombTerm {
var: var.id,
coef: *coef,
}),
);

if lc_terms.is_empty() {
// pure constant
return self.constant(constant);
}

self.instructions.push(SourceInstruction::LinComb(LinComb {
terms: lc_terms,
constant,
}));
self.new_var()
}

// return 1 if x > y; 0 otherwise
//
fn gt(
Expand Down Expand Up @@ -706,6 +743,14 @@ impl<C: Config> BasicAPI<C> for RootBuilder<C> {
) -> Variable {
self.last_builder().geq(x, y)
}

fn linear_combination(
&mut self,
terms: &[(Variable, CircuitField<C>)],
constant: CircuitField<C>,
) -> Variable {
self.last_builder().linear_combination(terms, constant)
}
}

impl<C: Config> RootAPI<C> for RootBuilder<C> {
Expand Down
134 changes: 133 additions & 1 deletion expander_compiler/src/frontend/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::frontend::M31Config as C;
use crate::{
compile::CompileOptions,
field::{FieldArith, M31},
frontend::{compile, RootAPI},
frontend::{compile, BasicAPI, RootAPI},
};

use super::{builder::Variable, circuit::*, variables::DumpLoadTwoVariables};
Expand Down Expand Up @@ -88,3 +88,135 @@ fn test_circuit_eval_simple() {
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![false]);
}

// linear_combination tests:
// verify that linear_combination([(x,a),(y,b)], c) == c + a*x + b*y via circuit eval

declare_circuit!(LinCombCircuit {
out_lc: Variable, // computed via linear_combination
out_ref: Variable, // computed via repeated mul+add (reference)
x: Variable,
y: Variable,
});

impl Define<C> for LinCombCircuit<Variable> {
fn define<Builder: RootAPI<C>>(&self, builder: &mut Builder) {
use crate::field::M31;
let a = M31::from(3u32);
let b = M31::from(5u32);
let c = M31::from(7u32);

let lc = builder.linear_combination(&[(self.x, a), (self.y, b)], c);
builder.assert_is_equal(lc, self.out_lc);

// reference: c + a*x + b*y via repeated mul+add
let ax = builder.mul(self.x, a);
let by = builder.mul(self.y, b);
let axby = builder.add(ax, by);
let ref_val = builder.add(axby, c);
builder.assert_is_equal(ref_val, self.out_ref);
}
}

#[test]
fn test_linear_combination_matches_mul_add() {
let compile_result = compile(&LinCombCircuit::default(), CompileOptions::default()).unwrap();

// x=2, y=4 → lc = 7 + 3*2 + 5*4 = 7+6+20 = 33
let assignment = LinCombCircuit::<M31> {
out_lc: M31::from(33u32),
out_ref: M31::from(33u32),
x: M31::from(2u32),
y: M31::from(4u32),
};
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![true]);
}

declare_circuit!(LinCombZeroCoefCircuit {
out: Variable,
x: Variable,
y: Variable,
});

impl Define<C> for LinCombZeroCoefCircuit<Variable> {
fn define<Builder: RootAPI<C>>(&self, builder: &mut Builder) {
use crate::field::M31;
// zero coefficient on y; linear_combination must ignore it
let a = M31::from(3u32);
let zero = M31::from(0u32);
let c = M31::from(1u32);
let lc = builder.linear_combination(&[(self.x, a), (self.y, zero)], c);
builder.assert_is_equal(lc, self.out);
}
}

#[test]
fn test_linear_combination_zero_coef_ignored() {
let compile_result = compile(
&LinCombZeroCoefCircuit::default(),
CompileOptions::default(),
)
.unwrap();

// y has zero coef → out = 1 + 3*x regardless of y
let assignment = LinCombZeroCoefCircuit::<M31> {
out: M31::from(7u32), // 1 + 3*2 = 7
x: M31::from(2u32),
y: M31::from(999u32), // ignored
};
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![true]);
}

declare_circuit!(LinCombAllZeroCircuit {
out: Variable,
x: Variable,
});

impl Define<C> for LinCombAllZeroCircuit<Variable> {
fn define<Builder: RootAPI<C>>(&self, builder: &mut Builder) {
use crate::field::M31;
// all zero coefs → falls back to pure constant
let zero = M31::from(0u32);
let c = M31::from(42u32);
let lc = builder.linear_combination(&[(self.x, zero)], c);
builder.assert_is_equal(lc, self.out);
}
}

#[test]
fn test_linear_combination_all_zero_coefs_returns_constant() {
let compile_result =
compile(&LinCombAllZeroCircuit::default(), CompileOptions::default()).unwrap();

let assignment = LinCombAllZeroCircuit::<M31> {
out: M31::from(42u32),
x: M31::from(5u32), // irrelevant
};
let witness = compile_result
.witness_solver
.solve_witness(&assignment)
.unwrap();
let output = compile_result.layered_circuit.run(&witness);
assert_eq!(output, vec![true]);
}

#[test]
#[should_panic]
fn test_linear_combination_invalid_variable_panics() {
use crate::frontend::builder::Builder;
let (mut b, _inputs) = Builder::<C>::new(2);
let bad = Variable::default(); // id=0, invalid
let c = M31::from(0u32);
// must panic due to ensure_variable_valid
let _ = b.linear_combination(&[(bad, M31::from(1u32))], c);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use gkr::{gkr_prove_batch, gkr_verify};
use gkr_engine::{ExpanderPCS, FieldEngine, GKREngine, MPIConfig, Transcript};
use crate::{frontend::{Config, SIMDField}, utils::misc::next_power_of_two,
zkcuda::{context::ComputationGraph, proving_system::{common::check_inputs,
expander::{prove_impl::{get_local_vals, prepare_expander_circuit, prepare_inputs_with_local_vals},
expander::{prove_impl::{get_local_vals, prepare_expander_circuit},
structs::{ExpanderProof, ExpanderProverSetup, ExpanderVerifierSetup}},
CombinedProof, Expander, ProvingSystem}}};

Expand Down Expand Up @@ -92,7 +92,7 @@ impl<C: GKREngine, ECCConfig: Config<FieldConfig = C::FieldConfig>> ProvingSyste
if !ok { return false; }
let chs = if let Some(cy) = ch.challenge_y() { vec![ch.challenge_x(), cy] } else { vec![ch.challenge_x()] };
for sc in &chs {
for (&ref comm, &_ib) in comms.iter().zip(tmpl.is_broadcast().iter()) {
for (comm, &_ib) in comms.iter().zip(tmpl.is_broadcast().iter()) {
let commitment_len = comm.vals_len;
let local_size = commitment_len >> sc.r_mpi.len();
let n_local = if local_size > 0 { local_size.ilog2() as usize } else { 0 };
Expand Down Expand Up @@ -142,7 +142,7 @@ fn prove_one<C: GKREngine, ECCConfig: Config<FieldConfig = C::FieldConfig>>(
if pc > 1 {
let mut tr = C::TranscriptConfig::new();
let mut tc = bc.clone(); tc.fill_rnd_coefs(&mut tr);
let is = 1 << tc.log_input_size();
let _is = 1 << tc.log_input_size();
// Flat-buffer batch: zero malloc during circuit prep
let ki = kernel.layered_circuit_input();
let (mut circuits, _flat_bufs) = unsafe { tc.create_batch(pc) };
Expand All @@ -151,12 +151,11 @@ fn prove_one<C: GKREngine, ECCConfig: Config<FieldConfig = C::FieldConfig>>(
// Inline get_local_vals: zero alloc, write directly into flat buffer
let input = &mut circuits[pi].layers[0].input_vals;
for v in input.iter_mut() { *v = Default::default(); }
for (ci, (partition, (&ref vals, &ib))) in ki.iter()
for (partition, (vals, &ib)) in ki.iter()
.zip(cvs.iter().zip(is_bc.iter()))
.enumerate()
{
let local_slice = if ib {
vals.as_ref()
vals
} else {
let chunk = vals.len() / pc;
&vals[chunk * pi..chunk * (pi + 1)]
Expand Down Expand Up @@ -193,7 +192,7 @@ fn prove_one<C: GKREngine, ECCConfig: Config<FieldConfig = C::FieldConfig>>(
let t2 = std::time::Instant::now();
let chs = if let Some(cy) = ch.challenge_y() { vec![ch.challenge_x(), cy] } else { vec![ch.challenge_x()] };
for sc in &chs {
for (ci, (&ref v, &_ib)) in cvs.iter().zip(tmpl.is_broadcast().iter()).enumerate() {
for (ci, (v, &_ib)) in cvs.iter().zip(tmpl.is_broadcast().iter()).enumerate() {
let pc2 = sc.clone();
let comm_idx = tmpl.commitment_indices()[ci];
let scratch = &commit_states[comm_idx].scratch;
Expand Down Expand Up @@ -237,13 +236,13 @@ fn dump_circuits_for_gpu<F: gkr_engine::FieldEngine>(
circuits: &[expander_circuit::Circuit<F>],
) {
use std::io::Write;
let dir = format!("gpu_data/tmpl_{}", ti);
let dir = format!("gpu_data/tmpl_{ti}");
std::fs::create_dir_all(&dir).ok();

let num_layers = template_circuit.layers.len();

// Write header: N, num_layers, per-layer sizes
let mut hdr = std::fs::File::create(format!("{}/header.bin", dir)).unwrap();
let mut hdr = std::fs::File::create(format!("{dir}/header.bin")).unwrap();
hdr.write_all(&(pc as u32).to_le_bytes()).unwrap();
hdr.write_all(&(num_layers as u32).to_le_bytes()).unwrap();
for layer in &template_circuit.layers {
Expand All @@ -256,7 +255,7 @@ fn dump_circuits_for_gpu<F: gkr_engine::FieldEngine>(
// Write gates per layer (shared across all instances)
for (li, layer) in template_circuit.layers.iter().enumerate() {
// Mul gates: [o_id, x_id, y_id, coef] x n_mul
let mut gf = std::fs::File::create(format!("{}/layer_{}_mul.bin", dir, li)).unwrap();
let mut gf = std::fs::File::create(format!("{dir}/layer_{li}_mul.bin")).unwrap();
for gate in &layer.mul {
gf.write_all(&(gate.o_id as u32).to_le_bytes()).unwrap();
gf.write_all(&(gate.i_ids[0] as u32).to_le_bytes()).unwrap();
Expand All @@ -265,11 +264,11 @@ fn dump_circuits_for_gpu<F: gkr_engine::FieldEngine>(
let coef_bytes: &[u8] = unsafe {
std::slice::from_raw_parts(&gate.coef as *const _ as *const u8, 4)
};
gf.write_all(&coef_bytes).unwrap();
gf.write_all(coef_bytes).unwrap();
}

// Add gates: [o_id, x_id, coef] x n_add
let mut af = std::fs::File::create(format!("{}/layer_{}_add.bin", dir, li)).unwrap();
let mut af = std::fs::File::create(format!("{dir}/layer_{li}_add.bin")).unwrap();
for gate in &layer.add {
af.write_all(&(gate.o_id as u32).to_le_bytes()).unwrap();
af.write_all(&(gate.i_ids[0] as u32).to_le_bytes()).unwrap();
Expand All @@ -284,7 +283,7 @@ fn dump_circuits_for_gpu<F: gkr_engine::FieldEngine>(
// Layout: [instance_0_layer0_input_vals | instance_1_layer0_input_vals | ...]
// Each instance = layer0.input_vals as raw M31x16 bytes (contiguous)
{
let mut wf = std::fs::File::create(format!("{}/witness.bin", dir)).unwrap();
let mut wf = std::fs::File::create(format!("{dir}/witness.bin")).unwrap();
for circuit in circuits.iter() {
let vals = &circuit.layers[0].input_vals;
let bytes: &[u8] = unsafe {
Expand All @@ -297,5 +296,5 @@ fn dump_circuits_for_gpu<F: gkr_engine::FieldEngine>(
}
}

eprintln!(" [dump] tmpl[{}] N={} layers={} → {}/", ti, pc, num_layers, dir);
eprintln!(" [dump] tmpl[{ti}] N={pc} layers={num_layers} -> {dir}/");
}