From f39a311425489d4634f5aa7928d22e8703b16b53 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 16 May 2025 21:51:38 +0800 Subject: [PATCH 01/28] refactor layer proof done --- Cargo.lock | 3 + Cargo.toml | 1 + ceno_zkvm/src/utils.rs | 18 -- gkr_iop/Cargo.toml | 4 +- gkr_iop/examples/multi_layer_logup.rs | 4 +- gkr_iop/src/error.rs | 5 +- gkr_iop/src/evaluation.rs | 4 +- gkr_iop/src/gkr.rs | 4 +- gkr_iop/src/gkr/layer.rs | 108 ++++++------ gkr_iop/src/gkr/layer/linear_layer.rs | 82 ++++----- gkr_iop/src/gkr/layer/sumcheck_layer.rs | 123 ++++++++----- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 183 ++++++++++++++++---- gkr_iop/src/gkr/mock.rs | 4 +- gkr_iop/src/precompiles/lookup_keccakf.rs | 2 +- multilinear_extensions/src/virtual_polys.rs | 14 ++ sumcheck/Cargo.toml | 1 + sumcheck/src/structs.rs | 13 +- sumcheck/src/util.rs | 15 ++ 18 files changed, 375 insertions(+), 213 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6df104fd1..51e30275b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1112,6 +1112,7 @@ version = "0.1.0" dependencies = [ "ark-std", "criterion", + "either", "ff_ext", "itertools 0.13.0", "multilinear_extensions", @@ -1122,6 +1123,7 @@ dependencies = [ "rayon", "serde", "subprotocols", + "sumcheck", "thiserror 1.0.69", "tiny-keccak", "transcript", @@ -2882,6 +2884,7 @@ dependencies = [ "rayon", "serde", "sumcheck_macro", + "thiserror 1.0.69", "tracing", "transcript", ] diff --git a/Cargo.toml b/Cargo.toml index 8d405dbd6..370340c4f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ secp = "0.4.1" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" strum = "0.26" +thiserror = "1" # do we need this? strum_macros = "0.26" subprotocols = { path = "subprotocols" } substrate-bn = { version = "0.6.0" } diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index c18ab880b..bdc1d5952 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -61,24 +61,6 @@ pub(crate) fn add_one_to_big_num(limb_modulo: F, limbs: &[F]) -> Vec( - size: usize, - transcript: &mut impl Transcript, -) -> Vec { - // println!("alpha_pow"); - let alpha = transcript - .sample_and_append_challenge(b"combine subset evals") - .elements; - (0..size) - .scan(E::ONE, |state, _| { - let res = *state; - *state *= alpha; - Some(res) - }) - .collect_vec() -} - // split single u64 value into W slices, each slice got C bits. // all the rest slices will be filled with 0 if W x C > 64 pub fn u64vec(x: u64) -> [u64; W] { diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index 668167c2a..df7e26a05 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -21,8 +21,10 @@ rand.workspace = true rayon.workspace = true serde.workspace = true subprotocols = { path = "../subprotocols" } -thiserror = "1" +sumcheck.workspace = true +thiserror.workspace = true tiny-keccak.workspace = true +either.workspace = true transcript = { path = "../transcript" } witness = { path = "../witness" } diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index f8e0a3906..3a664fd46 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -250,7 +250,7 @@ fn main() { #[cfg(debug_assertions)] { - let last = gkr_witness.layers[0].exts.clone(); + let last = gkr_witness.layers[0].bases.clone(); MockProver::check( gkr_circuit.clone(), &gkr_witness, @@ -264,7 +264,7 @@ fn main() { } let out_evals = { - let last = gkr_witness.layers[0].exts.clone(); + let last = gkr_witness.layers[0].bases.clone(); let point = Arc::new(vec![]); assert_eq!(last[0].len(), 1); vec![ diff --git a/gkr_iop/src/error.rs b/gkr_iop/src/error.rs index 02c64f96e..8fab57752 100644 --- a/gkr_iop/src/error.rs +++ b/gkr_iop/src/error.rs @@ -1,8 +1,9 @@ -use subprotocols::error::VerifierError; +use ff_ext::ExtensionField; +use sumcheck::structs::VerifierError; use thiserror::Error; #[derive(Clone, Debug, Error)] -pub enum BackendError { +pub enum BackendError { #[error("layer verification failed: {0:?}, {1:?}")] LayerVerificationFailed(String, VerifierError), } diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index b35565c7f..1d406c8fa 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use ff_ext::ExtensionField; use itertools::{Itertools, izip}; use multilinear_extensions::virtual_poly::build_eq_x_r_vec_sequential; @@ -70,7 +68,7 @@ impl EvalExpression { let eval = izip!(parts, &eq).fold(E::ZERO, |acc, (part, eq)| acc + part.eval * *eq); PointAndEval { - point: Arc::new(new_point), + point: new_point, eval, } } diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 98f6de8eb..89c789e78 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -24,8 +24,8 @@ pub struct GKRCircuit { } #[derive(Clone, Debug, Default)] -pub struct GKRCircuitWitness { - pub layers: Vec>, +pub struct GKRCircuitWitness<'a, E: ExtensionField> { + pub layers: Vec>, } #[derive(Clone, Serialize, Deserialize)] diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 6ccad13bb..31dc9ff66 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -1,13 +1,15 @@ use ark_std::log2; use ff_ext::ExtensionField; -use itertools::{chain, izip}; +use itertools::{Itertools, chain, izip}; use linear_layer::LinearLayer; -use serde::{Deserialize, Serialize}; +use multilinear_extensions::{Expression, mle::MultilinearExtension}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use subprotocols::{ - expression::{Constant, Expression, Point}, + expression::{Constant, Point}, sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, }; -use sumcheck_layer::SumcheckLayer; +use sumcheck::structs::IOPProof; +use sumcheck_layer::{SumcheckLayer, SumcheckLayerProof}; use transcript::Transcript; use zerocheck_layer::ZerocheckLayer; @@ -29,7 +31,11 @@ pub enum LayerType { } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Layer { +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] +pub struct Layer { pub name: String, pub ty: LayerType, /// Challenges generated at the beginning of the layer protocol. @@ -38,14 +44,14 @@ pub struct Layer { /// each expression corresponds to an output. While in sumcheck, there /// is only 1 expression, which corresponds to the sum of all outputs. /// This design is for the convenience when building the following - /// expression: `e_0 + beta * e_1 = sum_x (eq(p_0, x) + beta * - /// eq(p_1, x)) expr(x)`. where `vec![e_0, beta * e_1]` will be the - /// output evaluation expressions. - pub exprs: Vec, + /// expression: `e_0 + beta * e_1 + /// = \sum_x (r^0 eq_0(X) \cdot expr_0(x) + r^1 eq_1(X) \cdot expr_1(x) + ...)`. + /// where `vec![e_0, beta * e_1]` will be the output evaluation expressions. + pub exprs: Vec>, + /// eq expression for zero checks. Length should match with `exprs` + pub eqs: Vec>, /// Positions to place the evaluations of the base inputs of this layer. - pub in_bases: Vec, - /// Positions to place the evaluations of the ext inputs of this layer. - pub in_exts: Vec, + pub in_eval_expr: Vec, /// The expressions of the evaluations from the succeeding layers, which are /// connected to the outputs of this layer. pub outs: Vec, @@ -55,21 +61,20 @@ pub struct Layer { } #[derive(Clone, Debug, Default)] -pub struct LayerWitness { - pub bases: Vec>, - pub exts: Vec>, +pub struct LayerWitness<'a, E: ExtensionField> { + pub bases: Vec>, pub num_vars: usize, } -impl Layer { +impl Layer { #[allow(clippy::too_many_arguments)] pub fn new( name: String, ty: LayerType, - exprs: Vec, + exprs: Vec>, + eqs: Vec>, challenges: Vec, - in_bases: Vec, - in_exts: Vec, + in_eval_expr: Vec, outs: Vec, expr_names: Vec, ) -> Self { @@ -85,46 +90,50 @@ impl Layer { ty, challenges, exprs, - in_bases, - in_exts, + eqs, + in_eval_expr, outs, expr_names, } } #[allow(clippy::too_many_arguments)] - pub fn prove>( + pub fn prove>( &self, + num_threads: usize, + max_num_variables: usize, wit: LayerWitness, claims: &mut [PointAndEval], challenges: &mut Vec, - transcript: &mut Trans, + transcript: &mut T, ) -> SumcheckProof { self.update_challenges(challenges, transcript); - #[allow(unused)] let (sigmas, out_points) = self.sigmas_and_points(claims, challenges); - let SumcheckProverOutput { - point: in_point, - proof, + let SumcheckLayerProof { + proof: IOPProof { proofs, point }, + .. } = match self.ty { - LayerType::Sumcheck => >::prove( + LayerType::Sumcheck => as SumcheckLayer>::prove( self, + num_threads, + max_num_variables, wit, - &out_points.slice_vector(), challenges, transcript, ), - LayerType::Zerocheck => >::prove( + LayerType::Zerocheck => as ZerocheckLayer>::prove( self, + num_threads, + max_num_variables, wit, - &out_points.slice_vector(), + &out_points, challenges, transcript, ), LayerType::Linear => { - assert!(out_points.iter().all(|point| point == &out_points[0])); - >::prove(self, wit, &out_points[0], transcript) + assert!(out_points.iter().all_equal()); + as LinearLayer>::prove(self, wit, &out_points[0], transcript) } }; @@ -138,7 +147,7 @@ impl Layer { proof } - pub fn verify>( + pub fn verify>( &self, proof: SumcheckProof, claims: &mut [PointAndEval], @@ -182,7 +191,7 @@ impl Layer { Ok(()) } - fn sigmas_and_points( + fn sigmas_and_points( &self, claims: &[PointAndEval], challenges: &[E], @@ -196,13 +205,9 @@ impl Layer { .unzip() } - fn update_challenges( - &self, - challenges: &mut Vec, - transcript: &mut impl Transcript, - ) { + fn update_challenges(&self, challenges: &mut Vec, transcript: &mut impl Transcript) { for challenge in &self.challenges { - let value = transcript.sample_and_append_challenge(b"linear layer challenge"); + let value = transcript.sample_and_append_challenge(b"layer challenge"); match challenge { Constant::Challenge(i) => { if challenges.len() <= *i { @@ -224,7 +229,7 @@ impl Layer { ) { for (value, pos) in izip!( chain![base_mle_evals, ext_mle_evals], - chain![&self.in_bases, &self.in_exts] + chain![&self.in_eval_expr, &self.in_eval_expr] ) { *(pos.entry_mut(claims)) = PointAndEval { point: point.clone(), @@ -234,20 +239,15 @@ impl Layer { } } -impl LayerWitness { - pub fn new(bases: Vec>, exts: Vec>) -> Self { - assert!(!bases.is_empty() || !exts.is_empty()); +impl<'a, E: ExtensionField> LayerWitness<'a, E> { + pub fn new(bases: Vec>) -> Self { + assert!(!bases.is_empty() || !bases.is_empty()); let num_vars = if bases.is_empty() { - log2(exts[0].len()) + log2(bases[0].evaluations().len()) } else { - log2(bases[0].len()) + log2(bases[0].evaluations().len()) } as usize; - assert!(bases.iter().all(|b| b.len() == 1 << num_vars)); - assert!(exts.iter().all(|e| e.len() == 1 << num_vars)); - Self { - bases, - exts, - num_vars, - } + assert!(bases.iter().all(|b| b.evaluations().len() == 1 << num_vars)); + Self { bases, num_vars } } } diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs index 8387c4d80..8ab947845 100644 --- a/gkr_iop/src/gkr/layer/linear_layer.rs +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -1,93 +1,78 @@ use ff_ext::ExtensionField; -use itertools::{Itertools, izip}; -use subprotocols::{ - error::VerifierError, - expression::Point, - sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, - utils::{evaluate_mle_ext, evaluate_mle_inplace}, -}; +use itertools::Itertools; +use multilinear_extensions::{mle::Point, utils::eval_by_expr_with_instance}; +use sumcheck::structs::VerifierError; use transcript::Transcript; use crate::error::BackendError; use super::{Layer, LayerWitness}; +pub struct LinearLayerProof { + evals: Vec, + point: Point, +} + +pub struct LayerClaims { + pub in_point: Point, + pub evals: Vec, +} pub trait LinearLayer { fn prove( &self, wit: LayerWitness, out_point: &Point, transcript: &mut impl Transcript, - ) -> SumcheckProverOutput; + ) -> LinearLayerProof; fn verify( &self, - proof: SumcheckProof, + proof: LinearLayerProof, sigmas: &[E], out_point: &Point, challenges: &[E], transcript: &mut impl Transcript, - ) -> Result, BackendError>; + ) -> Result, BackendError>; } -impl LinearLayer for Layer { +impl LinearLayer for Layer { fn prove( &self, wit: LayerWitness, out_point: &Point, transcript: &mut impl Transcript, - ) -> SumcheckProverOutput { - let base_mle_evals = wit + ) -> LinearLayerProof { + let evals = wit .bases .iter() - .map(|base| evaluate_mle_ext(base, out_point)) - .collect_vec(); - - transcript.append_field_element_exts(&base_mle_evals); - - let ext_mle_evals = wit - .exts - .into_iter() - .map(|mut ext| evaluate_mle_inplace(&mut ext, out_point)) + .map(|base| base.evaluate(&out_point)) .collect_vec(); - transcript.append_field_element_exts(&ext_mle_evals); + transcript.append_field_element_exts(&evals); - SumcheckProverOutput { - proof: SumcheckProof { - univariate_polys: vec![], - ext_mle_evals, - base_mle_evals, - }, + LinearLayerProof { + evals, point: out_point.clone(), } } fn verify( &self, - proof: SumcheckProof, + proof: LinearLayerProof, sigmas: &[E], out_point: &Point, challenges: &[E], transcript: &mut impl Transcript, - ) -> Result, BackendError> { - let SumcheckProof { - univariate_polys: _, - ext_mle_evals, - base_mle_evals, - } = proof; + ) -> Result, BackendError> { + let LinearLayerProof { evals, .. } = proof; - transcript.append_field_element_exts(&ext_mle_evals); - transcript.append_field_element_exts(&base_mle_evals); + transcript.append_field_element_exts(&evals); - for (sigma, expr, expr_name) in izip!(sigmas, &self.exprs, &self.expr_names) { - let got = expr.evaluate( - &ext_mle_evals, - &base_mle_evals, - &[out_point], - &[], - challenges, - ); + for ((sigma, expr), expr_name) in sigmas.iter().zip_eq(&self.exprs).zip_eq(&self.expr_names) + { + let got = eval_by_expr_with_instance(&[], &evals, &[], &[], &challenges, &expr) + .right() + .unwrap(); if *sigma != got { return Err(BackendError::LayerVerificationFailed( self.name.clone(), @@ -96,9 +81,8 @@ impl LinearLayer for Layer { } } - Ok(SumcheckClaims { - base_mle_evals, - ext_mle_evals, + Ok(LayerClaims { + evals, in_point: out_point.clone(), }) } diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index 01bd00e8f..8b4841ea8 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -1,76 +1,121 @@ +use std::marker::PhantomData; + +use either::Either; use ff_ext::ExtensionField; -use subprotocols::sumcheck::{ - SumcheckClaims, SumcheckProof, SumcheckProverOutput, SumcheckProverState, SumcheckVerifierState, +use itertools::Itertools; +use multilinear_extensions::{ + utils::eval_by_expr_with_instance, virtual_poly::VPAuxInfo, + virtual_polys::VirtualPolynomialsBuilder, +}; +use sumcheck::structs::{ + IOPProof, IOPProverState, IOPVerifierState, SumCheckSubClaim, VerifierError, }; use transcript::Transcript; -use crate::{ - error::BackendError, - utils::{SliceVector, SliceVectorMut}, -}; +use crate::error::BackendError; -use super::{Layer, LayerWitness}; +use super::{Layer, LayerWitness, linear_layer::LayerClaims}; +pub struct SumcheckLayerProof { + pub proof: IOPProof, + pub evals: Vec, +} pub trait SumcheckLayer { #[allow(clippy::too_many_arguments)] - fn prove( + fn prove<'a>( &self, - wit: LayerWitness, - out_points: &[&[E]], + num_threads: usize, + max_num_variables: usize, + wit: LayerWitness<'a, E>, challenges: &[E], transcript: &mut impl Transcript, - ) -> SumcheckProverOutput; + ) -> SumcheckLayerProof; fn verify( &self, - proof: SumcheckProof, + proof: SumcheckLayerProof, sigma: &E, - out_points: Vec<&[E]>, + max_num_variables: usize, challenges: &[E], transcript: &mut impl Transcript, - ) -> Result, BackendError>; + ) -> Result, BackendError>; } -impl SumcheckLayer for Layer { - fn prove( +impl SumcheckLayer for Layer { + fn prove<'a>( &self, - mut wit: LayerWitness, - out_points: &[&[E]], + num_threads: usize, + max_num_variables: usize, + wit: LayerWitness<'a, E>, challenges: &[E], transcript: &mut impl Transcript, - ) -> SumcheckProverOutput { - let prover_state = SumcheckProverState::new( - self.exprs[0].clone(), - out_points, - wit.exts.slice_vector_mut(), - wit.bases.slice_vector(), - challenges, + ) -> SumcheckLayerProof { + let builder = VirtualPolynomialsBuilder::new_with_mles( + num_threads, + max_num_variables, + wit.bases.iter().map(|mle| Either::Left(mle)).collect_vec(), + ); + let (proof, prover_state) = IOPProverState::prove( + builder.to_virtual_polys(&[self.exprs[0].clone()], challenges), transcript, ); - - prover_state.prove() + SumcheckLayerProof { + proof, + evals: prover_state.get_mle_flatten_final_evaluations(), + } } fn verify( &self, - proof: SumcheckProof, + proof: SumcheckLayerProof, sigma: &E, - out_points: Vec<&[E]>, + max_num_variables: usize, challenges: &[E], transcript: &mut impl Transcript, - ) -> Result, BackendError> { - let verifier_state = SumcheckVerifierState::new( + ) -> Result, BackendError> { + let SumcheckLayerProof { + proof: IOPProof { proofs, .. }, + evals, + } = proof; + + let SumCheckSubClaim { + point: in_point, + expected_evaluation, + } = IOPVerifierState::verify( *sigma, - self.exprs[0].clone(), - out_points, - proof, - challenges, + &IOPProof { + point: vec![], // final claimed point will be derived from sumcheck protocol + proofs, + }, + &VPAuxInfo { + max_degree: self.exprs[0].degree(), + max_num_variables, + phantom: PhantomData, + }, transcript, - vec![], ); - verifier_state - .verify() - .map_err(|e| BackendError::LayerVerificationFailed(self.name.clone(), e)) + // Check the final evaluations. + let got_claim = + eval_by_expr_with_instance(&[], &evals, &[], &[], challenges, &self.exprs[0]) + .right() + .unwrap(); + + if got_claim != expected_evaluation { + return Err(BackendError::LayerVerificationFailed( + "sumcheck verify failed".to_string(), + VerifierError::ClaimNotMatch( + self.exprs[0].clone(), + expected_evaluation, + got_claim, + self.expr_names[0].clone(), + ), + )); + } + + Ok(LayerClaims { + in_point: in_point.into_iter().map(|c| c.elements).collect_vec(), + evals, + }) } } diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 4df79c405..ead623362 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -1,77 +1,184 @@ +use std::marker::PhantomData; + +use either::Either; use ff_ext::ExtensionField; -use subprotocols::{ - sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, - zerocheck::{ZerocheckProverState, ZerocheckVerifierState}, +use itertools::Itertools; +use multilinear_extensions::{ + Expression, + mle::{MultilinearExtension, Point}, + utils::eval_by_expr_with_instance, + virtual_poly::{VPAuxInfo, build_eq_x_r_vec, eq_eval}, + virtual_polys::VirtualPolynomialsBuilder, +}; +use p3_field::dot_product; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use sumcheck::{ + macros::{entered_span, exit_span}, + structs::{IOPProof, IOPProverState, IOPVerifierState, SumCheckSubClaim, VerifierError}, + util::get_challenge_pows, }; use transcript::Transcript; -use crate::{ - error::BackendError, - utils::{SliceVector, SliceVectorMut}, -}; +use crate::error::BackendError; -use super::{Layer, LayerWitness}; +use super::{Layer, LayerWitness, linear_layer::LayerClaims, sumcheck_layer::SumcheckLayerProof}; pub trait ZerocheckLayer { #[allow(clippy::too_many_arguments)] fn prove( &self, + num_threads: usize, + max_num_variables: usize, wit: LayerWitness, - out_points: &[&[E]], + out_points: &[Point], challenges: &[E], transcript: &mut impl Transcript, - ) -> SumcheckProverOutput; + ) -> SumcheckLayerProof; fn verify( &self, - proof: SumcheckProof, + max_num_variables: usize, + proof: SumcheckLayerProof, sigmas: Vec, - out_points: Vec<&[E]>, + out_points: &[Point], challenges: &[E], transcript: &mut impl Transcript, - ) -> Result, BackendError>; + ) -> Result, BackendError>; } -impl ZerocheckLayer for Layer { +impl ZerocheckLayer for Layer { fn prove( &self, + num_threads: usize, + max_num_variables: usize, mut wit: LayerWitness, - out_points: &[&[E]], + out_points: &[Point], challenges: &[E], transcript: &mut impl Transcript, - ) -> SumcheckProverOutput { - let prover_state = ZerocheckProverState::new( - self.exprs.clone(), - out_points, - wit.exts.slice_vector_mut(), - wit.bases.slice_vector(), - challenges, - transcript, - ); + ) -> SumcheckLayerProof { + assert_eq!(self.exprs.len(), out_points.len()); + + let span = entered_span!("build_out_points_eq"); + let eqs = out_points + .par_iter() + .map(|point| { + MultilinearExtension::from_evaluations_ext_vec( + point.len(), + build_eq_x_r_vec(&point), + ) + }) + .collect::>(); + exit_span!(span); - prover_state.prove() + let builder = VirtualPolynomialsBuilder::new_with_mles( + num_threads, + max_num_variables, + wit.bases + .iter_mut() + .map(|mle| Either::Right(mle)) + // extend eqs to the end of wit + .chain(eqs.iter_mut().map(|eq| Either::Right(eq))) + .collect_vec(), + ); + let alpha_pows = get_challenge_pows(self.exprs.len(), transcript) + .into_iter() + .map(|r| Expression::Constant(Either::Right(r))) + .collect_vec(); + let expr = self + .exprs + .iter() + .zip_eq(self.eqs) + .zip_eq(alpha_pows) + .map(|((expr, eq), alpha)| alpha * eq * expr) + .sum::>(); + let (proof, prover_state) = + IOPProverState::prove(builder.to_virtual_polys(&[expr], challenges), transcript); + SumcheckLayerProof { + proof, + evals: prover_state.get_mle_flatten_final_evaluations(), + } } fn verify( &self, - proof: SumcheckProof, + max_num_variables: usize, + proof: SumcheckLayerProof, sigmas: Vec, - out_points: Vec<&[E]>, + out_points: &[Point], challenges: &[E], transcript: &mut impl Transcript, - ) -> Result, BackendError> { - let verifier_state = ZerocheckVerifierState::new( - sigmas, - self.exprs.clone(), - vec![], - out_points, - proof, - challenges, + ) -> Result, BackendError> { + assert_eq!(sigmas.len(), out_points.len()); + let SumcheckLayerProof { + proof: IOPProof { proofs, .. }, + mut evals, + } = proof; + + let alpha_pows = get_challenge_pows(self.exprs.len(), transcript); + + let sigma: E = dot_product(alpha_pows.iter().cloned(), sigmas.iter().cloned()); + + let SumCheckSubClaim { + point: in_point, + expected_evaluation, + } = IOPVerifierState::verify( + sigma, + &IOPProof { + point: vec![], // final claimed point will be derived from sumcheck protocol + proofs, + }, + &VPAuxInfo { + max_degree: self.exprs[0].degree(), + max_num_variables, + phantom: PhantomData, + }, transcript, ); + let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); + + // eval eq and set to respective witin + out_points + .iter() + .map(|out_point| eq_eval(out_point, &in_point)) + .zip(&self.eqs) + .for_each(|(eval, eq_expr)| match eq_expr { + Expression::WitIn(witin_id) => evals[*witin_id as usize] = eval, + _ => unreachable!(), + }); + + // check the final evaluations. + let got_claim = self + .exprs + .iter() + .zip_eq(&self.eqs) + .zip_eq(alpha_pows) + .map(|((expr, eq_expr), alpha)| { + alpha + * eval_by_expr_with_instance( + &[], + &evals, + &[], + &[], + challenges, + &(expr * eq_expr), + ) + .right() + .unwrap() + }) + .sum::(); + + if got_claim != expected_evaluation { + return Err(BackendError::LayerVerificationFailed( + "sumcheck verify failed".to_string(), + VerifierError::ClaimNotMatch( + self.exprs[0].clone(), + expected_evaluation, + got_claim, + self.expr_names[0].clone(), + ), + )); + } - verifier_state - .verify() - .map_err(|e| BackendError::LayerVerificationFailed(self.name.clone(), e)) + Ok(LayerClaims { in_point, evals }) } } diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 3c359bdfe..2838cc8a9 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -101,10 +101,10 @@ impl MockProver { } } } - for (in_pos, base) in izip!(&layer.in_bases, &layer_wit.bases) { + for (in_pos, base) in izip!(&layer.in_eval_expr, &layer_wit.bases) { *(in_pos.entry_mut(&mut evaluations)) = VectorType::Base(base.clone()); } - for (in_pos, ext) in izip!(&layer.in_exts, &layer_wit.exts) { + for (in_pos, ext) in izip!(&layer.in_eval_expr, &layer_wit.exts) { *(in_pos.entry_mut(&mut evaluations)) = VectorType::Ext(ext.clone()); } } diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 50dd09384..3d097e7e2 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -768,7 +768,7 @@ where let mut layer_wits = vec![ LayerWitness { bases: vec![], - exts: vec![], + bases: vec![], num_vars: 1 }; n_layers diff --git a/multilinear_extensions/src/virtual_polys.rs b/multilinear_extensions/src/virtual_polys.rs index 04b435f9d..3580990a7 100644 --- a/multilinear_extensions/src/virtual_polys.rs +++ b/multilinear_extensions/src/virtual_polys.rs @@ -53,6 +53,20 @@ impl<'a, E: ExtensionField> VirtualPolynomialsBuilder<'a, E> { _phantom: PhantomData, } } + + /// create a new `VirtualPolynomialsBuilder` with the given number and max number of vars + pub fn new_with_mles( + num_threads: usize, + max_num_variables: usize, + mles: Vec, &'a mut MultilinearExtension<'a, E>>>, + ) -> Self { + let mut builder = Self::new(num_threads, max_num_variables); + mles.into_iter().for_each(|mle| { + let _ = builder.lift(mle); + }); + builder + } + /// lifts a reference to a `MultilinearExtension` into an `Expression::WitIn` /// /// assigns a unique witness index based on pointer address, reusing the same index diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index 520276b0f..9c822c386 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -16,6 +16,7 @@ ff_ext = { path = "../ff_ext" } itertools.workspace = true multilinear_extensions = { path = "../multilinear_extensions", features = ["parallel"] } p3.workspace = true +thiserror.workspace = true rayon.workspace = true serde.workspace = true sumcheck_macro = { path = "../sumcheck_macro" } diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index 21db3b164..cc4a3e05e 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -1,6 +1,9 @@ use ff_ext::ExtensionField; -use multilinear_extensions::{virtual_poly::VirtualPolynomial, virtual_polys::PolyMeta}; +use multilinear_extensions::{ + Expression, mle::Point, virtual_poly::VirtualPolynomial, virtual_polys::PolyMeta, +}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use thiserror::Error; use transcript::Challenge; /// An IOP proof is a collections of @@ -12,7 +15,7 @@ use transcript::Challenge; deserialize = "E::BaseField: DeserializeOwned" ))] pub struct IOPProof { - pub point: Vec, + pub point: Point, pub proofs: Vec>, } impl IOPProof { @@ -75,3 +78,9 @@ pub struct SumCheckSubClaim { /// the expected evaluation pub expected_evaluation: E, } + +#[derive(Clone, Debug, Error)] +pub enum VerifierError { + #[error("Claim not match: expr: {0:?}\n (expr name: {3:?})\n expect: {1:?}, got: {2:?}")] + ClaimNotMatch(Expression, E, E, String), +} diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index c72ffad15..9c958eabc 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -18,6 +18,7 @@ use multilinear_extensions::{ }; use p3::field::Field; use rayon::{prelude::ParallelIterator, slice::ParallelSliceMut}; +use transcript::Transcript; use crate::structs::IOPProverState; @@ -326,6 +327,20 @@ pub fn optimal_sumcheck_threads(num_vars: usize) -> usize { } } +/// Derive challenge from transcript and return all power results of the challenge. +pub fn get_challenge_pows( + size: usize, + transcript: &mut impl Transcript, +) -> Vec { + let alpha = transcript + .sample_and_append_challenge(b"combine subset evals") + .elements; + + std::iter::successors(Some(E::ONE), move |prev| Some(*prev * alpha)) + .take(size) + .collect() +} + #[derive(Clone, Copy, Debug)] /// util collection to support fundamental operation pub struct AdditiveArray(pub [F; N]); From 123cca33f3f75bbc2b4858c496aacdc65476a06f Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 16 May 2025 23:26:25 +0800 Subject: [PATCH 02/28] finish layer logic --- ceno_zkvm/src/structs.rs | 7 +-- gkr_iop/examples/multi_layer_logup.rs | 2 +- gkr_iop/src/evaluation.rs | 44 +++++++------- gkr_iop/src/gkr.rs | 42 +++++++------ gkr_iop/src/gkr/layer.rs | 79 +++++++++---------------- gkr_iop/src/gkr/layer/sumcheck_layer.rs | 4 +- multilinear_extensions/src/mle.rs | 7 +++ 7 files changed, 86 insertions(+), 99 deletions(-) diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index 02296a7f1..319a8c9f4 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -68,12 +68,7 @@ pub enum RAMType { impl_expr_from_unsigned!(RAMType); -/// A point and the evaluation of this point. -#[derive(Clone, Debug, PartialEq)] -pub struct PointAndEval { - pub point: Point, - pub eval: F, -} +pub type PointAndEval = multilinear_extensions::mle::PointAndEval; impl Default for PointAndEval { fn default() -> Self { diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index 3a664fd46..36acc812f 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -4,7 +4,7 @@ use ff_ext::ExtensionField; use gkr_iop::{ ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, - evaluation::{EvalExpression, PointAndEval}, + evaluation::EvalExpression, gkr::{ GKRCircuitWitness, GKRProverOutput, layer::{Layer, LayerType, LayerWitness}, diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index 1d406c8fa..20f7d9cf7 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -1,55 +1,55 @@ use ff_ext::ExtensionField; use itertools::{Itertools, izip}; -use multilinear_extensions::virtual_poly::build_eq_x_r_vec_sequential; -use serde::{Deserialize, Serialize}; -use subprotocols::expression::{Constant, Point}; +use multilinear_extensions::{ + Expression, mle::PointAndEval, utils::eval_by_expr_with_fixed, + virtual_poly::build_eq_x_r_vec_sequential, +}; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; /// Evaluation expression for the gkr layer reduction and PCS opening /// preparation. #[derive(Clone, Debug, Serialize, Deserialize)] -pub enum EvalExpression { +#[serde(bound = "E: ExtensionField + DeserializeOwned")] +pub enum EvalExpression { /// Single entry in the evaluation vector. Single(usize), /// Linear expression of an entry with the scalar and offset. - Linear(usize, Constant, Constant), + Linear(usize, Box>, Box>), /// Merging multiple evaluations which denotes a partition of the original /// polynomial. `(usize, Constant)` denote the modification of the point. /// For example, when it receive a point `(p0, p1, p2, p3)` from a /// succeeding layer, `vec![(2, c0), (4, c1)]` will modify the point to /// `(p0, p1, c0, p2, c1, p3)`. where the indices specify how the /// partition applied to the original polynomial. - Partition(Vec>, Vec<(usize, Constant)>), + Partition( + Vec>>, + Vec<(usize, Box>)>, + ), } -#[derive(Clone, Debug, Default)] -pub struct PointAndEval { - pub point: Point, - pub eval: E, -} - -impl Default for EvalExpression { +impl Default for EvalExpression { fn default() -> Self { EvalExpression::Single(0) } } -impl EvalExpression { - pub fn evaluate( - &self, - evals: &[PointAndEval], - challenges: &[E], - ) -> PointAndEval { +fn evaluate(expr: &Expression, challenges: &[E]) -> E { + eval_by_expr_with_fixed(&[], &[], &[], challenges, expr) +} + +impl EvalExpression { + pub fn evaluate(&self, evals: &[PointAndEval], challenges: &[E]) -> PointAndEval { match self { EvalExpression::Single(i) => evals[*i].clone(), EvalExpression::Linear(i, c0, c1) => PointAndEval { point: evals[*i].point.clone(), - eval: evals[*i].eval * c0.evaluate(challenges) + c1.evaluate(challenges), + eval: evals[*i].eval * evaluate(c0, challenges) + evaluate(c1, challenges), }, EvalExpression::Partition(parts, indices) => { assert!(izip!(indices.iter(), indices.iter().skip(1)).all(|(a, b)| a.0 < b.0)); let vars = indices .iter() - .map(|(_, c)| c.evaluate(challenges)) + .map(|(_, c)| evaluate(c, challenges)) .collect_vec(); let parts = parts @@ -61,7 +61,7 @@ impl EvalExpression { let mut new_point = parts[0].point.to_vec(); for (index_in_point, c) in indices { - new_point.insert(*index_in_point, c.evaluate(challenges)); + new_point.insert(*index_in_point, evaluate(c, challenges)); } let eq = build_eq_x_r_vec_sequential(&vars); diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 89c789e78..874f1f744 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,26 +1,25 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use layer::{Layer, LayerWitness}; +use multilinear_extensions::mle::{Point, PointAndEval}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use subprotocols::{expression::Point, sumcheck::SumcheckProof}; +use subprotocols::sumcheck::SumcheckProof; use transcript::Transcript; -use crate::{ - error::BackendError, - evaluation::{EvalExpression, PointAndEval}, -}; +use crate::{error::BackendError, evaluation::EvalExpression}; pub mod layer; pub mod mock; #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct GKRCircuit { - pub layers: Vec, +#[serde(bound = "E: ExtensionField + DeserializeOwned")] +pub struct GKRCircuit { + pub layers: Vec>, pub n_challenges: usize, pub n_evaluations: usize, - pub base_openings: Vec<(usize, EvalExpression)>, - pub ext_openings: Vec<(usize, EvalExpression)>, + pub base_openings: Vec<(usize, EvalExpression)>, + pub ext_openings: Vec<(usize, EvalExpression)>, } #[derive(Clone, Debug, Default)] @@ -58,23 +57,29 @@ pub struct Evaluation { pub struct GKRClaims(pub Vec); -impl GKRCircuit { - pub fn prove( +impl GKRCircuit { + pub fn prove( &self, + num_threads: usize, + max_num_variables: usize, circuit_wit: GKRCircuitWitness, out_evals: &[PointAndEval], challenges: &[E], transcript: &mut impl Transcript, - ) -> Result>, BackendError> - where - E: ExtensionField, - { + ) -> Result>, BackendError> { let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); let mut challenges = challenges.to_vec(); let sumcheck_proofs = izip!(&self.layers, circuit_wit.layers) .map(|(layer, layer_wit)| { - layer.prove(layer_wit, &mut evaluations, &mut challenges, transcript) + layer.prove( + num_threads, + max_num_variables, + layer_wit, + &mut evaluations, + &mut challenges, + transcript, + ) }) .collect_vec(); @@ -86,8 +91,9 @@ impl GKRCircuit { }) } - pub fn verify( + pub fn verify( &self, + max_num_variables: usize, gkr_proof: GKRProof, out_evals: &[PointAndEval], challenges: &[E], @@ -110,7 +116,7 @@ impl GKRCircuit { )) } - fn opening_evaluations( + fn opening_evaluations( &self, evaluations: &[PointAndEval], challenges: &[E], diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 31dc9ff66..1637404e6 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -1,23 +1,17 @@ use ark_std::log2; use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; -use linear_layer::LinearLayer; -use multilinear_extensions::{Expression, mle::MultilinearExtension}; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use subprotocols::{ - expression::{Constant, Point}, - sumcheck::{SumcheckClaims, SumcheckProof, SumcheckProverOutput}, +use linear_layer::{LayerClaims, LinearLayer}; +use multilinear_extensions::{ + Expression, + mle::{MultilinearExtension, Point, PointAndEval}, }; -use sumcheck::structs::IOPProof; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use sumcheck_layer::{SumcheckLayer, SumcheckLayerProof}; use transcript::Transcript; use zerocheck_layer::ZerocheckLayer; -use crate::{ - error::BackendError, - evaluation::{EvalExpression, PointAndEval}, - utils::SliceVector, -}; +use crate::{error::BackendError, evaluation::EvalExpression}; pub mod linear_layer; pub mod sumcheck_layer; @@ -39,7 +33,7 @@ pub struct Layer { pub name: String, pub ty: LayerType, /// Challenges generated at the beginning of the layer protocol. - pub challenges: Vec, + pub challenges: Vec>, /// Expressions to prove in this layer. For zerocheck and linear layers, /// each expression corresponds to an output. While in sumcheck, there /// is only 1 expression, which corresponds to the sum of all outputs. @@ -51,10 +45,10 @@ pub struct Layer { /// eq expression for zero checks. Length should match with `exprs` pub eqs: Vec>, /// Positions to place the evaluations of the base inputs of this layer. - pub in_eval_expr: Vec, + pub in_eval_expr: Vec>, /// The expressions of the evaluations from the succeeding layers, which are /// connected to the outputs of this layer. - pub outs: Vec, + pub outs: Vec>, // For debugging purposes pub expr_names: Vec, @@ -73,9 +67,9 @@ impl Layer { ty: LayerType, exprs: Vec>, eqs: Vec>, - challenges: Vec, - in_eval_expr: Vec, - outs: Vec, + challenges: Vec>, + in_eval_expr: Vec>, + outs: Vec>, expr_names: Vec, ) -> Self { let mut expr_names = expr_names; @@ -106,14 +100,11 @@ impl Layer { claims: &mut [PointAndEval], challenges: &mut Vec, transcript: &mut T, - ) -> SumcheckProof { + ) -> SumcheckLayerProof { self.update_challenges(challenges, transcript); let (sigmas, out_points) = self.sigmas_and_points(claims, challenges); - let SumcheckLayerProof { - proof: IOPProof { proofs, point }, - .. - } = match self.ty { + let sumcheck_layer_proof = match self.ty { LayerType::Sumcheck => as SumcheckLayer>::prove( self, num_threads, @@ -139,17 +130,17 @@ impl Layer { self.update_claims( claims, - &proof.base_mle_evals, - &proof.ext_mle_evals, - &in_point, + &sumcheck_layer_proof.evals, + &sumcheck_layer_proof.proof.point, ); - proof + sumcheck_layer_proof } pub fn verify>( &self, - proof: SumcheckProof, + max_num_variables: usize, + proof: SumcheckLayerProof, claims: &mut [PointAndEval], challenges: &mut Vec, transcript: &mut Trans, @@ -157,36 +148,33 @@ impl Layer { self.update_challenges(challenges, transcript); let (sigmas, points) = self.sigmas_and_points(claims, challenges); - let SumcheckClaims { - in_point, - base_mle_evals, - ext_mle_evals, - } = match self.ty { - LayerType::Sumcheck => >::verify( + let LayerClaims { in_point, evals } = match self.ty { + LayerType::Sumcheck => as SumcheckLayer>::verify( self, + max_num_variables, proof, &sigmas.iter().cloned().sum(), - points.slice_vector(), challenges, transcript, )?, - LayerType::Zerocheck => >::verify( + LayerType::Zerocheck => as ZerocheckLayer>::verify( self, + max_num_variables, proof, sigmas, - points.slice_vector(), + &points, challenges, transcript, )?, LayerType::Linear => { assert!(points.iter().all(|point| point == &points[0])); - >::verify( + as LinearLayer>::verify( self, proof, &sigmas, &points[0], challenges, transcript, )? } }; - self.update_claims(claims, &base_mle_evals, &ext_mle_evals, &in_point); + self.update_claims(claims, &evals, &in_point); Ok(()) } @@ -220,17 +208,8 @@ impl Layer { } } - fn update_claims( - &self, - claims: &mut [PointAndEval], - base_mle_evals: &[E], - ext_mle_evals: &[E], - point: &Point, - ) { - for (value, pos) in izip!( - chain![base_mle_evals, ext_mle_evals], - chain![&self.in_eval_expr, &self.in_eval_expr] - ) { + fn update_claims(&self, claims: &mut [PointAndEval], evals: &[E], point: &Point) { + for (value, pos) in izip!(chain![evals], chain![&self.in_eval_expr]) { *(pos.entry_mut(claims)) = PointAndEval { point: point.clone(), eval: *value, diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index 8b4841ea8..7173b6c08 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -33,9 +33,9 @@ pub trait SumcheckLayer { fn verify( &self, + max_num_variables: usize, proof: SumcheckLayerProof, sigma: &E, - max_num_variables: usize, challenges: &[E], transcript: &mut impl Transcript, ) -> Result, BackendError>; @@ -67,9 +67,9 @@ impl SumcheckLayer for Layer { fn verify( &self, + max_num_variables: usize, proof: SumcheckLayerProof, sigma: &E, - max_num_variables: usize, challenges: &[E], transcript: &mut impl Transcript, ) -> Result, BackendError> { diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index e165be0f9..93c87063f 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -19,6 +19,13 @@ use std::fmt::Debug; /// A point is a vector of num_var length pub type Point = Vec; +/// A point and the evaluation of this point. +#[derive(Clone, Debug, PartialEq, Default)] +pub struct PointAndEval { + pub point: Point, + pub eval: F, +} + impl Debug for MultilinearExtension<'_, E> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{:?}", self.evaluations()) From 2d4224f21d8e2d567bb020cc51a73497a0aa026e Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Sat, 17 May 2025 18:39:20 +0800 Subject: [PATCH 03/28] gkr layer fixed --- ceno_zkvm/src/scheme/mock_prover.rs | 2 +- ceno_zkvm/src/scheme/utils.rs | 176 +----------------- gkr_iop/src/chip.rs | 10 +- gkr_iop/src/chip/builder.rs | 34 +--- gkr_iop/src/chip/protocol.rs | 9 +- gkr_iop/src/evaluation.rs | 4 +- gkr_iop/src/gkr.rs | 18 +- gkr_iop/src/gkr/layer.rs | 21 ++- gkr_iop/src/gkr/layer/linear_layer.rs | 22 +-- gkr_iop/src/gkr/layer/sumcheck_layer.rs | 6 + gkr_iop/src/gkr/mock.rs | 137 ++++++++++---- gkr_iop/src/lib.rs | 8 +- gkr_iop/src/precompiles/bitwise_keccakf.rs | 27 +-- gkr_iop/src/precompiles/utils.rs | 9 +- multilinear_extensions/src/expression.rs | 198 ++++++++++++++++++++- multilinear_extensions/src/mle.rs | 24 +++ 16 files changed, 402 insertions(+), 303 deletions(-) diff --git a/ceno_zkvm/src/scheme/mock_prover.rs b/ceno_zkvm/src/scheme/mock_prover.rs index b422cd8fa..a9c46325f 100644 --- a/ceno_zkvm/src/scheme/mock_prover.rs +++ b/ceno_zkvm/src/scheme/mock_prover.rs @@ -530,7 +530,7 @@ impl<'a, E: ExtensionField + Hash> MockProver { { let right = -right.as_ref(); - let left_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, left); + let left_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, &left); let left_evaluated = left_evaluated.get_base_field_vec(); let right_evaluated = wit_infer_by_expr(&[], wits_in, &[], pi, &challenge, &right); diff --git a/ceno_zkvm/src/scheme/utils.rs b/ceno_zkvm/src/scheme/utils.rs index 74fb2cb07..dd50b3f3b 100644 --- a/ceno_zkvm/src/scheme/utils.rs +++ b/ceno_zkvm/src/scheme/utils.rs @@ -2,6 +2,7 @@ use std::sync::Arc; use ff_ext::ExtensionField; use itertools::Itertools; +pub use multilinear_extensions::wit_infer_by_expr; use multilinear_extensions::{ commutative_op_mle_pair, mle::{ArcMultilinearExtension, FieldType, IntoMLE, MultilinearExtension}, @@ -242,123 +243,6 @@ pub(crate) fn infer_tower_product_witness( wit_layers } -pub(crate) fn wit_infer_by_expr<'a, E: ExtensionField, const N: usize>( - fixed: &[ArcMultilinearExtension<'a, E>], - witnesses: &[ArcMultilinearExtension<'a, E>], - structual_witnesses: &[ArcMultilinearExtension<'a, E>], - instance: &[ArcMultilinearExtension<'a, E>], - challenges: &[E; N], - expr: &Expression, -) -> ArcMultilinearExtension<'a, E> { - expr.evaluate_with_instance::>( - &|f| fixed[f.0].clone(), - &|witness_id| witnesses[witness_id as usize].clone(), - &|witness_id, _, _, _| structual_witnesses[witness_id as usize].clone(), - &|i| instance[i.0].clone(), - &|scalar| { - let scalar: ArcMultilinearExtension = - Arc::new(MultilinearExtension::from_evaluations_vec( - 0, - vec![scalar.left().expect("do not support extension field")], - )); - scalar - }, - &|challenge_id, pow, scalar, offset| { - // TODO cache challenge power to be acquired once for each power - let challenge = challenges[challenge_id as usize]; - let challenge: ArcMultilinearExtension = - Arc::new(MultilinearExtension::from_evaluations_ext_vec( - 0, - vec![challenge.exp_u64(pow as u64) * scalar + offset], - )); - challenge - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(MultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] + b[0]], - )), - (1, _) => Arc::new(MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] + *b) - .collect(), - )), - (_, 1) => Arc::new(MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a + b[0]) - .collect(), - )), - (_, _) => Arc::new(MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .zip(b.par_iter()) - .with_min_len(MIN_PAR_SIZE) - .map(|(a, b)| *a + *b) - .collect(), - )), - } - }) - }, - &|a, b| { - commutative_op_mle_pair!(|a, b| { - match (a.len(), b.len()) { - (1, 1) => Arc::new(MultilinearExtension::from_evaluation_vec_smart( - 0, - vec![a[0] * b[0]], - )), - (1, _) => Arc::new(MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(b.len()), - b.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|b| a[0] * *b) - .collect(), - )), - (_, 1) => Arc::new(MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|a| *a * b[0]) - .collect(), - )), - (_, _) => { - assert_eq!(a.len(), b.len()); - // we do the pointwise evaluation multiplication here without involving FFT - // the evaluations outside of range will be checked via sumcheck + identity polynomial - Arc::new(MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(a.len()), - a.par_iter() - .zip(b.par_iter()) - .with_min_len(MIN_PAR_SIZE) - .map(|(a, b)| *a * *b) - .collect(), - )) - } - } - }) - }, - &|x, a, b| { - op_mle_xa_b!(|x, a, b| { - assert_eq!(a.len(), 1); - assert_eq!(b.len(), 1); - let (a, b) = (a[0], b[0]); - Arc::new(MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(x.len()), - x.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|x| a * *x + b) - .collect(), - )) - }) - }, - ) -} - #[cfg(test)] mod tests { @@ -643,62 +527,4 @@ mod tests { ])) ); } - - #[test] - fn test_wit_infer_by_expr_base_field() { - type E = ff_ext::GoldilocksExt2; - type B = p3::goldilocks::Goldilocks; - let mut cs = ConstraintSystem::::new(|| "test"); - let mut cb = CircuitBuilder::new(&mut cs); - let a = cb.create_witin(|| "a"); - let b = cb.create_witin(|| "b"); - let c = cb.create_witin(|| "c"); - - let expr: Expression = a.expr() + b.expr() + a.expr() * b.expr() + (c.expr() * 3 + 2); - - let res = wit_infer_by_expr( - &[], - &[ - vec![B::from_u64(1)].into_mle().into(), - vec![B::from_u64(2)].into_mle().into(), - vec![B::from_u64(3)].into_mle().into(), - ], - &[], - &[], - &[], - &expr, - ); - res.get_base_field_vec(); - } - - #[test] - fn test_wit_infer_by_expr_ext_field() { - type E = ff_ext::GoldilocksExt2; - type B = p3::goldilocks::Goldilocks; - let mut cs = ConstraintSystem::::new(|| "test"); - let mut cb = CircuitBuilder::new(&mut cs); - let a = cb.create_witin(|| "a"); - let b = cb.create_witin(|| "b"); - let c = cb.create_witin(|| "c"); - - let expr: Expression = a.expr() - + b.expr() - + a.expr() * b.expr() - + (c.expr() * 3 + 2) - + Expression::Challenge(0, 1, E::ONE, E::ONE); - - let res = wit_infer_by_expr( - &[], - &[ - vec![B::from_u64(1)].into_mle().into(), - vec![B::from_u64(2)].into_mle().into(), - vec![B::from_u64(3)].into_mle().into(), - ], - &[], - &[], - &[E::ONE], - &expr, - ); - res.get_ext_field_vec(); - } } diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index b40b66f9c..4089ac7c9 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -1,3 +1,5 @@ +use ff_ext::ExtensionField; + use crate::{evaluation::EvalExpression, gkr::layer::Layer}; pub mod builder; @@ -6,7 +8,7 @@ pub mod protocol; /// Chip stores all information required in the GKR protocol, including the /// commit phases, the GKR phase and the opening phase. #[derive(Clone, Debug, Default)] -pub struct Chip { +pub struct Chip { /// The number of base inputs committed in the whole protocol. pub n_committed_bases: usize, /// The number of ext inputs committed in the whole protocol. @@ -19,10 +21,8 @@ pub struct Chip { /// in a vector and this is the length. pub n_evaluations: usize, /// The layers of the GKR circuit, in the order outputs-to-inputs. - pub layers: Vec, + pub layers: Vec>, /// The polynomial index and evaluation expressions of the base inputs. - pub base_openings: Vec<(usize, EvalExpression)>, - /// The polynomial index and evaluation expressions of the ext inputs. - pub ext_openings: Vec<(usize, EvalExpression)>, + pub openings: Vec<(usize, EvalExpression)>, } diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index e8c5079e1..2703b91f2 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -1,5 +1,6 @@ use std::array; +use ff_ext::ExtensionField; use itertools::Itertools; use subprotocols::expression::{Constant, Witness}; @@ -10,7 +11,7 @@ use crate::{ use super::Chip; -impl Chip { +impl Chip { /// Allocate indices for committing base field polynomials. pub fn allocate_committed_base(&mut self) -> [usize; N] { self.n_committed_bases += N; @@ -29,31 +30,19 @@ impl Chip { /// processing the layer prover for each polynomial. This should be /// called at most once for each layer! #[allow(clippy::type_complexity)] - pub fn allocate_wits_in_layer( - &mut self, - ) -> ( - [(Witness, EvalExpression); M], - [(Witness, EvalExpression); N], - ) { + pub fn allocate_wits_in_layer(&mut self) -> [(Witness, EvalExpression); N] { let bases = array::from_fn(|i| { ( Witness::BasePoly(i), EvalExpression::Single(i + self.n_evaluations), ) }); - self.n_evaluations += M; - let exts = array::from_fn(|i| { - ( - Witness::ExtPoly(i), - EvalExpression::Single(i + self.n_evaluations), - ) - }); self.n_evaluations += N; - (bases, exts) + bases } /// Generate the evaluation expression for each output. - pub fn allocate_output_evals(&mut self) -> Vec + pub fn allocate_output_evals(&mut self) -> Vec> // -> [EvalExpression; N] { self.n_evaluations += N; @@ -73,19 +62,12 @@ impl Chip { /// Allocate a PCS opening action to a base polynomial with index /// `wit_index`. The `EvalExpression` represents the expression to /// compute the evaluation. - pub fn allocate_base_opening(&mut self, wit_index: usize, eval: EvalExpression) { - self.base_openings.push((wit_index, eval)); - } - - /// Allocate a PCS opening action to an ext polynomial with index - /// `wit_index`. The `EvalExpression` represents the expression to - /// compute the evaluation. - pub fn allocate_ext_opening(&mut self, wit_index: usize, eval: EvalExpression) { - self.ext_openings.push((wit_index, eval)); + pub fn allocate_base_opening(&mut self, wit_index: usize, eval: EvalExpression) { + self.openings.push((wit_index, eval)); } /// Add a layer to the circuit. - pub fn add_layer(&mut self, layer: Layer) { + pub fn add_layer(&mut self, layer: Layer) { assert_eq!(layer.outs.len(), layer.exprs.len()); match layer.ty { LayerType::Linear => { diff --git a/gkr_iop/src/chip/protocol.rs b/gkr_iop/src/chip/protocol.rs index 0374d8e6c..43b08fddf 100644 --- a/gkr_iop/src/chip/protocol.rs +++ b/gkr_iop/src/chip/protocol.rs @@ -1,16 +1,17 @@ +use ff_ext::ExtensionField; + use crate::gkr::GKRCircuit; use super::Chip; -impl Chip { +impl Chip { /// Extract information from Chip that required in the GKR phase. - pub fn gkr_circuit(&self) -> GKRCircuit { + pub fn gkr_circuit(&self) -> GKRCircuit { GKRCircuit { layers: self.layers.clone(), n_challenges: self.n_challenges, n_evaluations: self.n_evaluations, - base_openings: self.base_openings.clone(), - ext_openings: self.ext_openings.clone(), + openings: self.openings.clone(), } } } diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index 20f7d9cf7..af0d59f2d 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -78,14 +78,14 @@ impl EvalExpression { pub fn entry<'a, T>(&self, evals: &'a [T]) -> &'a T { match self { EvalExpression::Single(i) => &evals[*i], - _ => unreachable!(), + _ => panic!("invalid operation"), } } pub fn entry_mut<'a, T>(&self, evals: &'a mut [T]) -> &'a mut T { match self { EvalExpression::Single(i) => &mut evals[*i], - _ => unreachable!(), + _ => panic!("invalid operation"), } } } diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 874f1f744..6522c2105 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,9 +1,8 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; -use layer::{Layer, LayerWitness}; +use layer::{Layer, LayerWitness, sumcheck_layer::SumcheckLayerProof}; use multilinear_extensions::mle::{Point, PointAndEval}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use subprotocols::sumcheck::SumcheckProof; use transcript::Transcript; use crate::{error::BackendError, evaluation::EvalExpression}; @@ -18,8 +17,7 @@ pub struct GKRCircuit { pub n_challenges: usize, pub n_evaluations: usize, - pub base_openings: Vec<(usize, EvalExpression)>, - pub ext_openings: Vec<(usize, EvalExpression)>, + pub openings: Vec<(usize, EvalExpression)>, } #[derive(Clone, Debug, Default)] @@ -42,7 +40,7 @@ pub struct GKRProverOutput { serialize = "E::BaseField: Serialize", deserialize = "E::BaseField: DeserializeOwned" ))] -pub struct GKRProof(pub Vec>); +pub struct GKRProof(pub Vec>); #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(bound( @@ -108,7 +106,13 @@ impl GKRCircuit { let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); for (layer, layer_proof) in izip!(&self.layers, sumcheck_proofs) { - layer.verify(layer_proof, &mut evaluations, &mut challenges, transcript)?; + layer.verify( + max_num_variables, + layer_proof, + &mut evaluations, + &mut challenges, + transcript, + )?; } Ok(GKRClaims( @@ -121,7 +125,7 @@ impl GKRCircuit { evaluations: &[PointAndEval], challenges: &[E], ) -> Vec> { - chain!(&self.base_openings, &self.ext_openings) + chain!(&self.openings, &self.openings) .map(|(poly, eval)| { let poly = *poly; let PointAndEval { point, eval: value } = eval.evaluate(evaluations, challenges); diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 1637404e6..45d022e4c 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -102,7 +102,7 @@ impl Layer { transcript: &mut T, ) -> SumcheckLayerProof { self.update_challenges(challenges, transcript); - let (sigmas, out_points) = self.sigmas_and_points(claims, challenges); + let (_, out_points) = self.extract_claim_and_point(claims, challenges); let sumcheck_layer_proof = match self.ty { LayerType::Sumcheck => as SumcheckLayer>::prove( @@ -146,7 +146,7 @@ impl Layer { transcript: &mut Trans, ) -> Result<(), BackendError> { self.update_challenges(challenges, transcript); - let (sigmas, points) = self.sigmas_and_points(claims, challenges); + let (sigmas, points) = self.extract_claim_and_point(claims, challenges); let LayerClaims { in_point, evals } = match self.ty { LayerType::Sumcheck => as SumcheckLayer>::verify( @@ -179,7 +179,7 @@ impl Layer { Ok(()) } - fn sigmas_and_points( + fn extract_claim_and_point( &self, claims: &[PointAndEval], challenges: &[E], @@ -187,21 +187,24 @@ impl Layer { self.outs .iter() .map(|out| { - let tmp = out.evaluate(claims, challenges); - (tmp.eval, tmp.point) + let PointAndEval { point, eval } = out.evaluate(claims, challenges); + (eval, point) }) .unzip() } + // generate layer challenge, if have, and set to respective challenge_id index + // optional resize raw challenges vector to adapt new challenge fn update_challenges(&self, challenges: &mut Vec, transcript: &mut impl Transcript) { for challenge in &self.challenges { let value = transcript.sample_and_append_challenge(b"layer challenge"); match challenge { - Constant::Challenge(i) => { - if challenges.len() <= *i { - challenges.resize(*i + 1, E::ZERO); + Expression::Challenge(challange_id, ..) => { + let challange_id = *challange_id as usize; + if challenges.len() <= challange_id as usize { + challenges.resize(challange_id + 1, E::default()); } - challenges[*i] = value.elements; + challenges[challange_id] = value.elements; } _ => unreachable!(), } diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs index 8ab947845..68f7acb18 100644 --- a/gkr_iop/src/gkr/layer/linear_layer.rs +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -1,12 +1,12 @@ use ff_ext::ExtensionField; use itertools::Itertools; use multilinear_extensions::{mle::Point, utils::eval_by_expr_with_instance}; -use sumcheck::structs::VerifierError; +use sumcheck::structs::{IOPProof, VerifierError}; use transcript::Transcript; use crate::error::BackendError; -use super::{Layer, LayerWitness}; +use super::{Layer, LayerWitness, sumcheck_layer::SumcheckLayerProof}; pub struct LinearLayerProof { evals: Vec, @@ -23,11 +23,11 @@ pub trait LinearLayer { wit: LayerWitness, out_point: &Point, transcript: &mut impl Transcript, - ) -> LinearLayerProof; + ) -> SumcheckLayerProof; fn verify( &self, - proof: LinearLayerProof, + proof: SumcheckLayerProof, sigmas: &[E], out_point: &Point, challenges: &[E], @@ -41,7 +41,7 @@ impl LinearLayer for Layer { wit: LayerWitness, out_point: &Point, transcript: &mut impl Transcript, - ) -> LinearLayerProof { + ) -> SumcheckLayerProof { let evals = wit .bases .iter() @@ -50,22 +50,24 @@ impl LinearLayer for Layer { transcript.append_field_element_exts(&evals); - LinearLayerProof { + SumcheckLayerProof { evals, - point: out_point.clone(), + proof: IOPProof { + point: out_point.clone(), + proofs: vec![], + }, } } fn verify( &self, - proof: LinearLayerProof, + proof: SumcheckLayerProof, sigmas: &[E], out_point: &Point, challenges: &[E], transcript: &mut impl Transcript, ) -> Result, BackendError> { - let LinearLayerProof { evals, .. } = proof; - + let SumcheckLayerProof { evals, .. } = proof; transcript.append_field_element_exts(&evals); for ((sigma, expr), expr_name) in sigmas.iter().zip_eq(&self.exprs).zip_eq(&self.expr_names) diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index 7173b6c08..56031c6b9 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -7,6 +7,7 @@ use multilinear_extensions::{ utils::eval_by_expr_with_instance, virtual_poly::VPAuxInfo, virtual_polys::VirtualPolynomialsBuilder, }; +use serde::{Deserialize, Serialize, de::DeserializeOwned}; use sumcheck::structs::{ IOPProof, IOPProverState, IOPVerifierState, SumCheckSubClaim, VerifierError, }; @@ -16,6 +17,11 @@ use crate::error::BackendError; use super::{Layer, LayerWitness, linear_layer::LayerClaims}; +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound( + serialize = "E::BaseField: Serialize", + deserialize = "E::BaseField: DeserializeOwned" +))] pub struct SumcheckLayerProof { pub proof: IOPProof, pub evals: Vec, diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 2838cc8a9..0dfc15349 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -1,57 +1,90 @@ -use std::marker::PhantomData; +use std::{marker::PhantomData, sync::Arc}; use ff_ext::ExtensionField; use itertools::{Itertools, izip}; -use rand::rngs::OsRng; -use subprotocols::{ - expression::{Expression, VectorType}, - test_utils::random_point, - utils::eq_vecs, +use multilinear_extensions::{ + WitnessId, mle::MultilinearExtension, util::ceil_log2, + virtual_poly::build_eq_x_r_vec_with_scalar, }; +use rand::{rngs::OsRng, thread_rng}; use thiserror::Error; use crate::{evaluation::EvalExpression, utils::SliceIterator}; +use multilinear_extensions::{ + Expression, mle::FieldType, smart_slice::SmartSlice, wit_infer_by_expr, +}; +use rand::Rng; use super::{GKRCircuit, GKRCircuitWitness, layer::LayerType}; pub struct MockProver(PhantomData); #[derive(Clone, Debug, Error)] -pub enum MockProverError { +pub enum MockProverError<'a, E: ExtensionField> { #[error("sumcheck layer should have only one expression, got {0}")] SumcheckExprLenError(usize), #[error("sumcheck expression not match, out: {0:?}, expr: {1:?}, expect: {2:?}. got: {3:?}")] SumcheckExpressionNotMatch( - Vec, - Expression, - VectorType, - VectorType, + Vec>, + Expression, + FieldType<'a, E>, + FieldType<'a, E>, ), #[error("zerocheck expression not match, out: {0:?}, expr: {1:?}, expect: {2:?}. got: {3:?}")] - ZerocheckExpressionNotMatch(EvalExpression, Expression, VectorType, VectorType), + ZerocheckExpressionNotMatch( + EvalExpression, + Expression, + FieldType<'a, E>, + FieldType<'a, E>, + ), #[error("linear expression not match, out: {0:?}, expr: {1:?}, expect: {2:?}. got: {3:?}")] - LinearExpressionNotMatch(EvalExpression, Expression, VectorType, VectorType), + LinearExpressionNotMatch( + EvalExpression, + Expression, + FieldType<'a, E>, + FieldType<'a, E>, + ), } impl MockProver { - pub fn check( - circuit: GKRCircuit, + pub fn check<'a>( + circuit: GKRCircuit, circuit_wit: &GKRCircuitWitness, - mut evaluations: Vec>, + mut evaluations: Vec>, mut challenges: Vec, - ) -> Result<(), MockProverError> { - evaluations.resize(circuit.n_evaluations, VectorType::Base(vec![])); - challenges.resize_with(circuit.n_challenges, || E::random(OsRng)); + ) -> Result<(), MockProverError<'a, E>> { + let mut rng = thread_rng(); + evaluations.resize( + circuit.n_evaluations, + FieldType::Base(SmartSlice::Owned(vec![])), + ); + challenges.resize_with(circuit.n_challenges, || E::random(&mut rng)); for (layer, layer_wit) in izip!(circuit.layers, &circuit_wit.layers) { let num_vars = layer_wit.num_vars; let points = (0..layer.outs.len()) .map(|_| random_point::(OsRng, num_vars)) .collect_vec(); - let eqs = eq_vecs(points.slice_iter(), &vec![E::ONE; points.len()]); + let eqs = eq_mles(points.slice_iter(), &vec![E::ONE; points.len()]); let gots = layer .exprs .iter() - .map(|expr| expr.calc(&layer_wit.exts, &layer_wit.bases, &eqs, &challenges)) + .map(|expr| { + Arc::into_inner(wit_infer_by_expr( + &[], + &layer_wit + .bases + .iter() + .map(|mle| mle.as_view().into()) + .chain(eqs.into_iter().map(|eq| eq.into())) + .collect_vec(), + &[], + &[], + &challenges, + expr, + )) + .unwrap() + .evaluations_to_owned() + }) .collect_vec(); let expects = layer .outs @@ -102,29 +135,41 @@ impl MockProver { } } for (in_pos, base) in izip!(&layer.in_eval_expr, &layer_wit.bases) { - *(in_pos.entry_mut(&mut evaluations)) = VectorType::Base(base.clone()); - } - for (in_pos, ext) in izip!(&layer.in_eval_expr, &layer_wit.exts) { - *(in_pos.entry_mut(&mut evaluations)) = VectorType::Ext(ext.clone()); + *(in_pos.entry_mut(&mut evaluations)) = base.evaluations().as_borrowed_view(); } } Ok(()) } } -impl EvalExpression { - pub fn mock_evaluate( +impl EvalExpression { + pub fn mock_evaluate<'a>( &self, - evals: &[VectorType], + evals: &[FieldType<'a, E>], challenges: &[E], len: usize, - ) -> VectorType { + ) -> FieldType<'a, E> { match self { EvalExpression::Single(i) => evals[*i].clone(), - EvalExpression::Linear(i, c0, c1) => { - evals[*i].clone() * VectorType::Ext(vec![c0.evaluate(challenges); len]) - + VectorType::Ext(vec![c1.evaluate(challenges); len]) - } + EvalExpression::Linear(i, c0, c1) => Arc::into_inner(wit_infer_by_expr( + &[], + &evals + .iter() + .map(|field_type| { + MultilinearExtension::from_field_type_borrowed( + ceil_log2(field_type.len()), + field_type, + ) + .into() + }) + .collect_vec(), + &[], + &[], + &challenges, + &(Expression::WitIn(*i as WitnessId) * *c0.clone() + *c1.clone()), + )) + .unwrap() + .evaluations_to_owned(), EvalExpression::Partition(parts, indices) => { assert_eq!(parts.len(), 1 << indices.len()); let parts = parts @@ -137,7 +182,7 @@ impl EvalExpression { let step_size = 1 << i; acc.chunks_exact(2) .map(|chunk| match (&chunk[0], &chunk[1]) { - (VectorType::Base(v0), VectorType::Base(v1)) => { + (FieldType::Base(v0), FieldType::Base(v1)) => { let res = (0..v0.len()) .step_by(step_size) .flat_map(|j| { @@ -147,9 +192,9 @@ impl EvalExpression { .cloned() }) .collect_vec(); - VectorType::Base(res) + FieldType::Base(SmartSlice::Owned(res)) } - (VectorType::Ext(v0), VectorType::Ext(v1)) => { + (FieldType::Ext(v0), FieldType::Ext(v1)) => { let res = (0..v0.len()) .step_by(step_size) .flat_map(|j| { @@ -159,7 +204,7 @@ impl EvalExpression { .cloned() }) .collect_vec(); - VectorType::Ext(res) + FieldType::Ext(SmartSlice::Owned(res)) } _ => unreachable!(), }) @@ -171,3 +216,21 @@ impl EvalExpression { } } } + +fn eq_mles<'a, E: ExtensionField>( + points: impl Iterator, + scalars: &[E], +) -> Vec> { + izip!(points, scalars) + .map(|(point, scalar)| { + MultilinearExtension::from_evaluations_ext_vec( + point.len(), + build_eq_x_r_vec_with_scalar(point, *scalar), + ) + }) + .collect_vec() +} + +fn random_point(mut rng: impl Rng, num_vars: usize) -> Vec { + (0..num_vars).map(|_| E::random(&mut rng)).collect_vec() +} diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 06c83f486..b946984e3 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -12,13 +12,13 @@ pub mod gkr; pub mod precompiles; pub mod utils; -pub trait ProtocolBuilder: Sized { +pub trait ProtocolBuilder: Sized { type Params; fn init(params: Self::Params) -> Self; /// Build the protocol for GKR IOP. - fn build(params: Self::Params) -> (Self, Chip) { + fn build(params: Self::Params) -> (Self, Chip) { let mut chip_spec = Self::init(params); let mut chip = Chip::default(); chip_spec.build_commit_phase(&mut chip); @@ -29,11 +29,11 @@ pub trait ProtocolBuilder: Sized { /// Specify the polynomials and challenges to be committed and generated in /// Phase 1. - fn build_commit_phase(&mut self, spec: &mut Chip); + fn build_commit_phase(&mut self, spec: &mut Chip); /// Create the GKR layers in the reverse order. For each layer, specify the /// polynomial expressions, evaluation expressions of outputs and evaluation /// positions of the inputs. - fn build_gkr_phase(&mut self, spec: &mut Chip); + fn build_gkr_phase(&mut self, spec: &mut Chip); } pub trait ProtocolWitnessGenerator diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index 153e01d2b..52ad6c22b 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -3,7 +3,7 @@ use std::{array::from_fn, marker::PhantomData, sync::Arc}; use crate::{ ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, - evaluation::{EvalExpression, PointAndEval}, + evaluation::EvalExpression, gkr::{ GKRCircuitWitness, GKRProverOutput, layer::{Layer, LayerType, LayerWitness}, @@ -11,10 +11,11 @@ use crate::{ }; use ff_ext::ExtensionField; use itertools::{Itertools, chain, iproduct}; +use multilinear_extensions::{Expression, ToExpr}; use p3_field::{Field, PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; -use subprotocols::expression::{Constant, Expression, Witness}; +use subprotocols::expression::{Constant, Witness}; use tiny_keccak::keccakf; use transcript::BasicTranscript; @@ -53,24 +54,24 @@ fn xor(a: F, b: F) -> F { a + b - a * b - a * b } -fn and_expr(a: Expression, b: Expression) -> Expression { +fn and_expr(a: Expression, b: Expression) -> Expression { a.clone() * b.clone() } -fn not_expr(a: Expression) -> Expression { +fn not_expr(a: Expression) -> Expression { one_expr() - a } -fn xor_expr(a: Expression, b: Expression) -> Expression { - a.clone() + b.clone() - Expression::Const(Constant::Base(2)) * a * b +fn xor_expr(a: Expression, b: Expression) -> Expression { + a.clone() + b.clone() - E::BaseField::from_u32(2).expr() * a * b } -fn zero_expr() -> Expression { - Expression::Const(Constant::Base(0)) +fn zero_expr() -> Expression { + E::BaseField::ZERO.expr() } -fn one_expr() -> Expression { - Expression::Const(Constant::Base(1)) +fn one_expr() -> Expression { + E::BaseField::ONE.expr() } fn c(x: usize, z: usize, bits: &[F]) -> F { @@ -95,7 +96,7 @@ fn d(x: usize, z: usize, c_vals: &[F]) -> F { xor(c_vals[lhs], c_vals[rhs]) } -fn d_expr(x: usize, z: usize, c_wits: &[Witness]) -> Expression { +fn d_expr(x: usize, z: usize, c_wits: &[Witness]) -> Expression { let lhs = from_xz((x + 5 - 1) % 5, z); let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); xor_expr(c_wits[lhs].into(), c_wits[rhs].into()) @@ -212,7 +213,7 @@ fn iota_expr(bits: &[Witness], index: usize, round_value: u64) -> Expression { } } -fn chi_expr(i: usize, bits: &[Witness]) -> Expression { +fn chi_expr(i: usize, bits: &[Witness]) -> Expression { assert_eq!(bits.len(), STATE_SIZE); let (x, y, z) = to_xyz(i); @@ -223,7 +224,7 @@ fn chi_expr(i: usize, bits: &[Witness]) -> Expression { xor_expr((bits[i]).into(), rhs) } -impl ProtocolBuilder for KeccakLayout { +impl ProtocolBuilder for KeccakLayout { type Params = KeccakParams; fn init(params: Self::Params) -> Self { diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs index e5b17b3bd..58af29c11 100644 --- a/gkr_iop/src/precompiles/utils.rs +++ b/gkr_iop/src/precompiles/utils.rs @@ -1,5 +1,6 @@ use ff_ext::ExtensionField; use itertools::Itertools; +use multilinear_extensions::ToExpr; use p3_field::PrimeCharacteristicRing; use subprotocols::expression::{Constant, Expression}; @@ -13,8 +14,12 @@ pub fn not8_expr(expr: Expression) -> Expression { Expression::Const(Constant::Base(0xFF)) - expr } -pub fn zero_eval() -> EvalExpression { - EvalExpression::Linear(0, Constant::Base(0), Constant::Base(0)) +pub fn zero_eval() -> EvalExpression { + EvalExpression::Linear( + 0, + Box::new(E::BaseField::ZERO.expr()), + Box::new(E::BaseField::ZERO.expr()), + ) } pub fn nest(v: &[E::BaseField]) -> Vec> { diff --git a/multilinear_extensions/src/expression.rs b/multilinear_extensions/src/expression.rs index 8b68db5f2..f9fea60b4 100644 --- a/multilinear_extensions/src/expression.rs +++ b/multilinear_extensions/src/expression.rs @@ -1,6 +1,17 @@ pub mod monomial; pub mod utils; +use crate::{ + commutative_op_mle_pair, + mle::{ArcMultilinearExtension, MultilinearExtension}, + op_mle_xa_b, op_mle3_range, + util::ceil_log2, +}; +use ff_ext::{ExtensionField, SmallField}; +use itertools::Either; +use p3::field::PrimeCharacteristicRing; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use serde::de::DeserializeOwned; use std::{ cmp::max, fmt::Display, @@ -8,14 +19,9 @@ use std::{ ops::{Add, AddAssign, Deref, Mul, MulAssign, Neg, Shl, ShlAssign, Sub, SubAssign}, }; -use serde::de::DeserializeOwned; - -use ff_ext::{ExtensionField, SmallField}; -use itertools::Either; -use p3::field::PrimeCharacteristicRing; - pub type WitnessId = u32; pub type ChallengeId = u16; +pub const MIN_PAR_SIZE: usize = 64; #[derive( Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, serde::Serialize, serde::Deserialize, @@ -909,6 +915,129 @@ impl> ToExpr for F { } } +pub fn wit_infer_by_expr<'a, E: ExtensionField>( + fixed: &[ArcMultilinearExtension<'a, E>], + witnesses: &[ArcMultilinearExtension<'a, E>], + structual_witnesses: &[ArcMultilinearExtension<'a, E>], + instance: &[ArcMultilinearExtension<'a, E>], + challenges: &[E], + expr: &Expression, +) -> ArcMultilinearExtension<'a, E> { + expr.evaluate_with_instance::>( + &|f| fixed[f.0].clone(), + &|witness_id| witnesses[witness_id as usize].clone(), + &|witness_id, _, _, _| structual_witnesses[witness_id as usize].clone(), + &|i| instance[i.0].clone(), + &|scalar| { + let scalar: ArcMultilinearExtension = MultilinearExtension::from_evaluations_vec( + 0, + vec![scalar.left().expect("do not support extension field")], + ) + .into(); + scalar + }, + &|challenge_id, pow, scalar, offset| { + // TODO cache challenge power to be acquired once for each power + let challenge = challenges[challenge_id as usize]; + let challenge: ArcMultilinearExtension = + MultilinearExtension::from_evaluations_ext_vec( + 0, + vec![challenge.exp_u64(pow as u64) * scalar + offset], + ) + .into(); + challenge + }, + &|a, b| { + commutative_op_mle_pair!(|a, b| { + match (a.len(), b.len()) { + (1, 1) => { + MultilinearExtension::from_evaluation_vec_smart(0, vec![a[0] + b[0]]).into() + } + (1, _) => MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(b.len()), + b.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|b| a[0] + *b) + .collect(), + ) + .into(), + (_, 1) => MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|a| *a + b[0]) + .collect(), + ) + .into(), + (_, _) => MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .zip(b.par_iter()) + .with_min_len(MIN_PAR_SIZE) + .map(|(a, b)| *a + *b) + .collect(), + ) + .into(), + } + }) + }, + &|a, b| { + commutative_op_mle_pair!(|a, b| { + match (a.len(), b.len()) { + (1, 1) => { + MultilinearExtension::from_evaluation_vec_smart(0, vec![a[0] * b[0]]).into() + } + (1, _) => MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(b.len()), + b.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|b| a[0] * *b) + .collect(), + ) + .into(), + (_, 1) => MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|a| *a * b[0]) + .collect(), + ) + .into(), + (_, _) => { + assert_eq!(a.len(), b.len()); + // we do the pointwise evaluation multiplication here without involving FFT + // the evaluations outside of range will be checked via sumcheck + identity polynomial + MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .zip(b.par_iter()) + .with_min_len(MIN_PAR_SIZE) + .map(|(a, b)| *a * *b) + .collect(), + ) + .into() + } + } + }) + }, + &|x, a, b| { + op_mle_xa_b!(|x, a, b| { + assert_eq!(a.len(), 1); + assert_eq!(b.len(), 1); + let (a, b) = (a[0], b[0]); + MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(x.len()), + x.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|x| a * *x + b) + .collect(), + ) + .into() + }) + }, + ) +} + macro_rules! impl_from_via_ToExpr { ($($t:ty),*) => { $( @@ -1105,9 +1234,8 @@ pub mod fmt { #[cfg(test)] mod tests { - use crate::expression::WitIn; - use super::{Expression, ToExpr, fmt}; + use crate::{expression::WitIn, mle::IntoMLE, wit_infer_by_expr}; use either::Either; use ff_ext::{FieldInto, GoldilocksExt2}; use p3::field::PrimeCharacteristicRing; @@ -1257,4 +1385,58 @@ mod tests { assert_eq!(s, "WitIn(0)"); assert_eq!(wtns_acc, vec![0]); } + + #[test] + fn test_wit_infer_by_expr_base_field() { + type E = ff_ext::GoldilocksExt2; + type B = p3::goldilocks::Goldilocks; + let a = WitIn { id: 0 }; + let b = WitIn { id: 1 }; + let c = WitIn { id: 2 }; + + let expr: Expression = a.expr() + b.expr() + a.expr() * b.expr() + (c.expr() * 3 + 2); + + let res = wit_infer_by_expr( + &[], + &[ + vec![B::from_u64(1)].into_mle().into(), + vec![B::from_u64(2)].into_mle().into(), + vec![B::from_u64(3)].into_mle().into(), + ], + &[], + &[], + &[], + &expr, + ); + res.get_base_field_vec(); + } + + #[test] + fn test_wit_infer_by_expr_ext_field() { + type E = ff_ext::GoldilocksExt2; + type B = p3::goldilocks::Goldilocks; + let a = WitIn { id: 0 }; + let b = WitIn { id: 1 }; + let c = WitIn { id: 2 }; + + let expr: Expression = a.expr() + + b.expr() + + a.expr() * b.expr() + + (c.expr() * 3 + 2) + + Expression::Challenge(0, 1, E::ONE, E::ONE); + + let res = wit_infer_by_expr( + &[], + &[ + vec![B::from_u64(1)].into_mle().into(), + vec![B::from_u64(2)].into_mle().into(), + vec![B::from_u64(3)].into_mle().into(), + ], + &[], + &[], + &[E::ONE], + &expr, + ); + res.get_ext_field_vec(); + } } diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index 93c87063f..bf36ec7c7 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -106,6 +106,18 @@ impl<'a, E: ExtensionField> FieldType<'a, E> { } } + pub fn as_borrowed_view(&self) -> Self { + match self { + FieldType::Base(SmartSlice::Borrowed(slice)) => { + FieldType::Base(SmartSlice::Borrowed(slice)) + } + FieldType::Ext(SmartSlice::Borrowed(slice)) => { + FieldType::Ext(SmartSlice::Borrowed(slice)) + } + _ => panic!("invalid type"), + } + } + pub fn is_empty(&self) -> bool { match self { FieldType::Base(content) => content.is_empty(), @@ -202,6 +214,14 @@ impl<'a, E: ExtensionField> MultilinearExtension<'a, E> { } } + /// Create vector from field type + pub fn from_field_type_borrowed(num_vars: usize, field_type: &FieldType<'a, E>) -> Self { + Self { + num_vars, + evaluations: field_type.as_borrowed_view(), + } + } + /// Construct a new polynomial from a list of evaluations where the index /// represents a point in {0,1}^`num_vars` in little endian form. For /// example, `0b1011` represents `P(1,1,0,1)` @@ -532,6 +552,10 @@ impl<'a, E: ExtensionField> MultilinearExtension<'a, E> { &self.evaluations } + pub fn as_evaluations_view(&self) -> FieldType { + self.evaluations.as_borrowed_view() + } + pub fn evaluations_to_owned(self) -> FieldType<'a, E> { self.evaluations } From 1b196e0625060742abceb2cbddb0e57b9efd0504 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Mon, 19 May 2025 14:41:02 +0800 Subject: [PATCH 04/28] wip --- .../instructions/riscv/dummy/dummy_ecall.rs | 2 +- gkr_iop/examples/multi_layer_logup.rs | 10 +- gkr_iop/src/chip.rs | 4 +- gkr_iop/src/chip/builder.rs | 39 ++-- gkr_iop/src/lib.rs | 13 +- gkr_iop/src/precompiles/bitwise_keccakf.rs | 201 +++++++++++------- gkr_iop/src/precompiles/lookup_keccakf.rs | 8 +- 7 files changed, 167 insertions(+), 110 deletions(-) diff --git a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs index 6614c33e0..202f94506 100644 --- a/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs +++ b/ceno_zkvm/src/instructions/riscv/dummy/dummy_ecall.rs @@ -208,7 +208,7 @@ impl GKRIOPInstruction for LargeEcallDummy }) .collect_vec(); - layout.phase1_witness(KeccakTrace { instances }) + layout.phase1_witness_group(KeccakTrace { instances }) } fn assign_instance_with_gkr_iop( diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index 36acc812f..2301fcb8b 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -56,7 +56,7 @@ impl ProtocolBuilder for TowerChipLayout { } fn build_commit_phase(&mut self, chip: &mut Chip) { - [self.committed_table_id, self.committed_count_id] = chip.allocate_committed_base(); + [self.committed_table_id, self.committed_count_id] = chip.allocate_committed(); [self.lookup_challenge] = chip.allocate_challenges(); } @@ -146,8 +146,8 @@ impl ProtocolBuilder for TowerChipLayout { vec![], )); - chip.allocate_base_opening(self.committed_table_id, table.1); - chip.allocate_base_opening(self.committed_count_id, count); + chip.allocate_opening(self.committed_table_id, table.1); + chip.allocate_opening(self.committed_count_id, count); } } @@ -162,7 +162,7 @@ where { type Trace = TowerChipTrace; - fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + fn phase1_witness_group(&self, phase1: Self::Trace) -> Vec> { let mut res = vec![vec![]; 2]; res[self.committed_table_id] = phase1 .table @@ -232,7 +232,7 @@ fn main() { let count = (0..1 << log_size) .map(|_| OsRng.gen_range(0..1 << log_size as u64)) .collect_vec(); - let phase1_witness = layout.phase1_witness(TowerChipTrace { + let phase1_witness = layout.phase1_witness_group(TowerChipTrace { table, multiplicity: count, }); diff --git a/gkr_iop/src/chip.rs b/gkr_iop/src/chip.rs index 4089ac7c9..6bb6e07a8 100644 --- a/gkr_iop/src/chip.rs +++ b/gkr_iop/src/chip.rs @@ -10,9 +10,7 @@ pub mod protocol; #[derive(Clone, Debug, Default)] pub struct Chip { /// The number of base inputs committed in the whole protocol. - pub n_committed_bases: usize, - /// The number of ext inputs committed in the whole protocol. - pub n_committed_exts: usize, + pub n_committed: usize, /// The number of challenges generated through the whole protocols /// (except the ones inside sumcheck protocols). diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index 2703b91f2..cfd458f36 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -2,7 +2,7 @@ use std::array; use ff_ext::ExtensionField; use itertools::Itertools; -use subprotocols::expression::{Constant, Witness}; +use multilinear_extensions::{ChallengeId, Expression, WitIn, WitnessId}; use crate::{ evaluation::EvalExpression, @@ -13,15 +13,10 @@ use super::Chip; impl Chip { /// Allocate indices for committing base field polynomials. - pub fn allocate_committed_base(&mut self) -> [usize; N] { - self.n_committed_bases += N; - array::from_fn(|i| i + self.n_committed_bases - N) - } - - /// Allocate indices for committing extension field polynomials. - pub fn allocate_committed_ext(&mut self) -> [usize; N] { - self.n_committed_exts += N; - array::from_fn(|i| i + self.n_committed_exts - N) + pub fn allocate_committed(&mut self) -> [usize; N] { + let committed = array::from_fn(|i| i + self.n_committed); + self.n_committed += N; + committed } /// Allocate `Witness` and `EvalExpression` for the input polynomials in a @@ -30,10 +25,12 @@ impl Chip { /// processing the layer prover for each polynomial. This should be /// called at most once for each layer! #[allow(clippy::type_complexity)] - pub fn allocate_wits_in_layer(&mut self) -> [(Witness, EvalExpression); N] { + pub fn allocate_wits_in_layer(&mut self) -> [(WitIn, EvalExpression); N] { let bases = array::from_fn(|i| { ( - Witness::BasePoly(i), + WitIn { + id: (i + self.n_evaluations) as WitnessId, + }, EvalExpression::Single(i + self.n_evaluations), ) }); @@ -45,24 +42,28 @@ impl Chip { pub fn allocate_output_evals(&mut self) -> Vec> // -> [EvalExpression; N] { - self.n_evaluations += N; // array::from_fn(|i| EvalExpression::Single(i + self.n_evaluations - N)) // TODO: hotfix to avoid stack overflow, fix later - (0..N) - .map(|i| EvalExpression::Single(i + self.n_evaluations - N)) - .collect_vec() + let output_evals = (0..N) + .map(|i| EvalExpression::Single(i + self.n_evaluations)) + .collect_vec(); + self.n_evaluations += N; + output_evals } /// Allocate challenges. - pub fn allocate_challenges(&mut self) -> [Constant; N] { + pub fn allocate_challenges(&mut self) -> [Expression; N] { + let challanges = array::from_fn(|i| { + Expression::Challenge((i + self.n_challenges) as ChallengeId, 1, E::ONE, E::ZERO) + }); self.n_challenges += N; - array::from_fn(|i| Constant::Challenge(i + self.n_challenges - N)) + challanges } /// Allocate a PCS opening action to a base polynomial with index /// `wit_index`. The `EvalExpression` represents the expression to /// compute the evaluation. - pub fn allocate_base_opening(&mut self, wit_index: usize, eval: EvalExpression) { + pub fn allocate_opening(&mut self, wit_index: usize, eval: EvalExpression) { self.openings.push((wit_index, eval)); } diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index b946984e3..7bda8bc3a 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; use chip::Chip; use ff_ext::ExtensionField; use gkr::GKRCircuitWitness; +use multilinear_extensions::mle::MultilinearExtension; use transcript::Transcript; pub mod chip; @@ -12,6 +13,8 @@ pub mod gkr; pub mod precompiles; pub mod utils; +pub type Phase1WitnessGroup<'a, E> = Vec>>; + pub trait ProtocolBuilder: Sized { type Params; @@ -36,17 +39,21 @@ pub trait ProtocolBuilder: Sized { fn build_gkr_phase(&mut self, spec: &mut Chip); } -pub trait ProtocolWitnessGenerator +pub trait ProtocolWitnessGenerator<'a, E> where E: ExtensionField, { type Trace; /// The vectors to be committed in the phase1. - fn phase1_witness(&self, phase1: Self::Trace) -> Vec>; + fn phase1_witness_group(&self, phase1: Self::Trace) -> Phase1WitnessGroup<'a, E>; /// GKR witness. - fn gkr_witness(&self, phase1: &[Vec], challenges: &[E]) -> GKRCircuitWitness; + fn gkr_witness( + &self, + phase1_witness_group: Phase1WitnessGroup<'a, E>, + challenges: &[E], + ) -> GKRCircuitWitness; } // TODO: the following trait consists of `commit_phase1`, `commit_phase2`, diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index 52ad6c22b..c2bb47310 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -1,7 +1,7 @@ use std::{array::from_fn, marker::PhantomData, sync::Arc}; use crate::{ - ProtocolBuilder, ProtocolWitnessGenerator, + Phase1WitnessGroup, ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, evaluation::EvalExpression, gkr::{ @@ -11,25 +11,29 @@ use crate::{ }; use ff_ext::ExtensionField; use itertools::{Itertools, chain, iproduct}; -use multilinear_extensions::{Expression, ToExpr}; +use multilinear_extensions::{ + Expression, ToExpr, + mle::{MultilinearExtension, Point, PointAndEval}, + util::ceil_log2, + wit_infer_by_expr, +}; use p3_field::{Field, PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; -use subprotocols::expression::{Constant, Witness}; +use sumcheck::util::optimal_sumcheck_threads; use tiny_keccak::keccakf; use transcript::BasicTranscript; -type E = BinomialExtensionField; #[derive(Clone, Debug, Default)] struct KeccakParams {} #[derive(Clone, Debug, Default)] -struct KeccakLayout { +struct KeccakLayout { _params: KeccakParams, committed_bits_id: usize, - _result: Vec, + _result: Vec>, _marker: PhantomData, } @@ -54,11 +58,11 @@ fn xor(a: F, b: F) -> F { a + b - a * b - a * b } -fn and_expr(a: Expression, b: Expression) -> Expression { +fn and_expr(a: Expression, b: Expression) -> Expression { a.clone() * b.clone() } -fn not_expr(a: Expression) -> Expression { +fn not_expr(a: Expression) -> Expression { one_expr() - a } @@ -74,15 +78,22 @@ fn one_expr() -> Expression { E::BaseField::ONE.expr() } -fn c(x: usize, z: usize, bits: &[F]) -> F { - (0..5) - .map(|y| bits[from_xyz(x, y, z)]) - .fold(F::ZERO, |acc, x| xor(acc, x)) +fn c<'a, E: ExtensionField>(x: usize, z: usize, bits: &[MultilinearExtension<'a, E>]) -> E { + wit_infer_by_expr( + fixed, + witnesses, + structual_witnesses, + instance, + challenges, + expr, + )(0..5) + .map(|y| bits[from_xyz(x, y, z)]) + .fold(E::ZERO, |acc, x| xor(acc, x)) } -fn c_expr(x: usize, z: usize, state_wits: &[Witness]) -> Expression { +fn c_expr(x: usize, z: usize, state_wits: &[Expression]) -> Expression { (0..5) - .map(|y| Expression::from(state_wits[from_xyz(x, y, z)])) + .map(|y| state_wits[from_xyz(x, y, z)].clone()) .fold(zero_expr(), xor_expr) } @@ -96,10 +107,10 @@ fn d(x: usize, z: usize, c_vals: &[F]) -> F { xor(c_vals[lhs], c_vals[rhs]) } -fn d_expr(x: usize, z: usize, c_wits: &[Witness]) -> Expression { +fn d_expr(x: usize, z: usize, c_wits: &[Expression]) -> Expression { let lhs = from_xz((x + 5 - 1) % 5, z); let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); - xor_expr(c_wits[lhs].into(), c_wits[rhs].into()) + xor_expr(c_wits[lhs].clone(), c_wits[rhs].clone()) } fn theta(bits: Vec) -> Vec { @@ -130,11 +141,32 @@ fn not(a: F) -> F { F::ONE - a } -fn u64s_to_bools(state64: &[u64]) -> Vec { - state64 - .iter() - .flat_map(|&word| (0..64).map(move |i| ((word >> i) & 1) == 1)) - .collect() +fn keccak_witness<'a, E: ExtensionField>( + states: &[[u64; 25]], +) -> [MultilinearExtension<'a, E>; STATE_SIZE] { + let num_states = states.len(); + assert!(num_states.is_power_of_two()); + let log_num_states = ceil_log2(num_states); + let mut bits = from_fn(|_| vec![false; num_states]); + + for (state_idx, state) in states.iter().enumerate() { + for (word_idx, &word) in state.iter().enumerate() { + for bit_idx in 0..64 { + let bit = ((word >> bit_idx) & 1) == 1; + bits[word_idx * 64 + bit_idx][state_idx] = bit; + } + } + } + + bits.map(|bit_column| { + MultilinearExtension::from_evaluation_vec_smart( + log_num_states, + bit_column + .into_iter() + .map(|b| E::from_bool(b)) + .collect::>(), + ) + }) } fn chi(bits: &[F]) -> Vec { @@ -199,29 +231,31 @@ fn iota(bits: &[F], round_value: u64) -> Vec { ret } -fn iota_expr(bits: &[Witness], index: usize, round_value: u64) -> Expression { +fn iota_expr( + bits: &[Expression], + index: usize, + round_value: u64, +) -> Expression { assert_eq!(bits.len(), STATE_SIZE); let (x, y, z) = to_xyz(index); if x > 0 || y > 0 { - bits[index].into() + bits[index].clone() } else { - let round_bit = Expression::Const(Constant::Base( - ((round_value >> index) & 1).try_into().unwrap(), - )); - xor_expr(bits[from_xyz(0, 0, z)].into(), round_bit) + let round_bit = E::BaseField::from_u64((round_value >> index) & 1).expr(); + xor_expr(bits[from_xyz(0, 0, z)].clone(), round_bit) } } -fn chi_expr(i: usize, bits: &[Witness]) -> Expression { +fn chi_expr(i: usize, bits: &[Expression]) -> Expression { assert_eq!(bits.len(), STATE_SIZE); let (x, y, z) = to_xyz(i); let rhs = and_expr( - not_expr(bits[from_xyz((x + 1) % X, y, z)].into()), - bits[from_xyz((x + 2) % X, y, z)].into(), + not_expr(bits[from_xyz((x + 1) % X, y, z)].clone()), + bits[from_xyz((x + 2) % X, y, z)].clone(), ); - xor_expr((bits[i]).into(), rhs) + xor_expr((bits[i]).clone(), rhs) } impl ProtocolBuilder for KeccakLayout { @@ -234,18 +268,24 @@ impl ProtocolBuilder for KeccakLayout { } } - fn build_commit_phase(&mut self, chip: &mut Chip) { - [self.committed_bits_id] = chip.allocate_committed_base(); + fn build_commit_phase(&mut self, chip: &mut Chip) { + [self.committed_bits_id] = chip.allocate_committed(); } - fn build_gkr_phase(&mut self, chip: &mut Chip) { + fn build_gkr_phase(&mut self, chip: &mut Chip) { let final_output = chip.allocate_output_evals::(); (0..ROUNDS).rev().fold(final_output, |round_output, round| { - let (chi_output, _) = chip.allocate_wits_in_layer::(); + let chi_output = chip.allocate_wits_in_layer::(); let exprs = (0..STATE_SIZE) - .map(|i| iota_expr(&chi_output.iter().map(|e| e.0).collect_vec(), i, RC[round])) + .map(|i| { + iota_expr( + &chi_output.iter().map(|e| e.0.expr()).collect_vec(), + i, + RC[round], + ) + }) .collect_vec(); chip.add_layer(Layer::new( @@ -259,13 +299,13 @@ impl ProtocolBuilder for KeccakLayout { vec![], )); - let (theta_output, _) = chip.allocate_wits_in_layer::(); + let theta_output = chip.allocate_wits_in_layer::(); // Apply the effects of the rho + pi permutation directly o the argument of chi // No need for a separate layer let perm = rho_and_pi_permutation(); let permuted = (0..STATE_SIZE) - .map(|i| theta_output[perm[i]].0) + .map(|i| theta_output[perm[i]].0.expr()) .collect_vec(); let exprs = (0..STATE_SIZE) @@ -283,14 +323,14 @@ impl ProtocolBuilder for KeccakLayout { vec![], )); - let (d_and_state, _) = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }, 0>(); + let d_and_state = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }>(); let (d, state2) = d_and_state.split_at(D_SIZE); // Compute post-theta state using original state and D[][] values let exprs = (0..STATE_SIZE) .map(|i| { let (x, _, z) = to_xyz(i); - xor_expr(state2[i].0.into(), d[from_xz(x, z)].0.into()) + xor_expr(state2[i].0.expr(), d[from_xz(x, z)].0.expr()) }) .collect_vec(); @@ -305,9 +345,9 @@ impl ProtocolBuilder for KeccakLayout { vec![], )); - let (c, []) = chip.allocate_wits_in_layer::<{ C_SIZE }, 0>(); + let c = chip.allocate_wits_in_layer::<{ C_SIZE }>(); - let c_wits = c.iter().map(|e| e.0).collect_vec(); + let c_wits = c.iter().map(|e| e.0.expr()).collect_vec(); // Compute D[][] from C[][] values let d_exprs = iproduct!(0..5usize, 0..64usize) .map(|(x, z)| d_expr(x, z, &c_wits)) @@ -324,8 +364,8 @@ impl ProtocolBuilder for KeccakLayout { vec![], )); - let (state, []) = chip.allocate_wits_in_layer::(); - let state_wits = state.iter().map(|s| s.0).collect_vec(); + let state = chip.allocate_wits_in_layer::(); + let state_wits = state.iter().map(|s| s.0.expr()).collect_vec(); // Compute C[][] from state let c_exprs = iproduct!(0..5usize, 0..64usize) @@ -333,8 +373,7 @@ impl ProtocolBuilder for KeccakLayout { .collect_vec(); // Copy state - let id_exprs: Vec = - (0..STATE_SIZE).map(|i| state_wits[i].into()).collect_vec(); + let id_exprs = (0..STATE_SIZE).map(|i| state_wits[i].clone()).collect_vec(); chip.add_layer(Layer::new( format!("Round {round}: Theta::compute C[x][z]"), @@ -358,28 +397,29 @@ impl ProtocolBuilder for KeccakLayout { } } -pub struct KeccakTrace { - pub bits: [bool; STATE_SIZE], +pub struct KeccakTrace<'a, E: ExtensionField> { + pub bits: [MultilinearExtension<'a, E>; STATE_SIZE], } -impl ProtocolWitnessGenerator for KeccakLayout +impl<'a, E> ProtocolWitnessGenerator<'a, E> for KeccakLayout where E: ExtensionField, { - type Trace = KeccakTrace; + type Trace = KeccakTrace<'a, E>; - fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { - let mut res = vec![vec![]; 1]; - res[0] = phase1 - .bits - .into_iter() - .map(|b| E::BaseField::from_u64(b as u64)) - .collect(); - res + fn phase1_witness_group(&self, phase1: Self::Trace) -> Phase1WitnessGroup<'a, E> { + vec![phase1.bits.try_into().unwrap()] } - fn gkr_witness(&self, phase1: &[Vec], _challenges: &[E]) -> GKRCircuitWitness { - let mut bits = phase1[self.committed_bits_id].clone(); + fn gkr_witness( + &self, + phase1_witness_group: Phase1WitnessGroup<'a, E>, + _challenges: &[E], + ) -> GKRCircuitWitness { + let bits = phase1_witness_group[self.committed_bits_id] + .into_iter() + .map(|bit| bit.as_view()) + .collect_vec(); let n_layers = 100; let mut layer_wits = Vec::>::with_capacity(n_layers + 1); @@ -387,10 +427,7 @@ where #[allow(clippy::needless_range_loop)] for round in 0..24 { if round == 0 { - layer_wits.push(LayerWitness::new( - bits.clone().into_iter().map(|b| vec![b]).collect_vec(), - vec![], - )); + layer_wits.push(LayerWitness::new(bits)); } let c_wits = iproduct!(0..5usize, 0..64usize) @@ -405,7 +442,6 @@ where // bits.clone().into_iter().map(|b| vec![b]) ) .collect_vec(), - vec![], )); let d_wits = iproduct!(0..5usize, 0..64usize) @@ -491,23 +527,31 @@ fn rho_and_pi_permutation() -> Vec { pi(&rho(&perm)) } -pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) { +pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { + type E = BinomialExtensionField; + let num_instances = 1; + let log2_num_instances = ceil_log2(num_instances); + let num_threads = optimal_sumcheck_threads(log2_num_instances); + let params = KeccakParams {}; let (layout, chip) = KeccakLayout::build(params); let gkr_circuit = chip.gkr_circuit(); - let bits = u64s_to_bools(&state); + let bits = keccak_witness(&states); - let phase1_witness = layout.phase1_witness(KeccakTrace { - bits: bits.try_into().unwrap(), - }); + // get the view only phase 1 witness, since it need to be commit thus can't be in-place change + let phase1_witness = layout + .phase1_witness_group(KeccakTrace { bits }) + .into_iter() + .map(|mle| mle.as_view()) + .collect_vec(); let mut prover_transcript = BasicTranscript::::new(b"protocol"); // Omit the commit phase1 and phase2. let gkr_witness = layout.gkr_witness(&phase1_witness, &[]); let out_evals = { - let point = Arc::new(vec![]); + let point = Point::new(); let last_witness = gkr_witness.layers[0] .bases @@ -520,9 +564,9 @@ pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) { let expected_result_manual = iota(&last_witness, RC[23]); if test { - let mut state = state; + let mut state = states; keccakf(&mut state); - let state = u64s_to_bools(&state) + let state = keccak_witness(&state) .into_iter() .map(|b| Goldilocks::from_u64(b as u64)) .collect_vec(); @@ -539,7 +583,14 @@ pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) { }; let GKRProverOutput { gkr_proof, .. } = gkr_circuit - .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) + .prove( + num_threads, + 1, + gkr_witness, + &out_evals, + &[], + &mut prover_transcript, + ) .expect("Failed to prove phase"); if verify { @@ -547,7 +598,7 @@ pub fn run_keccakf(state: [u64; 25], verify: bool, test: bool) { let mut verifier_transcript = BasicTranscript::::new(b"protocol"); gkr_circuit - .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) + .verify(1, gkr_proof, &out_evals, &[], &mut verifier_transcript) .expect("GKR verify failed"); // Omit the PCS opening phase. diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 3d097e7e2..036a01e1b 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -333,7 +333,7 @@ impl ProtocolBuilder for KeccakLayout { } fn build_commit_phase(&mut self, chip: &mut Chip) { - let _ = chip.allocate_committed_base::<{ 50 + 40144 }>(); + let _ = chip.allocate_committed::<{ 50 + 40144 }>(); } fn build_gkr_phase(&mut self, chip: &mut Chip) { @@ -441,7 +441,7 @@ impl ProtocolBuilder for KeccakLayout { .collect_vec(); for wit in &bases { - chip.allocate_base_opening(openings, wit.clone().1); + chip.allocate_opening(openings, wit.clone().1); openings += 1; } @@ -752,7 +752,7 @@ where { type Trace = KeccakTrace; - fn phase1_witness(&self, phase1: Self::Trace) -> Vec> { + fn phase1_witness_group(&self, phase1: Self::Trace) -> Vec> { let mut poly = vec![vec![]; KECCAK_INPUT_SIZE]; for instance in phase1.instances { let felts = u64s_to_felts::(instance.into_iter().map(|e| e as u64).collect_vec()); @@ -1078,7 +1078,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo } let num_instances = instances.len(); - let phase1_witness = layout.phase1_witness(KeccakTrace { + let phase1_witness = layout.phase1_witness_group(KeccakTrace { instances: instances.clone(), }); From 2b69809b5ae6541cc2cafbb7c60c9b9825d3ea4b Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 20 May 2025 10:52:34 +0800 Subject: [PATCH 05/28] testing on keccak-f --- Cargo.lock | 2 + gkr_iop/Cargo.toml | 2 + gkr_iop/src/chip/builder.rs | 31 +- gkr_iop/src/gkr.rs | 9 +- gkr_iop/src/gkr/layer.rs | 6 +- gkr_iop/src/gkr/layer/sumcheck_layer.rs | 5 +- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 10 +- gkr_iop/src/gkr/mock.rs | 35 +- gkr_iop/src/lib.rs | 9 +- gkr_iop/src/precompiles/bitwise_keccakf.rs | 267 +-- gkr_iop/src/precompiles/lookup_keccakf.rs | 2410 ++++++++++---------- gkr_iop/src/precompiles/mod.rs | 8 +- multilinear_extensions/src/expression.rs | 30 +- multilinear_extensions/src/mle.rs | 9 +- 14 files changed, 1393 insertions(+), 1440 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 51e30275b..b01c7dd3f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1126,6 +1126,8 @@ dependencies = [ "sumcheck", "thiserror 1.0.69", "tiny-keccak", + "tracing", + "tracing-subscriber", "transcript", "witness", ] diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index df7e26a05..38c43a968 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -26,6 +26,8 @@ thiserror.workspace = true tiny-keccak.workspace = true either.workspace = true transcript = { path = "../transcript" } +tracing.workspace = true +tracing-subscriber.workspace = true witness = { path = "../witness" } [dev-dependencies] diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index cfd458f36..f23fddff4 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -19,23 +19,44 @@ impl Chip { committed } + /// refer to `allocate_wits_in_zero_layer`. allocate witness w/o eq + #[allow(clippy::type_complexity)] + pub fn allocate_wits_in_layer(&mut self) -> [(WitIn, EvalExpression); N] { + let (wits, _) = self.allocate_wits_in_zero_layer::(); + wits + } + /// Allocate `Witness` and `EvalExpression` for the input polynomials in a /// layer. Where `Witness` denotes the index and `EvalExpression` /// denotes the position to place the evaluation of the polynomial after /// processing the layer prover for each polynomial. This should be - /// called at most once for each layer! + /// called at most once for each layer + /// + /// id within EvalExpression is chip-unique #[allow(clippy::type_complexity)] - pub fn allocate_wits_in_layer(&mut self) -> [(WitIn, EvalExpression); N] { + pub fn allocate_wits_in_zero_layer( + &mut self, + ) -> ( + [(WitIn, EvalExpression); N], + [(WitIn, EvalExpression); Z], + ) { let bases = array::from_fn(|i| { + ( + WitIn { id: i as WitnessId }, + EvalExpression::Single(i + self.n_evaluations), + ) + }); + self.n_evaluations += N; + let eqs = array::from_fn(|i| { ( WitIn { - id: (i + self.n_evaluations) as WitnessId, + id: (N + i) as WitnessId, }, EvalExpression::Single(i + self.n_evaluations), ) }); - self.n_evaluations += N; - bases + self.n_evaluations += Z; + (bases, eqs) } /// Generate the evaluation expression for each output. diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 6522c2105..b934c5d3f 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -65,8 +65,9 @@ impl GKRCircuit { challenges: &[E], transcript: &mut impl Transcript, ) -> Result>, BackendError> { - let mut evaluations = out_evals.to_vec(); - evaluations.resize(self.n_evaluations, PointAndEval::default()); + let mut running_evals = out_evals.to_vec(); + // running evals is a global referable within chip + running_evals.resize(self.n_evaluations, PointAndEval::default()); let mut challenges = challenges.to_vec(); let sumcheck_proofs = izip!(&self.layers, circuit_wit.layers) .map(|(layer, layer_wit)| { @@ -74,14 +75,14 @@ impl GKRCircuit { num_threads, max_num_variables, layer_wit, - &mut evaluations, + &mut running_evals, &mut challenges, transcript, ) }) .collect_vec(); - let opening_evaluations = self.opening_evaluations(&evaluations, &challenges); + let opening_evaluations = self.opening_evaluations(&running_evals, &challenges); Ok(GKRProverOutput { gkr_proof: GKRProof(sumcheck_proofs), diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 45d022e4c..90a6ba5d7 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -4,7 +4,7 @@ use itertools::{Itertools, chain, izip}; use linear_layer::{LayerClaims, LinearLayer}; use multilinear_extensions::{ Expression, - mle::{MultilinearExtension, Point, PointAndEval}, + mle::{ArcMultilinearExtension, Point, PointAndEval}, }; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use sumcheck_layer::{SumcheckLayer, SumcheckLayerProof}; @@ -56,7 +56,7 @@ pub struct Layer { #[derive(Clone, Debug, Default)] pub struct LayerWitness<'a, E: ExtensionField> { - pub bases: Vec>, + pub bases: Vec>, pub num_vars: usize, } @@ -222,7 +222,7 @@ impl Layer { } impl<'a, E: ExtensionField> LayerWitness<'a, E> { - pub fn new(bases: Vec>) -> Self { + pub fn new(bases: Vec>) -> Self { assert!(!bases.is_empty() || !bases.is_empty()); let num_vars = if bases.is_empty() { log2(bases[0].evaluations().len()) diff --git a/gkr_iop/src/gkr/layer/sumcheck_layer.rs b/gkr_iop/src/gkr/layer/sumcheck_layer.rs index 56031c6b9..bfb8e6bb5 100644 --- a/gkr_iop/src/gkr/layer/sumcheck_layer.rs +++ b/gkr_iop/src/gkr/layer/sumcheck_layer.rs @@ -59,7 +59,10 @@ impl SumcheckLayer for Layer { let builder = VirtualPolynomialsBuilder::new_with_mles( num_threads, max_num_variables, - wit.bases.iter().map(|mle| Either::Left(mle)).collect_vec(), + wit.bases + .iter() + .map(|mle| Either::Left(mle.as_ref())) + .collect_vec(), ); let (proof, prover_state) = IOPProverState::prove( builder.to_virtual_polys(&[self.exprs[0].clone()], challenges), diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index ead623362..ab9d4c882 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -51,7 +51,7 @@ impl ZerocheckLayer for Layer { &self, num_threads: usize, max_num_variables: usize, - mut wit: LayerWitness, + wit: LayerWitness, out_points: &[Point], challenges: &[E], transcript: &mut impl Transcript, @@ -59,7 +59,7 @@ impl ZerocheckLayer for Layer { assert_eq!(self.exprs.len(), out_points.len()); let span = entered_span!("build_out_points_eq"); - let eqs = out_points + let mut eqs = out_points .par_iter() .map(|point| { MultilinearExtension::from_evaluations_ext_vec( @@ -74,8 +74,8 @@ impl ZerocheckLayer for Layer { num_threads, max_num_variables, wit.bases - .iter_mut() - .map(|mle| Either::Right(mle)) + .iter() + .map(|mle| Either::Left(mle.as_ref())) // extend eqs to the end of wit .chain(eqs.iter_mut().map(|eq| Either::Right(eq))) .collect_vec(), @@ -87,7 +87,7 @@ impl ZerocheckLayer for Layer { let expr = self .exprs .iter() - .zip_eq(self.eqs) + .zip_eq(&self.eqs) .zip_eq(alpha_pows) .map(|((expr, eq), alpha)| alpha * eq * expr) .sum::>(); diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 0dfc15349..9d11819d7 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -3,7 +3,9 @@ use std::{marker::PhantomData, sync::Arc}; use ff_ext::ExtensionField; use itertools::{Itertools, izip}; use multilinear_extensions::{ - WitnessId, mle::MultilinearExtension, util::ceil_log2, + WitnessId, + mle::{MultilinearExtension, Point}, + util::ceil_log2, virtual_poly::build_eq_x_r_vec_with_scalar, }; use rand::{rngs::OsRng, thread_rng}; @@ -49,7 +51,7 @@ pub enum MockProverError<'a, E: ExtensionField> { impl MockProver { pub fn check<'a>( circuit: GKRCircuit, - circuit_wit: &GKRCircuitWitness, + circuit_wit: &'a GKRCircuitWitness<'a, E>, mut evaluations: Vec>, mut challenges: Vec, ) -> Result<(), MockProverError<'a, E>> { @@ -64,7 +66,10 @@ impl MockProver { let points = (0..layer.outs.len()) .map(|_| random_point::(OsRng, num_vars)) .collect_vec(); - let eqs = eq_mles(points.slice_iter(), &vec![E::ONE; points.len()]); + let eqs = eq_mles(points.clone(), &vec![E::ONE; points.len()]) + .into_iter() + .map(Arc::new) + .collect_vec(); let gots = layer .exprs .iter() @@ -75,7 +80,7 @@ impl MockProver { .bases .iter() .map(|mle| mle.as_view().into()) - .chain(eqs.into_iter().map(|eq| eq.into())) + .chain(eqs.clone()) .collect_vec(), &[], &[], @@ -97,15 +102,15 @@ impl MockProver { return Err(MockProverError::SumcheckExprLenError(gots.len())); } let got = gots.into_iter().next().unwrap(); - let expect = expects.into_iter().reduce(|a, b| a + b).unwrap(); - if expect != got { - return Err(MockProverError::SumcheckExpressionNotMatch( - layer.outs.clone(), - layer.exprs[0].clone(), - expect, - got, - )); - } + // let expect = expects.into_iter().reduce(|a, b| a + b).unwrap(); + // if expect != got { + // return Err(MockProverError::SumcheckExpressionNotMatch( + // layer.outs.clone(), + // layer.exprs[0].clone(), + // expect, + // got, + // )); + // } } LayerType::Zerocheck => { for (got, expect, expr, out) in izip!(gots, expects, &layer.exprs, &layer.outs) @@ -218,14 +223,14 @@ impl EvalExpression { } fn eq_mles<'a, E: ExtensionField>( - points: impl Iterator, + points: Vec>, scalars: &[E], ) -> Vec> { izip!(points, scalars) .map(|(point, scalar)| { MultilinearExtension::from_evaluations_ext_vec( point.len(), - build_eq_x_r_vec_with_scalar(point, *scalar), + build_eq_x_r_vec_with_scalar(&point, *scalar), ) }) .collect_vec() diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 7bda8bc3a..3029416f5 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -2,8 +2,8 @@ use std::marker::PhantomData; use chip::Chip; use ff_ext::ExtensionField; -use gkr::GKRCircuitWitness; -use multilinear_extensions::mle::MultilinearExtension; +use gkr::{GKRCircuit, GKRCircuitWitness}; +use multilinear_extensions::mle::ArcMultilinearExtension; use transcript::Transcript; pub mod chip; @@ -13,7 +13,7 @@ pub mod gkr; pub mod precompiles; pub mod utils; -pub type Phase1WitnessGroup<'a, E> = Vec>>; +pub type Phase1WitnessGroup<'a, E> = Vec>>; pub trait ProtocolBuilder: Sized { type Params; @@ -51,9 +51,10 @@ where /// GKR witness. fn gkr_witness( &self, + chip: &GKRCircuit, phase1_witness_group: Phase1WitnessGroup<'a, E>, challenges: &[E], - ) -> GKRCircuitWitness; + ) -> GKRCircuitWitness<'a, E>; } // TODO: the following trait consists of `commit_phase1`, `commit_phase2`, diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index c2bb47310..944def7b4 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -5,7 +5,7 @@ use crate::{ chip::Chip, evaluation::EvalExpression, gkr::{ - GKRCircuitWitness, GKRProverOutput, + GKRCircuit, GKRCircuitWitness, GKRProverOutput, layer::{Layer, LayerType, LayerWitness}, }, }; @@ -13,7 +13,7 @@ use ff_ext::ExtensionField; use itertools::{Itertools, chain, iproduct}; use multilinear_extensions::{ Expression, ToExpr, - mle::{MultilinearExtension, Point, PointAndEval}, + mle::{ArcMultilinearExtension, MultilinearExtension, Point, PointAndEval}, util::ceil_log2, wit_infer_by_expr, }; @@ -22,7 +22,7 @@ use p3_goldilocks::Goldilocks; use sumcheck::util::optimal_sumcheck_threads; use tiny_keccak::keccakf; -use transcript::BasicTranscript; +use transcript::{BasicTranscript, Transcript}; #[derive(Clone, Debug, Default)] struct KeccakParams {} @@ -54,10 +54,6 @@ fn from_xyz(x: usize, y: usize, z: usize) -> usize { 64 * (5 * y + x) + z } -fn xor(a: F, b: F) -> F { - a + b - a * b - a * b -} - fn and_expr(a: Expression, b: Expression) -> Expression { a.clone() * b.clone() } @@ -78,19 +74,6 @@ fn one_expr() -> Expression { E::BaseField::ONE.expr() } -fn c<'a, E: ExtensionField>(x: usize, z: usize, bits: &[MultilinearExtension<'a, E>]) -> E { - wit_infer_by_expr( - fixed, - witnesses, - structual_witnesses, - instance, - challenges, - expr, - )(0..5) - .map(|y| bits[from_xyz(x, y, z)]) - .fold(E::ZERO, |acc, x| xor(acc, x)) -} - fn c_expr(x: usize, z: usize, state_wits: &[Expression]) -> Expression { (0..5) .map(|y| state_wits[from_xyz(x, y, z)].clone()) @@ -101,46 +84,12 @@ fn from_xz(x: usize, z: usize) -> usize { x * 64 + z } -fn d(x: usize, z: usize, c_vals: &[F]) -> F { - let lhs = from_xz((x + 5 - 1) % 5, z); - let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); - xor(c_vals[lhs], c_vals[rhs]) -} - fn d_expr(x: usize, z: usize, c_wits: &[Expression]) -> Expression { let lhs = from_xz((x + 5 - 1) % 5, z); let rhs = from_xz((x + 1) % 5, (z + 64 - 1) % 64); xor_expr(c_wits[lhs].clone(), c_wits[rhs].clone()) } -fn theta(bits: Vec) -> Vec { - assert_eq!(bits.len(), STATE_SIZE); - - let c_vals = iproduct!(0..5, 0..64) - .map(|(x, z)| c(x, z, &bits)) - .collect_vec(); - - let d_vals = iproduct!(0..5, 0..64) - .map(|(x, z)| d(x, z, &c_vals)) - .collect_vec(); - - bits.iter() - .enumerate() - .map(|(i, bit)| { - let (x, _, z) = to_xyz(i); - xor(*bit, d_vals[from_xz(x, z)]) - }) - .collect() -} - -fn and(a: F, b: F) -> F { - a * b -} - -fn not(a: F) -> F { - F::ONE - a -} - fn keccak_witness<'a, E: ExtensionField>( states: &[[u64; 25]], ) -> [MultilinearExtension<'a, E>; STATE_SIZE] { @@ -169,22 +118,6 @@ fn keccak_witness<'a, E: ExtensionField>( }) } -fn chi(bits: &[F]) -> Vec { - assert_eq!(bits.len(), STATE_SIZE); - - bits.iter() - .enumerate() - .map(|(i, bit)| { - let (x, y, z) = to_xyz(i); - let rhs = and( - not(bits[from_xyz((x + 1) % X, y, z)]), - bits[from_xyz((x + 2) % X, y, z)], - ); - xor(*bit, rhs) - }) - .collect() -} - const ROUNDS: usize = 24; const RC: [u64; ROUNDS] = [ @@ -214,23 +147,6 @@ const RC: [u64; ROUNDS] = [ 0x8000000080008008u64, ]; -fn iota(bits: &[F], round_value: u64) -> Vec { - assert_eq!(bits.len(), STATE_SIZE); - let mut ret = bits.to_vec(); - - let cast = |x| match x { - 0 => F::ZERO, - 1 => F::ONE, - _ => unreachable!(), - }; - - for z in 0..Z { - ret[from_xyz(0, 0, z)] = xor(bits[from_xyz(0, 0, z)], cast((round_value >> z) & 1)); - } - - ret -} - fn iota_expr( bits: &[Expression], index: usize, @@ -276,7 +192,7 @@ impl ProtocolBuilder for KeccakLayout { let final_output = chip.allocate_output_evals::(); (0..ROUNDS).rev().fold(final_output, |round_output, round| { - let chi_output = chip.allocate_wits_in_layer::(); + let (chi_output, [eq]) = chip.allocate_wits_in_zero_layer::(); let exprs = (0..STATE_SIZE) .map(|i| { @@ -292,14 +208,14 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Iota:: compute output"), LayerType::Zerocheck, exprs, + vec![eq.0.expr()], vec![], chi_output.iter().map(|e| e.1.clone()).collect_vec(), - vec![], round_output.to_vec(), vec![], )); - let theta_output = chip.allocate_wits_in_layer::(); + let (theta_output, [eq]) = chip.allocate_wits_in_zero_layer::(); // Apply the effects of the rho + pi permutation directly o the argument of chi // No need for a separate layer @@ -316,14 +232,15 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Chi:: apply rho, pi and chi"), LayerType::Zerocheck, exprs, + vec![eq.0.expr()], vec![], theta_output.iter().map(|e| e.1.clone()).collect_vec(), - vec![], chi_output.iter().map(|e| e.1.clone()).collect_vec(), vec![], )); - let d_and_state = chip.allocate_wits_in_layer::<{ D_SIZE + STATE_SIZE }>(); + let (d_and_state, [eq]) = + chip.allocate_wits_in_zero_layer::<{ D_SIZE + STATE_SIZE }, 1>(); let (d, state2) = d_and_state.split_at(D_SIZE); // Compute post-theta state using original state and D[][] values @@ -338,14 +255,14 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Theta::compute output"), LayerType::Zerocheck, exprs, + vec![eq.0.expr()], vec![], d_and_state.iter().map(|e| e.1.clone()).collect_vec(), - vec![], theta_output.iter().map(|e| e.1.clone()).collect_vec(), vec![], )); - let c = chip.allocate_wits_in_layer::<{ C_SIZE }>(); + let (c, [eq]) = chip.allocate_wits_in_zero_layer::<{ C_SIZE }, 1>(); let c_wits = c.iter().map(|e| e.0.expr()).collect_vec(); // Compute D[][] from C[][] values @@ -357,14 +274,14 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Theta::compute D[x][z]"), LayerType::Zerocheck, d_exprs, + vec![eq.0.expr()], vec![], c.iter().map(|e| e.1.clone()).collect_vec(), - vec![], d.iter().map(|e| e.1.clone()).collect_vec(), vec![], )); - let state = chip.allocate_wits_in_layer::(); + let (state, [eq]) = chip.allocate_wits_in_zero_layer::(); let state_wits = state.iter().map(|s| s.0.expr()).collect_vec(); // Compute C[][] from state @@ -379,9 +296,9 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Theta::compute C[x][z]"), LayerType::Zerocheck, chain!(c_exprs, id_exprs).collect_vec(), + vec![eq.0.expr()], vec![], state.iter().map(|t| t.1.clone()).collect_vec(), - vec![], chain!( c.iter().map(|e| e.1.clone()), state2.iter().map(|e| e.1.clone()) @@ -401,6 +318,24 @@ pub struct KeccakTrace<'a, E: ExtensionField> { pub bits: [MultilinearExtension<'a, E>; STATE_SIZE], } +pub fn infer_layer_witness<'a, E>( + layer: &Layer, + layer_wits: Vec>, + challenges: &[E], +) -> Vec> +where + E: ExtensionField, +{ + layer + .exprs + .iter() + .map(|expr| { + tracing::trace!("infer_layer_witness expr {}", expr,); + wit_infer_by_expr(&[], &layer_wits, &[], &[], challenges, expr) + }) + .collect_vec() +} + impl<'a, E> ProtocolWitnessGenerator<'a, E> for KeccakLayout where E: ExtensionField, @@ -408,75 +343,41 @@ where type Trace = KeccakTrace<'a, E>; fn phase1_witness_group(&self, phase1: Self::Trace) -> Phase1WitnessGroup<'a, E> { - vec![phase1.bits.try_into().unwrap()] + vec![phase1.bits.into_iter().map(Arc::new).collect_vec()] } fn gkr_witness( &self, + circuit: &GKRCircuit, phase1_witness_group: Phase1WitnessGroup<'a, E>, - _challenges: &[E], - ) -> GKRCircuitWitness { - let bits = phase1_witness_group[self.committed_bits_id] - .into_iter() - .map(|bit| bit.as_view()) - .collect_vec(); + challenges: &[E], + ) -> GKRCircuitWitness<'a, E> { + let bits_ref: Vec> = + phase1_witness_group[self.committed_bits_id].clone(); + // layer order from output to input let n_layers = 100; let mut layer_wits = Vec::>::with_capacity(n_layers + 1); - #[allow(clippy::needless_range_loop)] - for round in 0..24 { - if round == 0 { - layer_wits.push(LayerWitness::new(bits)); - } - - let c_wits = iproduct!(0..5usize, 0..64usize) - .map(|(x, z)| c(x, z, &bits)) - .collect_vec(); - - layer_wits.push(LayerWitness::new( - chain!( - c_wits.clone().into_iter().map(|b| vec![b]), - // Note: it seems test pass even if this is uncommented. - // Maybe it's good to assert there are no unused witnesses - // bits.clone().into_iter().map(|b| vec![b]) - ) - .collect_vec(), - )); - - let d_wits = iproduct!(0..5usize, 0..64usize) - .map(|(x, z)| d(x, z, &c_wits)) - .collect_vec(); - - layer_wits.push(LayerWitness::new( - chain!( - d_wits.clone().into_iter().map(|b| vec![b]), - bits.clone().into_iter().map(|b| vec![b]) - ) - .collect_vec(), - vec![], - )); - - bits = theta(bits); - layer_wits.push(LayerWitness::new( - bits.clone().into_iter().map(|b| vec![b]).collect_vec(), - vec![], - )); - - bits = chi(&pi(&rho(&bits))); - layer_wits.push(LayerWitness::new( - bits.clone().into_iter().map(|b| vec![b]).collect_vec(), - vec![], - )); + layer_wits.push(LayerWitness::new(bits_ref.clone())); - if round < 23 { - bits = iota(&bits, RC[round]); - layer_wits.push(LayerWitness::new( - bits.clone().into_iter().map(|b| vec![b]).collect_vec(), - vec![], - )); - } - } + circuit + .layers + .iter() + .rev() + .enumerate() + .fold(&mut layer_wits, |layer_wits, (i, layer)| { + tracing::info!("generating input {i} layer with layer name {}", layer.name); + let wit = { + LayerWitness::new(infer_layer_witness( + &layer, + layer_wits.last().unwrap().bases.clone(), + challenges, + )) + }; + layer_wits.push(wit); + layer_wits + }); // Assumes one input instance let total_witness_size: usize = layer_wits.iter().map(|layer| layer.bases.len()).sum(); @@ -540,44 +441,41 @@ pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { let bits = keccak_witness(&states); // get the view only phase 1 witness, since it need to be commit thus can't be in-place change - let phase1_witness = layout - .phase1_witness_group(KeccakTrace { bits }) - .into_iter() - .map(|mle| mle.as_view()) - .collect_vec(); + let phase1_witness = layout.phase1_witness_group(KeccakTrace { bits }); let mut prover_transcript = BasicTranscript::::new(b"protocol"); // Omit the commit phase1 and phase2. - let gkr_witness = layout.gkr_witness(&phase1_witness, &[]); + let gkr_witness = layout.gkr_witness(&gkr_circuit, phase1_witness, &[]); let out_evals = { - let point = Point::new(); - - let last_witness = gkr_witness.layers[0] - .bases - .clone() - .into_iter() - .flatten() - .collect_vec(); - - // Last witness is missing the final sub-round; apply it now - let expected_result_manual = iota(&last_witness, RC[23]); + let mut point = Point::new(); + point.extend(prover_transcript.sample_vec(1).to_vec()); if test { - let mut state = states; - keccakf(&mut state); - let state = keccak_witness(&state) - .into_iter() - .map(|b| Goldilocks::from_u64(b as u64)) + // sanity check on first instance only + let result_from_witness = gkr_witness.layers[0] + .bases + .iter() + .map(|bit| bit.get_base_field_vec()[0]) .collect_vec(); - assert_eq!(state, expected_result_manual); + let mut state = states.clone(); + keccakf(&mut state[0]); + + assert_eq!( + keccak_witness(&state) // result from tiny keccak + .into_iter() + .map(|b: MultilinearExtension<'_, E>| b.get_base_field_vec()[0]) + .collect_vec(), + result_from_witness + ); } - expected_result_manual + gkr_witness.layers[0] + .bases .iter() .map(|bit| PointAndEval { point: point.clone(), - eval: E::from_bases(&[*bit, Goldilocks::ZERO]), + eval: bit.evaluate(&point), }) .collect_vec() }; @@ -614,12 +512,17 @@ mod tests { #[test] fn test_keccakf() { + let _ = tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .with_test_writer() + .try_init(); + for _ in 0..3 { let random_u64: u64 = rand::random(); // Use seeded rng for debugging convenience let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); - run_keccakf(state, true, true); + run_keccakf(vec![state], true, true); } } } diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 036a01e1b..a6222bfb1 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -1,1205 +1,1205 @@ -use std::{cmp::Ordering, marker::PhantomData, sync::Arc}; - -use crate::{ - ProtocolBuilder, ProtocolWitnessGenerator, - chip::Chip, - evaluation::{EvalExpression, PointAndEval}, - gkr::{ - GKRCircuitWitness, GKRProverOutput, - layer::{Layer, LayerType, LayerWitness}, - }, - precompiles::utils::{MaskRepresentation, nest, not8_expr, zero_expr}, -}; -use ndarray::{ArrayView, Ix2, Ix3, s}; - -use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; -use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; - -use ff_ext::{ExtensionField, SmallField}; -use itertools::{Itertools, chain, iproduct, zip_eq}; -use p3_goldilocks::Goldilocks; -use subprotocols::expression::{Constant, Expression, Witness}; -use tiny_keccak::keccakf; -use transcript::BasicTranscript; - -type E = BinomialExtensionField; - -#[derive(Clone, Debug, Default)] -pub struct KeccakParams {} - -#[derive(Clone, Debug, Default)] -pub struct KeccakLayout { - _params: KeccakParams, - _input_columns: Vec, - _result: Vec, - _marker: PhantomData, -} - -fn expansion_expr(expansion: &[(usize, Witness)]) -> Expression { - let (total, ret) = expansion - .iter() - .rev() - .fold((0, zero_expr()), |acc, (sz, felt)| { - ( - acc.0 + sz, - acc.1 * Expression::Const(Constant::Base(1 << sz)) + (*felt).into(), - ) - }); - - assert_eq!(total, SIZE); - ret -} - -/// Compute an adequate split of 64-bits into chunks for performing a rotation -/// by `delta`. The first element of the return value is the vec of chunk sizes. -/// The second one is the length of its suffix that needs to be rotated -fn rotation_split(delta: usize) -> (Vec, usize) { - let delta = delta % 64; - - if delta == 0 { - return (vec![32, 32], 0); - } - - // This split meets all requirements except for <= 16 sizes - let split32 = match delta.cmp(&32) { - Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], - Ordering::Equal => vec![32, 32], - Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], - }; - - // Split off large chunks - let split16 = split32 - .into_iter() - .flat_map(|size| { - assert!(size < 32); - if size <= 16 { - vec![size] - } else { - vec![16, size - 16] - } - }) - .collect_vec(); - - let mut sum = 0; - for (i, size) in split16.iter().rev().enumerate() { - sum += size; - if sum == delta { - return (split16, i + 1); - } - } - - panic!(); -} - -struct ConstraintSystem { - expressions: Vec, - expr_names: Vec, - evals: Vec, - and_lookups: Vec, - xor_lookups: Vec, - range_lookups: Vec, -} - -impl ConstraintSystem { - fn new() -> Self { - ConstraintSystem { - expressions: vec![], - evals: vec![], - expr_names: vec![], - and_lookups: vec![], - xor_lookups: vec![], - range_lookups: vec![], - } - } - - fn add_constraint(&mut self, expr: Expression, name: String) { - self.expressions.push(expr); - self.evals.push(zero_eval()); - self.expr_names.push(name); - } - - fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { - self.and_lookups.push(CenoLookup::And(a, b, c)); - } - - fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { - self.xor_lookups.push(CenoLookup::Xor(a, b, c)); - } - - /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. - /// In general it can be done by two U16 checks: one for `value` and one for - /// `value << (16 - size)`. - fn lookup_range(&mut self, value: Expression, size: usize) { - assert!(size <= 16); - self.range_lookups.push(CenoLookup::U16(value.clone())); - if size < 16 { - self.range_lookups.push(CenoLookup::U16( - value * Expression::Const(Constant::Base(1 << (16 - size))), - )) - } - } - - fn constrain_eq(&mut self, lhs: Expression, rhs: Expression, name: String) { - self.add_constraint(lhs - rhs, name); - } - - // Constrains that lhs and rhs encode the same value of SIZE bits - // WARNING: Assumes that forall i, (lhs[i].1 < (2 ^ lhs[i].0)) - // This needs to be constrained separately - fn constrain_reps_eq( - &mut self, - lhs: &[(usize, Witness)], - rhs: &[(usize, Witness)], - name: String, - ) { - self.add_constraint( - expansion_expr::(lhs) - expansion_expr::(rhs), - name, - ); - } - - /// Checks that `rot8` is equal to `input8` left-rotated by `delta`. - /// `rot8` and `input8` each consist of 8 chunks of 8-bits. - /// - /// `split_rep` is a chunk representation of the input which - /// allows to reduce the required rotation to an array rotation. It may use - /// non-uniform chunks. - /// - /// For example, when `delta = 2`, the 64 bits are split into chunks of - /// sizes `[16a, 14b, 2c, 16d, 14e, 2f]` (here the first chunks contains the - /// least significant bits so a left rotation will become a right rotation - /// of the array). To perform the required rotation, we can - /// simply rotate the array: [2f, 16a, 14b, 2c, 16d, 14e]. - /// - /// In the first step, we check that `rot8` and `split_rep` represent the - /// same 64 bits. In the second step we check that `rot8` and the appropiate - /// array rotation of `split_rep` represent the same 64 bits. - /// - /// This type of representation-equality check is done by packing chunks - /// into sizes of exactly 32 (so for `delta = 2` we compare [16a, 14b, - /// 2c] to the first 4 elements of `rot8`). In addition, we do range - /// checks on `split_rep` which check that the felts meet the required - /// sizes. - /// - /// This algorithm imposes the following general requirements for - /// `split_rep`: - /// - There exists a suffix of `split_rep` which sums to exactly `delta`. - /// This suffix can contain several elements. - /// - Chunk sizes are at most 16 (so they can be range-checked) or they are - /// exactly equal to 32. - /// - There exists a prefix of chunks which sums exactly to 32. This must - /// hold for the rotated array as well. - /// - The number of chunks should be as small as possible. - /// - /// Consult the method `rotation_split` to see how splits are computed for a - /// given `delta - /// - /// Note that the function imposes range checks on chunk values, but it - /// makes two exceptions: - /// 1. It doesn't check the 8-bit reps (input and output). This is - /// because all 8-bit reps in the global circuit are implicitly - /// range-checked because they are lookup arguments. - /// 2. It doesn't range-check 32-bit chunks. This is because a 32-bit - /// chunk value is checked to be equal to the composition of 4 8-bit - /// chunks. As mentioned in 1., these can be trusted to be range - /// checked, so the resulting 32-bit is correct by construction as - /// well. - fn constrain_left_rotation64( - &mut self, - input8: &[Witness], - split_rep: &[(usize, Witness)], - rot8: &[Witness], - delta: usize, - label: String, - ) { - assert_eq!(input8.len(), 8); - assert_eq!(rot8.len(), 8); - - // Assert that the given split witnesses are correct for this delta - let (sizes, chunks_rotation) = rotation_split(delta); - assert_eq!(sizes, split_rep.iter().map(|e| e.0).collect_vec()); - - // Lookup ranges - for (size, elem) in split_rep { - if *size != 32 { - self.lookup_range((*elem).into(), *size); - } - } - - // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are - // the same 64 bitstring - let mut helper = |rep8: &[Witness], rep_x: &[(usize, Witness)], chunks_rotation: usize| { - // Do the same thing for the two 32-bit halves - let mut rep_x = rep_x.to_owned(); - rep_x.rotate_right(chunks_rotation); - - for i in 0..2 { - // The respective 4 elements in the byte representation - let lhs = rep8[4 * i..4 * (i + 1)] - .iter() - .map(|wit| (8, *wit)) - .collect_vec(); - let cnt = rep_x.len() / 2; - let rhs = &rep_x[cnt * i..cnt * (i + 1)]; - - assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); - - self.constrain_reps_eq::<32>( - &lhs, - rhs, - format!( - "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", - sizes - ), - ); - } - }; - - helper(input8, split_rep, 0); - helper(rot8, split_rep, chunks_rotation); - } -} - -const ROUNDS: usize = 24; - -const RC: [u64; ROUNDS] = [ - 1u64, - 0x8082u64, - 0x800000000000808au64, - 0x8000000080008000u64, - 0x808bu64, - 0x80000001u64, - 0x8000000080008081u64, - 0x8000000000008009u64, - 0x8au64, - 0x88u64, - 0x80008009u64, - 0x8000000au64, - 0x8000808bu64, - 0x800000000000008bu64, - 0x8000000000008089u64, - 0x8000000000008003u64, - 0x8000000000008002u64, - 0x8000000000000080u64, - 0x800au64, - 0x800000008000000au64, - 0x8000000080008081u64, - 0x8000000000008080u64, - 0x80000001u64, - 0x8000000080008008u64, -]; - -const ROTATION_CONSTANTS: [[usize; 5]; 5] = [ - [0, 1, 62, 28, 27], - [36, 44, 6, 55, 20], - [3, 10, 43, 25, 39], - [41, 45, 15, 21, 8], - [18, 2, 61, 56, 14], -]; - -pub const KECCAK_INPUT_SIZE: usize = 50; -pub const KECCAK_OUTPUT_SIZE: usize = 50; - -pub const AND_LOOKUPS_PER_ROUND: usize = 200; -pub const XOR_LOOKUPS_PER_ROUND: usize = 608; -pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; -pub const LOOKUP_FELTS_PER_ROUND: usize = - 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; - -pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; -pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; -pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; - -macro_rules! allocate_and_split { - ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ - let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); - let mut iter = witnesses.into_iter(); - ( - $( - iter.by_ref().take($size).collect_vec(), - )* - ) - }}; - } - -impl ProtocolBuilder for KeccakLayout { - type Params = KeccakParams; - - fn init(params: Self::Params) -> Self { - Self { - _params: params, - ..Default::default() - } - } - - fn build_commit_phase(&mut self, chip: &mut Chip) { - let _ = chip.allocate_committed::<{ 50 + 40144 }>(); - } - - fn build_gkr_phase(&mut self, chip: &mut Chip) { - let final_outputs = - chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS }>(); - - let mut final_outputs_iter = final_outputs.iter(); - - let [keccak_output32, keccak_input32, lookup_outputs] = [ - KECCAK_OUTPUT_SIZE, - KECCAK_INPUT_SIZE, - // LOOKUPS_PER_ROUND - LOOKUP_FELTS_PER_ROUND * ROUNDS, - ] - .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); - - let keccak_output32 = keccak_output32.to_vec(); - let keccak_input32 = keccak_input32.to_vec(); - let lookup_outputs = lookup_outputs.to_vec(); - - let (keccak_output8, []) = chip.allocate_wits_in_layer::<200, 0>(); - - let keccak_output8: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &keccak_output8).unwrap(); - - let mut expressions = vec![]; - let mut evals = vec![]; - let mut expr_names = vec![]; - - let mut global_and_lookup = 0; - let mut global_xor_lookup = 3 * AND_LOOKUPS; - let mut global_range_lookup = 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS; - - let mut openings = 0; - - for x in 0..5 { - for y in 0..5 { - for k in 0..2 { - // create an expression combining 4 elements of state8 into a single 32-bit felt - let expr = expansion_expr::<32>( - &keccak_output8 - .slice(s![x, y, 4 * k..4 * (k + 1)]) - .iter() - .map(|e| (8, e.0)) - .collect_vec(), - ); - expressions.push(expr); - evals.push(keccak_output32[evals.len()].clone()); - expr_names.push(format!("build 32-bit output: {x}, {y}, {k}")); - } - } - } - - chip.add_layer(Layer::new( - "build 32-bit output".to_string(), - LayerType::Zerocheck, - expressions, - vec![], - keccak_output8 - .into_iter() - .map(|e| e.1.clone()) - .collect_vec(), - vec![], - evals, - expr_names, - )); - - let state8_loop = (0..ROUNDS).rev().fold( - keccak_output8.iter().map(|e| e.1.clone()).collect_vec(), - |round_output, round| { - #[allow(non_snake_case)] - let ( - state8, - c_aux, - c_temp, - c_rot, - d, - theta_output, - rotation_witness, - rhopi_output, - nonlinear, - chi_output, - iota_output, - ) = allocate_and_split!( - chip, 1656, 200, 200, 30, 40, 40, 200, 146, 200, 200, 200, 200 - ); - - let total_witnesses = 200 + 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 200 + 200; - // dbg!(total_witnesses); - assert_eq!(1656, total_witnesses); - - let bases = chain!( - state8.clone(), - c_aux.clone(), - c_temp.clone(), - c_rot.clone(), - d.clone(), - theta_output.clone(), - rotation_witness.clone(), - rhopi_output.clone(), - nonlinear.clone(), - chi_output.clone(), - iota_output.clone(), - ) - .collect_vec(); - - for wit in &bases { - chip.allocate_opening(openings, wit.clone().1); - openings += 1; - } - - // TODO: ndarrays can be replaced with normal arrays - - // Input state of the round in 8-bit chunks - let state8: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &state8).unwrap(); - - let mut system = ConstraintSystem::new(); - - // The purpose is to compute the auxiliary array - // c[i] = XOR (state[j][i]) for j in 0..5 - // We unroll it into - // c_aux[i][j] = XOR (state[k][i]) for k in 0..j - // We use c_aux[i][4] instead of c[i] - // c_aux is also stored in 8-bit chunks - let c_aux: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); - - for i in 0..5 { - for k in 0..8 { - // Initialize first element - system.constrain_eq( - state8[[0, i, k]].0.into(), - c_aux[[i, 0, k]].0.into(), - "init c_aux".to_string(), - ); - } - for j in 1..5 { - // Check xor using lookups over all chunks - for k in 0..8 { - system.lookup_xor8( - c_aux[[i, j - 1, k]].0.into(), - state8[[j, i, k]].0.into(), - c_aux[[i, j, k]].0.into(), - ); - } - } - } - - // Compute c_rot[i] = c[i].rotate_left(1) - // To understand how rotations are performed in general, consult the - // documentation of `constrain_left_rotation64`. Here c_temp is the split - // witness for a 1-rotation. - - let c_temp: ArrayView<(Witness, EvalExpression), Ix2> = - ArrayView::from_shape((5, 6), &c_temp).unwrap(); - let c_rot: ArrayView<(Witness, EvalExpression), Ix2> = - ArrayView::from_shape((5, 8), &c_rot).unwrap(); - - let (sizes, _) = rotation_split(1); - - for i in 0..5 { - assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); - - system.constrain_left_rotation64( - &c_aux.slice(s![i, 4, ..]).iter().map(|e| e.0).collect_vec(), - &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) - .map(|(e, sz)| (*sz, e.0)) - .collect_vec(), - &c_rot.slice(s![i, ..]).iter().map(|e| e.0).collect_vec(), - 1, - "theta rotation".to_string(), - ); - } - - // d is computed simply as XOR of required elements of c (and rotations) - // again stored as 8-bit chunks - let d: ArrayView<(Witness, EvalExpression), Ix2> = - ArrayView::from_shape((5, 8), &d).unwrap(); - - for i in 0..5 { - for k in 0..8 { - system.lookup_xor8( - c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), - c_rot[[(i + 1) % 5, k]].0.into(), - d[[i, k]].0.into(), - ) - } - } - - // output state of the Theta sub-round, simple XOR, in 8-bit chunks - let theta_output: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); - - for i in 0..5 { - for j in 0..5 { - for k in 0..8 { - system.lookup_xor8( - state8[[j, i, k]].0.into(), - d[[i, k]].0.into(), - theta_output[[j, i, k]].0.into(), - ) - } - } - } - - // output state after applying both Rho and Pi sub-rounds - // sub-round Pi is a simple permutation of 64-bit lanes - // sub-round Rho requires rotations - let rhopi_output: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); - - // iterator over split witnesses - let mut rotation_witness = rotation_witness.iter(); - - for i in 0..5 { - #[allow(clippy::needless_range_loop)] - for j in 0..5 { - let arg = theta_output - .slice(s!(j, i, ..)) - .iter() - .map(|e| e.0) - .collect_vec(); - let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); - let many = sizes.len(); - let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) - .map(|(sz, (wit, _))| (sz, *wit)) - .collect_vec(); - let arg_rotated = rhopi_output - .slice(s!((2 * i + 3 * j) % 5, j, ..)) - .iter() - .map(|e| e.0) - .collect_vec(); - system.constrain_left_rotation64( - &arg, - &rep_split, - &arg_rotated, - ROTATION_CONSTANTS[j][i], - format!("RHOPI {i}, {j}"), - ); - } - } - - let chi_output: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); - - // for the Chi sub-round, we use an intermediate witness storing the result of - // the required AND - let nonlinear: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); - - for i in 0..5 { - for j in 0..5 { - for k in 0..8 { - system.lookup_and8( - not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), - rhopi_output[[j, (i + 2) % 5, k]].0.into(), - nonlinear[[j, i, k]].0.into(), - ); - - system.lookup_xor8( - rhopi_output[[j, i, k]].0.into(), - nonlinear[[j, i, k]].0.into(), - chi_output[[j, i, k]].0.into(), - ); - } - } - } - - // TODO: 24/25 elements stay the same after Iota; eliminate duplication? - let iota_output: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); - - for i in 0..5 { - for j in 0..5 { - if i == 0 && j == 0 { - for k in 0..8 { - system.lookup_xor8( - chi_output[[j, i, k]].0.into(), - Expression::Const(Constant::Base( - ((RC[round] >> (k * 8)) & 0xFF) as i64, - )), - iota_output[[j, i, k]].0.into(), - ); - } - } else { - for k in 0..8 { - system.constrain_eq( - iota_output[[j, i, k]].0.into(), - chi_output[[j, i, k]].0.into(), - "nothing special".to_string(), - ); - } - } - } - } - - let ConstraintSystem { - mut expressions, - mut expr_names, - mut evals, - and_lookups, - xor_lookups, - range_lookups, - .. - } = system; - - iota_output - .into_iter() - .enumerate() - .map(|(i, val)| { - expressions.push(val.0.into()); - expr_names.push(format!("iota_output {i}")); - evals.push(round_output[i].clone()); - }) - .count(); - - for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) - .flatten() - .enumerate() - { - expressions.push(lookup); - expr_names.push(format!("round {round}: {i}th lookup felt")); - let idx = if i < 3 * AND_LOOKUPS_PER_ROUND { - &mut global_and_lookup - } else if i < 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND { - &mut global_xor_lookup - } else { - &mut global_range_lookup - }; - evals.push(lookup_outputs[*idx].clone()); - *idx += 1; - } - - chip.add_layer(Layer::new( - format!("Round {round}"), - LayerType::Zerocheck, - expressions, - vec![], - bases.into_iter().map(|e| e.1).collect_vec(), - vec![], - evals, - expr_names, - )); - - state8.into_iter().map(|e| e.1.clone()).collect_vec() - }, - ); - - assert!(global_and_lookup == 3 * AND_LOOKUPS); - assert!(global_xor_lookup == 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); - assert!(global_range_lookup == LOOKUP_FELTS_PER_ROUND * ROUNDS); - - let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); - - let state8: ArrayView<(Witness, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &state8).unwrap(); - - let mut expressions = vec![]; - let mut evals = vec![]; - let mut expr_names = vec![]; - - for x in 0..5 { - for y in 0..5 { - for k in 0..2 { - // create an expression combining 4 elements of state8 into a single 32-bit felt - let expr = expansion_expr::<32>( - state8 - .slice(s![x, y, 4 * k..4 * (k + 1)]) - .iter() - .map(|e| (8, e.0)) - .collect_vec() - .as_slice(), - ); - expressions.push(expr); - evals.push(keccak_input32[evals.len()].clone()); - expr_names.push(format!("build 32-bit input: {x}, {y}, {k}")); - } - } - } - - // TODO: eliminate this duplication - zip_eq(state8.iter(), state8_loop.iter()) - .map(|(e, e_loop)| { - expressions.push(e.0.into()); - evals.push(e_loop.clone()); - expr_names.push("state8 identity".to_string()); - }) - .count(); - - chip.add_layer(Layer::new( - "build 32-bit input".to_string(), - LayerType::Zerocheck, - expressions, - vec![], - state8.into_iter().map(|e| e.1.clone()).collect_vec(), - vec![], - evals, - expr_names, - )); - - // TODO: allocate everything - chip.allocate_base_opening(0, state8[[0, 0, 0]].clone().1); - chip.allocate_base_opening(1, state8[[0, 0, 1]].clone().1); - } -} - -#[derive(Clone, Default)] -pub struct KeccakTrace { - pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, -} - -impl ProtocolWitnessGenerator for KeccakLayout -where - E: ExtensionField, -{ - type Trace = KeccakTrace; - - fn phase1_witness_group(&self, phase1: Self::Trace) -> Vec> { - let mut poly = vec![vec![]; KECCAK_INPUT_SIZE]; - for instance in phase1.instances { - let felts = u64s_to_felts::(instance.into_iter().map(|e| e as u64).collect_vec()); - for i in 0..KECCAK_INPUT_SIZE { - poly[i].push(felts[i]); - } - } - poly - } - - fn gkr_witness(&self, phase1: &[Vec], _challenges: &[E]) -> GKRCircuitWitness { - let n_layers = 24 + 2 + 1; - let mut layer_wits = vec![ - LayerWitness { - bases: vec![], - bases: vec![], - num_vars: 1 - }; - n_layers - ]; - - let num_instances = phase1[0].len(); - - for i in 0..num_instances { - fn conv64to8(input: u64) -> [u64; 8] { - MaskRepresentation::new(vec![(64, input).into()]) - .convert(vec![8; 8]) - .values() - .try_into() - .unwrap() - } - - let mut com_state = vec![]; - #[allow(clippy::needless_range_loop)] - for j in 0..KECCAK_INPUT_SIZE { - com_state.push(phase1[j][i]); - } - - let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; - let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; - let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; - - let mut add_and = |a: u64, b: u64, round: usize| { - let c = a & b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - and_lookups[round].extend(vec![a, b, c]); - }; - - let mut add_xor = |a: u64, b: u64, round: usize| { - let c = a ^ b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - xor_lookups[round].extend(vec![a, b, c]); - }; - - let mut add_range = |value: u64, size: usize, round: usize| { - assert!(size <= 16, "{size}"); - range_lookups[round].push(value); - if size < 16 { - range_lookups[round].push(value << (16 - size)); - assert!(value << (16 - size) < (1 << 16)); - } - }; - - let state32 = com_state - .into_iter() - // TODO double check assumptions about canonical - .map(|e| e.to_canonical_u64()) - .collect_vec(); - - let mut state64 = [[0u64; 5]; 5]; - let mut state8 = [[[0u64; 8]; 5]; 5]; - - zip_eq(iproduct!(0..5, 0..5), state32.clone().iter().tuples()) - .map(|((x, y), (lo, hi))| { - state64[x][y] = lo | (hi << 32); - }) - .count(); - - for x in 0..5 { - for y in 0..5 { - state8[x][y] = conv64to8(state64[x][y]); - } - } - - let mut curr_layer = 0; - let mut push_instance = |wits: Vec| { - let felts = u64s_to_felts::(wits); - if layer_wits[curr_layer].bases.is_empty() { - layer_wits[curr_layer] = LayerWitness::new(nest::(&felts), vec![]); - } else { - assert_eq!(felts.len(), layer_wits[curr_layer].bases.len()); - for (i, base) in layer_wits[curr_layer].bases.iter_mut().enumerate() { - base.push(felts[i]); - } - } - curr_layer += 1; - }; - - push_instance(state8.into_iter().flatten().flatten().collect_vec()); - - #[allow(clippy::needless_range_loop)] - for round in 0..ROUNDS { - let mut c_aux64 = [[0u64; 5]; 5]; - let mut c_aux8 = [[[0u64; 8]; 5]; 5]; - - for i in 0..5 { - c_aux64[i][0] = state64[0][i]; - c_aux8[i][0] = conv64to8(c_aux64[i][0]); - for j in 1..5 { - c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; - c_aux8[i][j] = conv64to8(c_aux64[i][j]); - - for k in 0..8 { - add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); - } - } - } - - let mut c64 = [0u64; 5]; - let mut c8 = [[0u64; 8]; 5]; - - for x in 0..5 { - c64[x] = c_aux64[x][4]; - c8[x] = conv64to8(c64[x]); - } - - let mut c_temp = [[0u64; 6]; 5]; - for i in 0..5 { - let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) - .convert(vec![16, 15, 1, 16, 15, 1]); - c_temp[i] = rep.values().try_into().unwrap(); - for mask in rep.rep { - add_range(mask.value, mask.size, round); - } - } - - let mut crot64 = [0u64; 5]; - let mut crot8 = [[0u64; 8]; 5]; - for i in 0..5 { - crot64[i] = c64[i].rotate_left(1); - crot8[i] = conv64to8(crot64[i]); - } - - let mut d64 = [0u64; 5]; - let mut d8 = [[0u64; 8]; 5]; - for x in 0..5 { - d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); - d8[x] = conv64to8(d64[x]); - for k in 0..8 { - add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); - } - } - - let mut theta_state64 = state64; - let mut theta_state8 = [[[0u64; 8]; 5]; 5]; - let mut rotation_witness = vec![]; - - for x in 0..5 { - for y in 0..5 { - theta_state64[y][x] ^= d64[x]; - theta_state8[y][x] = conv64to8(theta_state64[y][x]); - - for k in 0..8 { - add_xor(state8[y][x][k], d8[x][k], round); - } - - let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); - let rep = MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) - .convert(sizes); - for mask in rep.rep.iter() { - if mask.size != 32 { - add_range(mask.value, mask.size, round); - } - } - rotation_witness.extend(rep.values()); - } - } - - // Rho and Pi steps - let mut rhopi_output64 = [[0u64; 5]; 5]; - let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; - - for x in 0..5 { - for y in 0..5 { - rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = - theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); - } - } - - for x in 0..5 { - for y in 0..5 { - rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]); - } - } - - // Chi step - - let mut nonlinear64 = [[0u64; 5]; 5]; - let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - nonlinear64[y][x] = - !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; - nonlinear8[y][x] = conv64to8(nonlinear64[y][x]); - - for k in 0..8 { - add_and( - 0xFF - rhopi_output8[y][(x + 1) % 5][k], - rhopi_output8[y][(x + 2) % 5][k], - round, - ); - } - } - } - - let mut chi_output64 = [[0u64; 5]; 5]; - let mut chi_output8 = [[[0u64; 8]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; - chi_output8[y][x] = conv64to8(chi_output64[y][x]); - for k in 0..8 { - add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) - } - } - } - - // Iota step - let mut iota_output64 = chi_output64; - let mut iota_output8 = [[[0u64; 8]; 5]; 5]; - iota_output64[0][0] ^= RC[round]; - - for k in 0..8 { - add_xor(chi_output8[0][0][k], (RC[round] >> (k * 8)) & 0xFF, round); - } - - for x in 0..5 { - for y in 0..5 { - iota_output8[x][y] = conv64to8(iota_output64[x][y]); - } - } - - let all_wits64 = [ - state8.into_iter().flatten().flatten().collect_vec(), - c_aux8.into_iter().flatten().flatten().collect_vec(), - c_temp.into_iter().flatten().collect_vec(), - crot8.into_iter().flatten().collect_vec(), - d8.into_iter().flatten().collect_vec(), - theta_state8.into_iter().flatten().flatten().collect_vec(), - rotation_witness, - rhopi_output8.into_iter().flatten().flatten().collect_vec(), - nonlinear8.into_iter().flatten().flatten().collect_vec(), - chi_output8.into_iter().flatten().flatten().collect_vec(), - iota_output8.into_iter().flatten().flatten().collect_vec(), - ]; - - push_instance(all_wits64.into_iter().flatten().collect_vec()); - - state8 = iota_output8; - state64 = iota_output64; - } - - let mut keccak_output32 = vec![vec![vec![0; 2]; 5]; 5]; - - for x in 0..5 { - for y in 0..5 { - keccak_output32[x][y] = MaskRepresentation::from( - state8[x][y].into_iter().map(|e| (8, e)).collect_vec(), - ) - .convert(vec![32; 2]) - .values(); - } - } - - push_instance(state8.into_iter().flatten().flatten().collect_vec()); - - // For temporary convenience, use one extra layer to store the correct outputs - // of the circuit This is not used during proving - let lookups = chain!( - (0..ROUNDS).rev().flat_map(|i| and_lookups[i].clone()), - (0..ROUNDS).rev().flat_map(|i| xor_lookups[i].clone()), - (0..ROUNDS).rev().flat_map(|i| range_lookups[i].clone()) - ) - .collect_vec(); - - push_instance( - chain!( - keccak_output32.into_iter().flatten().flatten(), - state32, - lookups - ) - .collect_vec(), - ); - } - - let len = layer_wits.len() - 1; - layer_wits[..len].reverse(); - - GKRCircuitWitness { layers: layer_wits } - } -} - -pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) { - let params = KeccakParams {}; - let (layout, chip) = KeccakLayout::build(params); - - let mut instances = vec![]; - for state in &states { - let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); - let state_mask32 = state_mask64.convert(vec![32; 50]); - - instances.push( - state_mask32 - .values() - .iter() - .map(|e| *e as u32) - .collect_vec() - .try_into() - .unwrap(), - ); - } - - let num_instances = instances.len(); - let phase1_witness = layout.phase1_witness_group(KeccakTrace { - instances: instances.clone(), - }); - - let mut prover_transcript = BasicTranscript::::new(b"protocol"); - - // Omit the commit phase1 and phase2. - let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &[]); - - let out_evals = { - let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); - let point = Arc::new(vec![E::from_u64(29); log2_num_instances as usize]); - - if test_outputs { - // Confront outputs with tiny_keccak::keccakf call - let mut instance_outputs = vec![vec![]; num_instances]; - for base in gkr_witness - .layers - .last() - .unwrap() - .bases - .iter() - .take(KECCAK_OUTPUT_SIZE) - { - assert_eq!(base.len(), num_instances); - for i in 0..num_instances { - instance_outputs[i].push(base[i]); - } - } - - for i in 0..num_instances { - let mut state = states[i]; - keccakf(&mut state); - assert_eq!( - state - .to_vec() - .iter() - .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) - .map(|e| Goldilocks::from_u64(e as u64)) - .collect_vec(), - instance_outputs[i] - ); - } - } - - let out_evals = gkr_witness - .layers - .last() - .unwrap() - .bases - .iter() - .map(|base| PointAndEval { - point: point.clone(), - eval: subprotocols::utils::evaluate_mle_ext(base, &point), - }) - .collect_vec(); - - assert_eq!( - out_evals.len(), - KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS - ); - - out_evals - }; - - let gkr_circuit = chip.gkr_circuit(); - dbg!(&gkr_circuit.layers.len()); - let GKRProverOutput { gkr_proof, .. } = gkr_circuit - .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) - .expect("Failed to prove phase"); - - if verify { - { - let mut verifier_transcript = BasicTranscript::::new(b"protocol"); - - gkr_circuit - .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) - .expect("GKR verify failed"); - - // Omit the PCS opening phase. - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use rand::{Rng, SeedableRng}; - - #[test] - fn test_keccakf() { - for _ in 0..3 { - // let random_u64: u64 = rand::random(); - // Use seeded rng for debugging convenience - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - - let num_instances = 8; - let mut states: Vec<[u64; 25]> = vec![]; - - for _ in 0..num_instances { - states.push(std::array::from_fn(|_| rng.gen())) - } - run_faster_keccakf(states, true, true); - } - } - - // TODO: make it pass - #[ignore] - #[test] - fn test_keccakf_nonpow2() { - for _ in 0..3 { - // let random_u64: u64 = rand::random(); - // Use seeded rng for debugging convenience - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - - let num_instances = 3; - let mut states: Vec<[u64; 25]> = vec![]; - - for _ in 0..num_instances { - states.push(std::array::from_fn(|_| rng.gen())) - } - run_faster_keccakf(states, true, true); - } - } -} +// use std::{cmp::Ordering, marker::PhantomData, sync::Arc}; + +// use crate::{ +// ProtocolBuilder, ProtocolWitnessGenerator, +// chip::Chip, +// evaluation::{EvalExpression, PointAndEval}, +// gkr::{ +// GKRCircuitWitness, GKRProverOutput, +// layer::{Layer, LayerType, LayerWitness}, +// }, +// precompiles::utils::{MaskRepresentation, nest, not8_expr, zero_expr}, +// }; +// use ndarray::{ArrayView, Ix2, Ix3, s}; + +// use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; +// use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; + +// use ff_ext::{ExtensionField, SmallField}; +// use itertools::{Itertools, chain, iproduct, zip_eq}; +// use p3_goldilocks::Goldilocks; +// use subprotocols::expression::{Constant, Expression, Witness}; +// use tiny_keccak::keccakf; +// use transcript::BasicTranscript; + +// type E = BinomialExtensionField; + +// #[derive(Clone, Debug, Default)] +// pub struct KeccakParams {} + +// #[derive(Clone, Debug, Default)] +// pub struct KeccakLayout { +// _params: KeccakParams, +// _input_columns: Vec, +// _result: Vec, +// _marker: PhantomData, +// } + +// fn expansion_expr(expansion: &[(usize, Witness)]) -> Expression { +// let (total, ret) = expansion +// .iter() +// .rev() +// .fold((0, zero_expr()), |acc, (sz, felt)| { +// ( +// acc.0 + sz, +// acc.1 * Expression::Const(Constant::Base(1 << sz)) + (*felt).into(), +// ) +// }); + +// assert_eq!(total, SIZE); +// ret +// } + +// /// Compute an adequate split of 64-bits into chunks for performing a rotation +// /// by `delta`. The first element of the return value is the vec of chunk sizes. +// /// The second one is the length of its suffix that needs to be rotated +// fn rotation_split(delta: usize) -> (Vec, usize) { +// let delta = delta % 64; + +// if delta == 0 { +// return (vec![32, 32], 0); +// } + +// // This split meets all requirements except for <= 16 sizes +// let split32 = match delta.cmp(&32) { +// Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], +// Ordering::Equal => vec![32, 32], +// Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], +// }; + +// // Split off large chunks +// let split16 = split32 +// .into_iter() +// .flat_map(|size| { +// assert!(size < 32); +// if size <= 16 { +// vec![size] +// } else { +// vec![16, size - 16] +// } +// }) +// .collect_vec(); + +// let mut sum = 0; +// for (i, size) in split16.iter().rev().enumerate() { +// sum += size; +// if sum == delta { +// return (split16, i + 1); +// } +// } + +// panic!(); +// } + +// struct ConstraintSystem { +// expressions: Vec, +// expr_names: Vec, +// evals: Vec, +// and_lookups: Vec, +// xor_lookups: Vec, +// range_lookups: Vec, +// } + +// impl ConstraintSystem { +// fn new() -> Self { +// ConstraintSystem { +// expressions: vec![], +// evals: vec![], +// expr_names: vec![], +// and_lookups: vec![], +// xor_lookups: vec![], +// range_lookups: vec![], +// } +// } + +// fn add_constraint(&mut self, expr: Expression, name: String) { +// self.expressions.push(expr); +// self.evals.push(zero_eval()); +// self.expr_names.push(name); +// } + +// fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { +// self.and_lookups.push(CenoLookup::And(a, b, c)); +// } + +// fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { +// self.xor_lookups.push(CenoLookup::Xor(a, b, c)); +// } + +// /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. +// /// In general it can be done by two U16 checks: one for `value` and one for +// /// `value << (16 - size)`. +// fn lookup_range(&mut self, value: Expression, size: usize) { +// assert!(size <= 16); +// self.range_lookups.push(CenoLookup::U16(value.clone())); +// if size < 16 { +// self.range_lookups.push(CenoLookup::U16( +// value * Expression::Const(Constant::Base(1 << (16 - size))), +// )) +// } +// } + +// fn constrain_eq(&mut self, lhs: Expression, rhs: Expression, name: String) { +// self.add_constraint(lhs - rhs, name); +// } + +// // Constrains that lhs and rhs encode the same value of SIZE bits +// // WARNING: Assumes that forall i, (lhs[i].1 < (2 ^ lhs[i].0)) +// // This needs to be constrained separately +// fn constrain_reps_eq( +// &mut self, +// lhs: &[(usize, Witness)], +// rhs: &[(usize, Witness)], +// name: String, +// ) { +// self.add_constraint( +// expansion_expr::(lhs) - expansion_expr::(rhs), +// name, +// ); +// } + +// /// Checks that `rot8` is equal to `input8` left-rotated by `delta`. +// /// `rot8` and `input8` each consist of 8 chunks of 8-bits. +// /// +// /// `split_rep` is a chunk representation of the input which +// /// allows to reduce the required rotation to an array rotation. It may use +// /// non-uniform chunks. +// /// +// /// For example, when `delta = 2`, the 64 bits are split into chunks of +// /// sizes `[16a, 14b, 2c, 16d, 14e, 2f]` (here the first chunks contains the +// /// least significant bits so a left rotation will become a right rotation +// /// of the array). To perform the required rotation, we can +// /// simply rotate the array: [2f, 16a, 14b, 2c, 16d, 14e]. +// /// +// /// In the first step, we check that `rot8` and `split_rep` represent the +// /// same 64 bits. In the second step we check that `rot8` and the appropiate +// /// array rotation of `split_rep` represent the same 64 bits. +// /// +// /// This type of representation-equality check is done by packing chunks +// /// into sizes of exactly 32 (so for `delta = 2` we compare [16a, 14b, +// /// 2c] to the first 4 elements of `rot8`). In addition, we do range +// /// checks on `split_rep` which check that the felts meet the required +// /// sizes. +// /// +// /// This algorithm imposes the following general requirements for +// /// `split_rep`: +// /// - There exists a suffix of `split_rep` which sums to exactly `delta`. +// /// This suffix can contain several elements. +// /// - Chunk sizes are at most 16 (so they can be range-checked) or they are +// /// exactly equal to 32. +// /// - There exists a prefix of chunks which sums exactly to 32. This must +// /// hold for the rotated array as well. +// /// - The number of chunks should be as small as possible. +// /// +// /// Consult the method `rotation_split` to see how splits are computed for a +// /// given `delta +// /// +// /// Note that the function imposes range checks on chunk values, but it +// /// makes two exceptions: +// /// 1. It doesn't check the 8-bit reps (input and output). This is +// /// because all 8-bit reps in the global circuit are implicitly +// /// range-checked because they are lookup arguments. +// /// 2. It doesn't range-check 32-bit chunks. This is because a 32-bit +// /// chunk value is checked to be equal to the composition of 4 8-bit +// /// chunks. As mentioned in 1., these can be trusted to be range +// /// checked, so the resulting 32-bit is correct by construction as +// /// well. +// fn constrain_left_rotation64( +// &mut self, +// input8: &[Witness], +// split_rep: &[(usize, Witness)], +// rot8: &[Witness], +// delta: usize, +// label: String, +// ) { +// assert_eq!(input8.len(), 8); +// assert_eq!(rot8.len(), 8); + +// // Assert that the given split witnesses are correct for this delta +// let (sizes, chunks_rotation) = rotation_split(delta); +// assert_eq!(sizes, split_rep.iter().map(|e| e.0).collect_vec()); + +// // Lookup ranges +// for (size, elem) in split_rep { +// if *size != 32 { +// self.lookup_range((*elem).into(), *size); +// } +// } + +// // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are +// // the same 64 bitstring +// let mut helper = |rep8: &[Witness], rep_x: &[(usize, Witness)], chunks_rotation: usize| { +// // Do the same thing for the two 32-bit halves +// let mut rep_x = rep_x.to_owned(); +// rep_x.rotate_right(chunks_rotation); + +// for i in 0..2 { +// // The respective 4 elements in the byte representation +// let lhs = rep8[4 * i..4 * (i + 1)] +// .iter() +// .map(|wit| (8, *wit)) +// .collect_vec(); +// let cnt = rep_x.len() / 2; +// let rhs = &rep_x[cnt * i..cnt * (i + 1)]; + +// assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); + +// self.constrain_reps_eq::<32>( +// &lhs, +// rhs, +// format!( +// "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", +// sizes +// ), +// ); +// } +// }; + +// helper(input8, split_rep, 0); +// helper(rot8, split_rep, chunks_rotation); +// } +// } + +// const ROUNDS: usize = 24; + +// const RC: [u64; ROUNDS] = [ +// 1u64, +// 0x8082u64, +// 0x800000000000808au64, +// 0x8000000080008000u64, +// 0x808bu64, +// 0x80000001u64, +// 0x8000000080008081u64, +// 0x8000000000008009u64, +// 0x8au64, +// 0x88u64, +// 0x80008009u64, +// 0x8000000au64, +// 0x8000808bu64, +// 0x800000000000008bu64, +// 0x8000000000008089u64, +// 0x8000000000008003u64, +// 0x8000000000008002u64, +// 0x8000000000000080u64, +// 0x800au64, +// 0x800000008000000au64, +// 0x8000000080008081u64, +// 0x8000000000008080u64, +// 0x80000001u64, +// 0x8000000080008008u64, +// ]; + +// const ROTATION_CONSTANTS: [[usize; 5]; 5] = [ +// [0, 1, 62, 28, 27], +// [36, 44, 6, 55, 20], +// [3, 10, 43, 25, 39], +// [41, 45, 15, 21, 8], +// [18, 2, 61, 56, 14], +// ]; + +// pub const KECCAK_INPUT_SIZE: usize = 50; +// pub const KECCAK_OUTPUT_SIZE: usize = 50; + +// pub const AND_LOOKUPS_PER_ROUND: usize = 200; +// pub const XOR_LOOKUPS_PER_ROUND: usize = 608; +// pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; +// pub const LOOKUP_FELTS_PER_ROUND: usize = +// 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; + +// pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; +// pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; +// pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; + +// macro_rules! allocate_and_split { +// ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ +// let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); +// let mut iter = witnesses.into_iter(); +// ( +// $( +// iter.by_ref().take($size).collect_vec(), +// )* +// ) +// }}; +// } + +// impl ProtocolBuilder for KeccakLayout { +// type Params = KeccakParams; + +// fn init(params: Self::Params) -> Self { +// Self { +// _params: params, +// ..Default::default() +// } +// } + +// fn build_commit_phase(&mut self, chip: &mut Chip) { +// let _ = chip.allocate_committed::<{ 50 + 40144 }>(); +// } + +// fn build_gkr_phase(&mut self, chip: &mut Chip) { +// let final_outputs = +// chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS }>(); + +// let mut final_outputs_iter = final_outputs.iter(); + +// let [keccak_output32, keccak_input32, lookup_outputs] = [ +// KECCAK_OUTPUT_SIZE, +// KECCAK_INPUT_SIZE, +// // LOOKUPS_PER_ROUND +// LOOKUP_FELTS_PER_ROUND * ROUNDS, +// ] +// .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); + +// let keccak_output32 = keccak_output32.to_vec(); +// let keccak_input32 = keccak_input32.to_vec(); +// let lookup_outputs = lookup_outputs.to_vec(); + +// let (keccak_output8, []) = chip.allocate_wits_in_layer::<200, 0>(); + +// let keccak_output8: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &keccak_output8).unwrap(); + +// let mut expressions = vec![]; +// let mut evals = vec![]; +// let mut expr_names = vec![]; + +// let mut global_and_lookup = 0; +// let mut global_xor_lookup = 3 * AND_LOOKUPS; +// let mut global_range_lookup = 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS; + +// let mut openings = 0; + +// for x in 0..5 { +// for y in 0..5 { +// for k in 0..2 { +// // create an expression combining 4 elements of state8 into a single 32-bit felt +// let expr = expansion_expr::<32>( +// &keccak_output8 +// .slice(s![x, y, 4 * k..4 * (k + 1)]) +// .iter() +// .map(|e| (8, e.0)) +// .collect_vec(), +// ); +// expressions.push(expr); +// evals.push(keccak_output32[evals.len()].clone()); +// expr_names.push(format!("build 32-bit output: {x}, {y}, {k}")); +// } +// } +// } + +// chip.add_layer(Layer::new( +// "build 32-bit output".to_string(), +// LayerType::Zerocheck, +// expressions, +// vec![], +// keccak_output8 +// .into_iter() +// .map(|e| e.1.clone()) +// .collect_vec(), +// vec![], +// evals, +// expr_names, +// )); + +// let state8_loop = (0..ROUNDS).rev().fold( +// keccak_output8.iter().map(|e| e.1.clone()).collect_vec(), +// |round_output, round| { +// #[allow(non_snake_case)] +// let ( +// state8, +// c_aux, +// c_temp, +// c_rot, +// d, +// theta_output, +// rotation_witness, +// rhopi_output, +// nonlinear, +// chi_output, +// iota_output, +// ) = allocate_and_split!( +// chip, 1656, 200, 200, 30, 40, 40, 200, 146, 200, 200, 200, 200 +// ); + +// let total_witnesses = 200 + 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 200 + 200; +// // dbg!(total_witnesses); +// assert_eq!(1656, total_witnesses); + +// let bases = chain!( +// state8.clone(), +// c_aux.clone(), +// c_temp.clone(), +// c_rot.clone(), +// d.clone(), +// theta_output.clone(), +// rotation_witness.clone(), +// rhopi_output.clone(), +// nonlinear.clone(), +// chi_output.clone(), +// iota_output.clone(), +// ) +// .collect_vec(); + +// for wit in &bases { +// chip.allocate_opening(openings, wit.clone().1); +// openings += 1; +// } + +// // TODO: ndarrays can be replaced with normal arrays + +// // Input state of the round in 8-bit chunks +// let state8: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + +// let mut system = ConstraintSystem::new(); + +// // The purpose is to compute the auxiliary array +// // c[i] = XOR (state[j][i]) for j in 0..5 +// // We unroll it into +// // c_aux[i][j] = XOR (state[k][i]) for k in 0..j +// // We use c_aux[i][4] instead of c[i] +// // c_aux is also stored in 8-bit chunks +// let c_aux: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); + +// for i in 0..5 { +// for k in 0..8 { +// // Initialize first element +// system.constrain_eq( +// state8[[0, i, k]].0.into(), +// c_aux[[i, 0, k]].0.into(), +// "init c_aux".to_string(), +// ); +// } +// for j in 1..5 { +// // Check xor using lookups over all chunks +// for k in 0..8 { +// system.lookup_xor8( +// c_aux[[i, j - 1, k]].0.into(), +// state8[[j, i, k]].0.into(), +// c_aux[[i, j, k]].0.into(), +// ); +// } +// } +// } + +// // Compute c_rot[i] = c[i].rotate_left(1) +// // To understand how rotations are performed in general, consult the +// // documentation of `constrain_left_rotation64`. Here c_temp is the split +// // witness for a 1-rotation. + +// let c_temp: ArrayView<(Witness, EvalExpression), Ix2> = +// ArrayView::from_shape((5, 6), &c_temp).unwrap(); +// let c_rot: ArrayView<(Witness, EvalExpression), Ix2> = +// ArrayView::from_shape((5, 8), &c_rot).unwrap(); + +// let (sizes, _) = rotation_split(1); + +// for i in 0..5 { +// assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); + +// system.constrain_left_rotation64( +// &c_aux.slice(s![i, 4, ..]).iter().map(|e| e.0).collect_vec(), +// &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) +// .map(|(e, sz)| (*sz, e.0)) +// .collect_vec(), +// &c_rot.slice(s![i, ..]).iter().map(|e| e.0).collect_vec(), +// 1, +// "theta rotation".to_string(), +// ); +// } + +// // d is computed simply as XOR of required elements of c (and rotations) +// // again stored as 8-bit chunks +// let d: ArrayView<(Witness, EvalExpression), Ix2> = +// ArrayView::from_shape((5, 8), &d).unwrap(); + +// for i in 0..5 { +// for k in 0..8 { +// system.lookup_xor8( +// c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), +// c_rot[[(i + 1) % 5, k]].0.into(), +// d[[i, k]].0.into(), +// ) +// } +// } + +// // output state of the Theta sub-round, simple XOR, in 8-bit chunks +// let theta_output: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); + +// for i in 0..5 { +// for j in 0..5 { +// for k in 0..8 { +// system.lookup_xor8( +// state8[[j, i, k]].0.into(), +// d[[i, k]].0.into(), +// theta_output[[j, i, k]].0.into(), +// ) +// } +// } +// } + +// // output state after applying both Rho and Pi sub-rounds +// // sub-round Pi is a simple permutation of 64-bit lanes +// // sub-round Rho requires rotations +// let rhopi_output: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); + +// // iterator over split witnesses +// let mut rotation_witness = rotation_witness.iter(); + +// for i in 0..5 { +// #[allow(clippy::needless_range_loop)] +// for j in 0..5 { +// let arg = theta_output +// .slice(s!(j, i, ..)) +// .iter() +// .map(|e| e.0) +// .collect_vec(); +// let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); +// let many = sizes.len(); +// let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) +// .map(|(sz, (wit, _))| (sz, *wit)) +// .collect_vec(); +// let arg_rotated = rhopi_output +// .slice(s!((2 * i + 3 * j) % 5, j, ..)) +// .iter() +// .map(|e| e.0) +// .collect_vec(); +// system.constrain_left_rotation64( +// &arg, +// &rep_split, +// &arg_rotated, +// ROTATION_CONSTANTS[j][i], +// format!("RHOPI {i}, {j}"), +// ); +// } +// } + +// let chi_output: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); + +// // for the Chi sub-round, we use an intermediate witness storing the result of +// // the required AND +// let nonlinear: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); + +// for i in 0..5 { +// for j in 0..5 { +// for k in 0..8 { +// system.lookup_and8( +// not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), +// rhopi_output[[j, (i + 2) % 5, k]].0.into(), +// nonlinear[[j, i, k]].0.into(), +// ); + +// system.lookup_xor8( +// rhopi_output[[j, i, k]].0.into(), +// nonlinear[[j, i, k]].0.into(), +// chi_output[[j, i, k]].0.into(), +// ); +// } +// } +// } + +// // TODO: 24/25 elements stay the same after Iota; eliminate duplication? +// let iota_output: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); + +// for i in 0..5 { +// for j in 0..5 { +// if i == 0 && j == 0 { +// for k in 0..8 { +// system.lookup_xor8( +// chi_output[[j, i, k]].0.into(), +// Expression::Const(Constant::Base( +// ((RC[round] >> (k * 8)) & 0xFF) as i64, +// )), +// iota_output[[j, i, k]].0.into(), +// ); +// } +// } else { +// for k in 0..8 { +// system.constrain_eq( +// iota_output[[j, i, k]].0.into(), +// chi_output[[j, i, k]].0.into(), +// "nothing special".to_string(), +// ); +// } +// } +// } +// } + +// let ConstraintSystem { +// mut expressions, +// mut expr_names, +// mut evals, +// and_lookups, +// xor_lookups, +// range_lookups, +// .. +// } = system; + +// iota_output +// .into_iter() +// .enumerate() +// .map(|(i, val)| { +// expressions.push(val.0.into()); +// expr_names.push(format!("iota_output {i}")); +// evals.push(round_output[i].clone()); +// }) +// .count(); + +// for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) +// .flatten() +// .enumerate() +// { +// expressions.push(lookup); +// expr_names.push(format!("round {round}: {i}th lookup felt")); +// let idx = if i < 3 * AND_LOOKUPS_PER_ROUND { +// &mut global_and_lookup +// } else if i < 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND { +// &mut global_xor_lookup +// } else { +// &mut global_range_lookup +// }; +// evals.push(lookup_outputs[*idx].clone()); +// *idx += 1; +// } + +// chip.add_layer(Layer::new( +// format!("Round {round}"), +// LayerType::Zerocheck, +// expressions, +// vec![], +// bases.into_iter().map(|e| e.1).collect_vec(), +// vec![], +// evals, +// expr_names, +// )); + +// state8.into_iter().map(|e| e.1.clone()).collect_vec() +// }, +// ); + +// assert!(global_and_lookup == 3 * AND_LOOKUPS); +// assert!(global_xor_lookup == 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); +// assert!(global_range_lookup == LOOKUP_FELTS_PER_ROUND * ROUNDS); + +// let (state8, _) = chip.allocate_wits_in_layer::<200, 0>(); + +// let state8: ArrayView<(Witness, EvalExpression), Ix3> = +// ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + +// let mut expressions = vec![]; +// let mut evals = vec![]; +// let mut expr_names = vec![]; + +// for x in 0..5 { +// for y in 0..5 { +// for k in 0..2 { +// // create an expression combining 4 elements of state8 into a single 32-bit felt +// let expr = expansion_expr::<32>( +// state8 +// .slice(s![x, y, 4 * k..4 * (k + 1)]) +// .iter() +// .map(|e| (8, e.0)) +// .collect_vec() +// .as_slice(), +// ); +// expressions.push(expr); +// evals.push(keccak_input32[evals.len()].clone()); +// expr_names.push(format!("build 32-bit input: {x}, {y}, {k}")); +// } +// } +// } + +// // TODO: eliminate this duplication +// zip_eq(state8.iter(), state8_loop.iter()) +// .map(|(e, e_loop)| { +// expressions.push(e.0.into()); +// evals.push(e_loop.clone()); +// expr_names.push("state8 identity".to_string()); +// }) +// .count(); + +// chip.add_layer(Layer::new( +// "build 32-bit input".to_string(), +// LayerType::Zerocheck, +// expressions, +// vec![], +// state8.into_iter().map(|e| e.1.clone()).collect_vec(), +// vec![], +// evals, +// expr_names, +// )); + +// // TODO: allocate everything +// chip.allocate_base_opening(0, state8[[0, 0, 0]].clone().1); +// chip.allocate_base_opening(1, state8[[0, 0, 1]].clone().1); +// } +// } + +// #[derive(Clone, Default)] +// pub struct KeccakTrace { +// pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, +// } + +// impl ProtocolWitnessGenerator for KeccakLayout +// where +// E: ExtensionField, +// { +// type Trace = KeccakTrace; + +// fn phase1_witness_group(&self, phase1: Self::Trace) -> Vec> { +// let mut poly = vec![vec![]; KECCAK_INPUT_SIZE]; +// for instance in phase1.instances { +// let felts = u64s_to_felts::(instance.into_iter().map(|e| e as u64).collect_vec()); +// for i in 0..KECCAK_INPUT_SIZE { +// poly[i].push(felts[i]); +// } +// } +// poly +// } + +// fn gkr_witness(&self, phase1: &[Vec], _challenges: &[E]) -> GKRCircuitWitness { +// let n_layers = 24 + 2 + 1; +// let mut layer_wits = vec![ +// LayerWitness { +// bases: vec![], +// bases: vec![], +// num_vars: 1 +// }; +// n_layers +// ]; + +// let num_instances = phase1[0].len(); + +// for i in 0..num_instances { +// fn conv64to8(input: u64) -> [u64; 8] { +// MaskRepresentation::new(vec![(64, input).into()]) +// .convert(vec![8; 8]) +// .values() +// .try_into() +// .unwrap() +// } + +// let mut com_state = vec![]; +// #[allow(clippy::needless_range_loop)] +// for j in 0..KECCAK_INPUT_SIZE { +// com_state.push(phase1[j][i]); +// } + +// let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; +// let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; +// let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; + +// let mut add_and = |a: u64, b: u64, round: usize| { +// let c = a & b; +// assert!(a < (1 << 8)); +// assert!(b < (1 << 8)); +// and_lookups[round].extend(vec![a, b, c]); +// }; + +// let mut add_xor = |a: u64, b: u64, round: usize| { +// let c = a ^ b; +// assert!(a < (1 << 8)); +// assert!(b < (1 << 8)); +// xor_lookups[round].extend(vec![a, b, c]); +// }; + +// let mut add_range = |value: u64, size: usize, round: usize| { +// assert!(size <= 16, "{size}"); +// range_lookups[round].push(value); +// if size < 16 { +// range_lookups[round].push(value << (16 - size)); +// assert!(value << (16 - size) < (1 << 16)); +// } +// }; + +// let state32 = com_state +// .into_iter() +// // TODO double check assumptions about canonical +// .map(|e| e.to_canonical_u64()) +// .collect_vec(); + +// let mut state64 = [[0u64; 5]; 5]; +// let mut state8 = [[[0u64; 8]; 5]; 5]; + +// zip_eq(iproduct!(0..5, 0..5), state32.clone().iter().tuples()) +// .map(|((x, y), (lo, hi))| { +// state64[x][y] = lo | (hi << 32); +// }) +// .count(); + +// for x in 0..5 { +// for y in 0..5 { +// state8[x][y] = conv64to8(state64[x][y]); +// } +// } + +// let mut curr_layer = 0; +// let mut push_instance = |wits: Vec| { +// let felts = u64s_to_felts::(wits); +// if layer_wits[curr_layer].bases.is_empty() { +// layer_wits[curr_layer] = LayerWitness::new(nest::(&felts), vec![]); +// } else { +// assert_eq!(felts.len(), layer_wits[curr_layer].bases.len()); +// for (i, base) in layer_wits[curr_layer].bases.iter_mut().enumerate() { +// base.push(felts[i]); +// } +// } +// curr_layer += 1; +// }; + +// push_instance(state8.into_iter().flatten().flatten().collect_vec()); + +// #[allow(clippy::needless_range_loop)] +// for round in 0..ROUNDS { +// let mut c_aux64 = [[0u64; 5]; 5]; +// let mut c_aux8 = [[[0u64; 8]; 5]; 5]; + +// for i in 0..5 { +// c_aux64[i][0] = state64[0][i]; +// c_aux8[i][0] = conv64to8(c_aux64[i][0]); +// for j in 1..5 { +// c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; +// c_aux8[i][j] = conv64to8(c_aux64[i][j]); + +// for k in 0..8 { +// add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); +// } +// } +// } + +// let mut c64 = [0u64; 5]; +// let mut c8 = [[0u64; 8]; 5]; + +// for x in 0..5 { +// c64[x] = c_aux64[x][4]; +// c8[x] = conv64to8(c64[x]); +// } + +// let mut c_temp = [[0u64; 6]; 5]; +// for i in 0..5 { +// let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) +// .convert(vec![16, 15, 1, 16, 15, 1]); +// c_temp[i] = rep.values().try_into().unwrap(); +// for mask in rep.rep { +// add_range(mask.value, mask.size, round); +// } +// } + +// let mut crot64 = [0u64; 5]; +// let mut crot8 = [[0u64; 8]; 5]; +// for i in 0..5 { +// crot64[i] = c64[i].rotate_left(1); +// crot8[i] = conv64to8(crot64[i]); +// } + +// let mut d64 = [0u64; 5]; +// let mut d8 = [[0u64; 8]; 5]; +// for x in 0..5 { +// d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); +// d8[x] = conv64to8(d64[x]); +// for k in 0..8 { +// add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); +// } +// } + +// let mut theta_state64 = state64; +// let mut theta_state8 = [[[0u64; 8]; 5]; 5]; +// let mut rotation_witness = vec![]; + +// for x in 0..5 { +// for y in 0..5 { +// theta_state64[y][x] ^= d64[x]; +// theta_state8[y][x] = conv64to8(theta_state64[y][x]); + +// for k in 0..8 { +// add_xor(state8[y][x][k], d8[x][k], round); +// } + +// let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); +// let rep = MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) +// .convert(sizes); +// for mask in rep.rep.iter() { +// if mask.size != 32 { +// add_range(mask.value, mask.size, round); +// } +// } +// rotation_witness.extend(rep.values()); +// } +// } + +// // Rho and Pi steps +// let mut rhopi_output64 = [[0u64; 5]; 5]; +// let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; + +// for x in 0..5 { +// for y in 0..5 { +// rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = +// theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); +// } +// } + +// for x in 0..5 { +// for y in 0..5 { +// rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]); +// } +// } + +// // Chi step + +// let mut nonlinear64 = [[0u64; 5]; 5]; +// let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; +// for x in 0..5 { +// for y in 0..5 { +// nonlinear64[y][x] = +// !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; +// nonlinear8[y][x] = conv64to8(nonlinear64[y][x]); + +// for k in 0..8 { +// add_and( +// 0xFF - rhopi_output8[y][(x + 1) % 5][k], +// rhopi_output8[y][(x + 2) % 5][k], +// round, +// ); +// } +// } +// } + +// let mut chi_output64 = [[0u64; 5]; 5]; +// let mut chi_output8 = [[[0u64; 8]; 5]; 5]; +// for x in 0..5 { +// for y in 0..5 { +// chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; +// chi_output8[y][x] = conv64to8(chi_output64[y][x]); +// for k in 0..8 { +// add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) +// } +// } +// } + +// // Iota step +// let mut iota_output64 = chi_output64; +// let mut iota_output8 = [[[0u64; 8]; 5]; 5]; +// iota_output64[0][0] ^= RC[round]; + +// for k in 0..8 { +// add_xor(chi_output8[0][0][k], (RC[round] >> (k * 8)) & 0xFF, round); +// } + +// for x in 0..5 { +// for y in 0..5 { +// iota_output8[x][y] = conv64to8(iota_output64[x][y]); +// } +// } + +// let all_wits64 = [ +// state8.into_iter().flatten().flatten().collect_vec(), +// c_aux8.into_iter().flatten().flatten().collect_vec(), +// c_temp.into_iter().flatten().collect_vec(), +// crot8.into_iter().flatten().collect_vec(), +// d8.into_iter().flatten().collect_vec(), +// theta_state8.into_iter().flatten().flatten().collect_vec(), +// rotation_witness, +// rhopi_output8.into_iter().flatten().flatten().collect_vec(), +// nonlinear8.into_iter().flatten().flatten().collect_vec(), +// chi_output8.into_iter().flatten().flatten().collect_vec(), +// iota_output8.into_iter().flatten().flatten().collect_vec(), +// ]; + +// push_instance(all_wits64.into_iter().flatten().collect_vec()); + +// state8 = iota_output8; +// state64 = iota_output64; +// } + +// let mut keccak_output32 = vec![vec![vec![0; 2]; 5]; 5]; + +// for x in 0..5 { +// for y in 0..5 { +// keccak_output32[x][y] = MaskRepresentation::from( +// state8[x][y].into_iter().map(|e| (8, e)).collect_vec(), +// ) +// .convert(vec![32; 2]) +// .values(); +// } +// } + +// push_instance(state8.into_iter().flatten().flatten().collect_vec()); + +// // For temporary convenience, use one extra layer to store the correct outputs +// // of the circuit This is not used during proving +// let lookups = chain!( +// (0..ROUNDS).rev().flat_map(|i| and_lookups[i].clone()), +// (0..ROUNDS).rev().flat_map(|i| xor_lookups[i].clone()), +// (0..ROUNDS).rev().flat_map(|i| range_lookups[i].clone()) +// ) +// .collect_vec(); + +// push_instance( +// chain!( +// keccak_output32.into_iter().flatten().flatten(), +// state32, +// lookups +// ) +// .collect_vec(), +// ); +// } + +// let len = layer_wits.len() - 1; +// layer_wits[..len].reverse(); + +// GKRCircuitWitness { layers: layer_wits } +// } +// } + +// pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) { +// let params = KeccakParams {}; +// let (layout, chip) = KeccakLayout::build(params); + +// let mut instances = vec![]; +// for state in &states { +// let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); +// let state_mask32 = state_mask64.convert(vec![32; 50]); + +// instances.push( +// state_mask32 +// .values() +// .iter() +// .map(|e| *e as u32) +// .collect_vec() +// .try_into() +// .unwrap(), +// ); +// } + +// let num_instances = instances.len(); +// let phase1_witness = layout.phase1_witness_group(KeccakTrace { +// instances: instances.clone(), +// }); + +// let mut prover_transcript = BasicTranscript::::new(b"protocol"); + +// // Omit the commit phase1 and phase2. +// let gkr_witness: GKRCircuitWitness = layout.gkr_witness(&phase1_witness, &[]); + +// let out_evals = { +// let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); +// let point = Arc::new(vec![E::from_u64(29); log2_num_instances as usize]); + +// if test_outputs { +// // Confront outputs with tiny_keccak::keccakf call +// let mut instance_outputs = vec![vec![]; num_instances]; +// for base in gkr_witness +// .layers +// .last() +// .unwrap() +// .bases +// .iter() +// .take(KECCAK_OUTPUT_SIZE) +// { +// assert_eq!(base.len(), num_instances); +// for i in 0..num_instances { +// instance_outputs[i].push(base[i]); +// } +// } + +// for i in 0..num_instances { +// let mut state = states[i]; +// keccakf(&mut state); +// assert_eq!( +// state +// .to_vec() +// .iter() +// .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) +// .map(|e| Goldilocks::from_u64(e as u64)) +// .collect_vec(), +// instance_outputs[i] +// ); +// } +// } + +// let out_evals = gkr_witness +// .layers +// .last() +// .unwrap() +// .bases +// .iter() +// .map(|base| PointAndEval { +// point: point.clone(), +// eval: subprotocols::utils::evaluate_mle_ext(base, &point), +// }) +// .collect_vec(); + +// assert_eq!( +// out_evals.len(), +// KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS +// ); + +// out_evals +// }; + +// let gkr_circuit = chip.gkr_circuit(); +// dbg!(&gkr_circuit.layers.len()); +// let GKRProverOutput { gkr_proof, .. } = gkr_circuit +// .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) +// .expect("Failed to prove phase"); + +// if verify { +// { +// let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + +// gkr_circuit +// .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) +// .expect("GKR verify failed"); + +// // Omit the PCS opening phase. +// } +// } +// } + +// #[cfg(test)] +// mod tests { +// use super::*; +// use rand::{Rng, SeedableRng}; + +// #[test] +// fn test_keccakf() { +// for _ in 0..3 { +// // let random_u64: u64 = rand::random(); +// // Use seeded rng for debugging convenience +// let mut rng = rand::rngs::StdRng::seed_from_u64(42); + +// let num_instances = 8; +// let mut states: Vec<[u64; 25]> = vec![]; + +// for _ in 0..num_instances { +// states.push(std::array::from_fn(|_| rng.gen())) +// } +// run_faster_keccakf(states, true, true); +// } +// } + +// // TODO: make it pass +// #[ignore] +// #[test] +// fn test_keccakf_nonpow2() { +// for _ in 0..3 { +// // let random_u64: u64 = rand::random(); +// // Use seeded rng for debugging convenience +// let mut rng = rand::rngs::StdRng::seed_from_u64(42); + +// let num_instances = 3; +// let mut states: Vec<[u64; 25]> = vec![]; + +// for _ in 0..num_instances { +// states.push(std::array::from_fn(|_| rng.gen())) +// } +// run_faster_keccakf(states, true, true); +// } +// } +// } diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index ee783a711..9b0ec2c98 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -2,7 +2,7 @@ mod bitwise_keccakf; mod lookup_keccakf; mod utils; pub use bitwise_keccakf::run_keccakf; -pub use lookup_keccakf::{ - AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KeccakLayout, KeccakParams, KeccakTrace, RANGE_LOOKUPS, - RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, run_faster_keccakf, -}; +// pub use lookup_keccakf::{ +// AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KeccakLayout, KeccakParams, KeccakTrace, RANGE_LOOKUPS, +// RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, run_faster_keccakf, +// }; diff --git a/multilinear_extensions/src/expression.rs b/multilinear_extensions/src/expression.rs index f9fea60b4..6bafdb30f 100644 --- a/multilinear_extensions/src/expression.rs +++ b/multilinear_extensions/src/expression.rs @@ -1022,17 +1022,25 @@ pub fn wit_infer_by_expr<'a, E: ExtensionField>( }, &|x, a, b| { op_mle_xa_b!(|x, a, b| { - assert_eq!(a.len(), 1); - assert_eq!(b.len(), 1); - let (a, b) = (a[0], b[0]); - MultilinearExtension::from_evaluation_vec_smart( - ceil_log2(x.len()), - x.par_iter() - .with_min_len(MIN_PAR_SIZE) - .map(|x| a * *x + b) - .collect(), - ) - .into() + match (x.len(), a.len(), b.len()) { + (_, 1, 1) => MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(x.len()), + x.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|x| a[0] * *x + b[0]) + .collect(), + ) + .into(), + (1, _, 1) => MultilinearExtension::from_evaluation_vec_smart( + ceil_log2(a.len()), + a.par_iter() + .with_min_len(MIN_PAR_SIZE) + .map(|a| *a * x[0] + b[0]) + .collect(), + ) + .into(), + lefted => panic!("unknown combination {:?}", lefted), + } }) }, ) diff --git a/multilinear_extensions/src/mle.rs b/multilinear_extensions/src/mle.rs index bf36ec7c7..ef58c824b 100644 --- a/multilinear_extensions/src/mle.rs +++ b/multilinear_extensions/src/mle.rs @@ -945,7 +945,7 @@ macro_rules! op_mle3_range { }}; } -/// deal with x * a + b +/// deal with x * a + b or a * x + b #[macro_export] macro_rules! op_mle_xa_b { (|$x:ident, $a:ident, $b:ident| $op:expr, |$bb_out:ident| $op_bb_out:expr) => { @@ -971,6 +971,13 @@ macro_rules! op_mle_xa_b { ) => { op_mle3_range!($x, $a, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) } + ( + $crate::mle::FieldType::Ext(x_vec), + $crate::mle::FieldType::Base(a_vec), + $crate::mle::FieldType::Base(b_vec), + ) => { + op_mle3_range!($a, $x, $b, x_vec, a_vec, b_vec, $op, |$bb_out| $op_bb_out) + } (x, a, b) => unreachable!( "unmatched pattern {:?} {:?} {:?}", x.variant_name(), From 4ec140a34a92198ddce80809897f8b2d61dd69b7 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 20 May 2025 11:17:52 +0800 Subject: [PATCH 06/28] cleanup --- gkr_iop/src/gkr/layer/linear_layer.rs | 5 ----- gkr_iop/src/precompiles/bitwise_keccakf.rs | 15 ++++++------- gkr_iop/src/precompiles/utils.rs | 25 +++++++--------------- 3 files changed, 14 insertions(+), 31 deletions(-) diff --git a/gkr_iop/src/gkr/layer/linear_layer.rs b/gkr_iop/src/gkr/layer/linear_layer.rs index 68f7acb18..36ef033ef 100644 --- a/gkr_iop/src/gkr/layer/linear_layer.rs +++ b/gkr_iop/src/gkr/layer/linear_layer.rs @@ -8,11 +8,6 @@ use crate::error::BackendError; use super::{Layer, LayerWitness, sumcheck_layer::SumcheckLayerProof}; -pub struct LinearLayerProof { - evals: Vec, - point: Point, -} - pub struct LayerClaims { pub in_point: Point, pub evals: Vec, diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index 944def7b4..97e6279a9 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -17,7 +17,7 @@ use multilinear_extensions::{ util::ceil_log2, wit_infer_by_expr, }; -use p3_field::{Field, PrimeCharacteristicRing, extension::BinomialExtensionField}; +use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; use sumcheck::util::optimal_sumcheck_threads; @@ -368,14 +368,11 @@ where .enumerate() .fold(&mut layer_wits, |layer_wits, (i, layer)| { tracing::info!("generating input {i} layer with layer name {}", layer.name); - let wit = { - LayerWitness::new(infer_layer_witness( - &layer, - layer_wits.last().unwrap().bases.clone(), - challenges, - )) - }; - layer_wits.push(wit); + layer_wits.push(LayerWitness::new(infer_layer_witness( + &layer, + layer_wits.last().unwrap().bases.clone(), + challenges, + ))); layer_wits }); diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs index 58af29c11..e25523193 100644 --- a/gkr_iop/src/precompiles/utils.rs +++ b/gkr_iop/src/precompiles/utils.rs @@ -1,19 +1,10 @@ use ff_ext::ExtensionField; use itertools::Itertools; -use multilinear_extensions::ToExpr; +use multilinear_extensions::{Expression, ToExpr}; use p3_field::PrimeCharacteristicRing; -use subprotocols::expression::{Constant, Expression}; use crate::evaluation::EvalExpression; -pub fn zero_expr() -> Expression { - Expression::Const(Constant::Base(0)) -} - -pub fn not8_expr(expr: Expression) -> Expression { - Expression::Const(Constant::Base(0xFF)) - expr -} - pub fn zero_eval() -> EvalExpression { EvalExpression::Linear( 0, @@ -117,15 +108,15 @@ impl MaskRepresentation { } #[derive(Debug)] -pub enum CenoLookup { - And(Expression, Expression, Expression), - Xor(Expression, Expression, Expression), - U16(Expression), +pub enum CenoLookup { + And(Expression, Expression, Expression), + Xor(Expression, Expression, Expression), + U16(Expression), } -impl IntoIterator for CenoLookup { - type Item = Expression; - type IntoIter = std::vec::IntoIter; +impl IntoIterator for CenoLookup { + type Item = Expression; + type IntoIter = std::vec::IntoIter>; fn into_iter(self) -> Self::IntoIter { match self { From a0b4889780ac2a3714d3cf64bc1675e5d49d53b9 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 20 May 2025 22:45:41 +0800 Subject: [PATCH 07/28] witness compile pass --- ceno_zkvm/benches/fibonacci.rs | 2 +- gkr_iop/src/precompiles/bitwise_keccakf.rs | 119 ++++++++++++++------- 2 files changed, 82 insertions(+), 39 deletions(-) diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index c74c63e34..4f43130e7 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -31,7 +31,7 @@ fn setup() -> (Program, Platform) { let stack_size = 32768; let heap_size = 2097152; let pub_io_size = 16; - let program = Program::load_elf(ceno_examples::fibonacci, u32::MAX).unwrap(); + let program = Program::load_elf(ceno_examples::guest_keccak, u32::MAX).unwrap(); let platform = setup_platform(Preset::Ceno, &program, stack_size, heap_size, pub_io_size); (program, platform) } diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index 97e6279a9..619888e1a 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -112,7 +112,7 @@ fn keccak_witness<'a, E: ExtensionField>( log_num_states, bit_column .into_iter() - .map(|b| E::from_bool(b)) + .map(|b| E::BaseField::from_bool(b)) .collect::>(), ) }) @@ -320,7 +320,7 @@ pub struct KeccakTrace<'a, E: ExtensionField> { pub fn infer_layer_witness<'a, E>( layer: &Layer, - layer_wits: Vec>, + layer_wits: &[ArcMultilinearExtension<'a, E>], challenges: &[E], ) -> Vec> where @@ -329,10 +329,7 @@ where layer .exprs .iter() - .map(|expr| { - tracing::trace!("infer_layer_witness expr {}", expr,); - wit_infer_by_expr(&[], &layer_wits, &[], &[], challenges, expr) - }) + .map(|expr| wit_infer_by_expr(&[], layer_wits, &[], &[], challenges, expr)) .collect_vec() } @@ -352,29 +349,66 @@ where phase1_witness_group: Phase1WitnessGroup<'a, E>, challenges: &[E], ) -> GKRCircuitWitness<'a, E> { - let bits_ref: Vec> = + let input_bits: Vec> = phase1_witness_group[self.committed_bits_id].clone(); + if cfg!(debug_assertions) { + // phase 1 input must all in base field + input_bits.iter().for_each(|mle| { + let _ = mle.get_base_field_vec(); + }); + } + // layer order from output to input let n_layers = 100; let mut layer_wits = Vec::>::with_capacity(n_layers + 1); - layer_wits.push(LayerWitness::new(bits_ref.clone())); + layer_wits.push(LayerWitness::new(input_bits.clone())); + let mut witness_mle_flattern = vec![None; circuit.n_evaluations]; - circuit - .layers - .iter() - .rev() - .enumerate() - .fold(&mut layer_wits, |layer_wits, (i, layer)| { - tracing::info!("generating input {i} layer with layer name {}", layer.name); - layer_wits.push(LayerWitness::new(infer_layer_witness( - &layer, - layer_wits.last().unwrap().bases.clone(), - challenges, - ))); - layer_wits - }); + // set input to witness_mle_flattern via first layer in_eval_expr + circuit.layers.last().map(|first_layer| { + first_layer + .in_eval_expr + .iter() + .enumerate() + .for_each(|(index, eval_expr)| match eval_expr { + EvalExpression::Single(witin) => { + witness_mle_flattern[*witin] = Some(input_bits[index].clone()); + } + other => unimplemented!("{:?}", other), + }) + }); + + // generate all layer witness from input to output + for (i, layer) in circuit.layers.iter().rev().enumerate() { + tracing::info!("generating input {i} layer with layer name {}", layer.name); + // process in_evals to prepare layer witness + let current_layer_wits = layer + .in_eval_expr + .iter() + .map(|eval| match eval { + EvalExpression::Single(witin) => witness_mle_flattern[*witin] + .clone() + .expect("witness must exist"), + other => unimplemented!("{:?}", other), + }) + .collect_vec(); + let current_layer_output = infer_layer_witness(&layer, ¤t_layer_wits, challenges); + layer_wits.push(LayerWitness::new(current_layer_wits)); + + // process out to prepare output witness + layer + .outs + .iter() + .zip_eq(¤t_layer_output) + .for_each(|(out_eval, out_mle)| match out_eval { + EvalExpression::Single(out) => { + witness_mle_flattern[*out] = Some(out_mle.clone()) + } + other => unimplemented!("{:?}", other), + }); + } // Assumes one input instance let total_witness_size: usize = layer_wits.iter().map(|layer| layer.bases.len()).sum(); @@ -427,7 +461,7 @@ fn rho_and_pi_permutation() -> Vec { pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { type E = BinomialExtensionField; - let num_instances = 1; + let num_instances = states.len(); let log2_num_instances = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instances); @@ -450,21 +484,29 @@ pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { if test { // sanity check on first instance only + // TODO test all instances let result_from_witness = gkr_witness.layers[0] .bases .iter() - .map(|bit| bit.get_base_field_vec()[0]) + .map(|bit| { + if ::BaseField::ZERO == bit.get_base_field_vec()[0] { + ::BaseField::ZERO + } else { + ::BaseField::ONE + } + }) .collect_vec(); let mut state = states.clone(); keccakf(&mut state[0]); - assert_eq!( - keccak_witness(&state) // result from tiny keccak - .into_iter() - .map(|b: MultilinearExtension<'_, E>| b.get_base_field_vec()[0]) - .collect_vec(), - result_from_witness - ); + // TODO test this + // assert_eq!( + // keccak_witness(&state) // result from tiny keccak + // .into_iter() + // .map(|b: MultilinearExtension<'_, E>| b.get_base_field_vec()[0]) + // .collect_vec(), + // result_from_witness + // ); } gkr_witness.layers[0] @@ -514,12 +556,13 @@ mod tests { .with_test_writer() .try_init(); - for _ in 0..3 { - let random_u64: u64 = rand::random(); - // Use seeded rng for debugging convenience - let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); - let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); - run_keccakf(vec![state], true, true); - } + let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); + let num_instance = 2; + let states: Vec<[u64; 25]> = (0..num_instance) + .map(|_| std::array::from_fn(|_| rng.gen())) + .collect_vec(); + run_keccakf(states, true, true); } } From 14315abe2269f685c09f7dafa4a35b59fc2d11c4 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 21 May 2025 00:16:15 +0800 Subject: [PATCH 08/28] wip eq --- gkr_iop/src/gkr.rs | 4 +++- gkr_iop/src/gkr/layer.rs | 1 + gkr_iop/src/gkr/layer/zerocheck_layer.rs | 19 +++++++++++++------ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index b934c5d3f..5d40ec06e 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -70,7 +70,9 @@ impl GKRCircuit { running_evals.resize(self.n_evaluations, PointAndEval::default()); let mut challenges = challenges.to_vec(); let sumcheck_proofs = izip!(&self.layers, circuit_wit.layers) - .map(|(layer, layer_wit)| { + .enumerate() + .map(|(i, (layer, layer_wit))| { + tracing::info!("prove layer {i} layer with layer name {}", layer.name); layer.prove( num_threads, max_num_variables, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 90a6ba5d7..a7c388e71 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -72,6 +72,7 @@ impl Layer { outs: Vec>, expr_names: Vec, ) -> Self { + assert!(eqs.len() == 1); let mut expr_names = expr_names; if expr_names.len() < exprs.len() { expr_names.extend(vec![ diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index ab9d4c882..7cf9d49c2 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -11,7 +11,7 @@ use multilinear_extensions::{ virtual_polys::VirtualPolynomialsBuilder, }; use p3_field::dot_product; -use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverState, IOPVerifierState, SumCheckSubClaim, VerifierError}, @@ -56,11 +56,17 @@ impl ZerocheckLayer for Layer { challenges: &[E], transcript: &mut impl Transcript, ) -> SumcheckLayerProof { + assert!( + out_points.iter().all_equal(), + "output points not all equals len() {}", + out_points.len() + ); assert_eq!(self.exprs.len(), out_points.len()); let span = entered_span!("build_out_points_eq"); let mut eqs = out_points .par_iter() + .take(1) .map(|point| { MultilinearExtension::from_evaluations_ext_vec( point.len(), @@ -84,15 +90,16 @@ impl ZerocheckLayer for Layer { .into_iter() .map(|r| Expression::Constant(Either::Right(r))) .collect_vec(); - let expr = self + let zerocheck_expr = self .exprs .iter() - .zip_eq(&self.eqs) .zip_eq(alpha_pows) - .map(|((expr, eq), alpha)| alpha * eq * expr) + .map(|(expr, alpha)| alpha * expr) .sum::>(); - let (proof, prover_state) = - IOPProverState::prove(builder.to_virtual_polys(&[expr], challenges), transcript); + let (proof, prover_state) = IOPProverState::prove( + builder.to_virtual_polys(&[self.eqs[0].clone() * zerocheck_expr], challenges), + transcript, + ); SumcheckLayerProof { proof, evals: prover_state.get_mle_flatten_final_evaluations(), From fc3b5905e720dd4d897398683612f3e9742b1731 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 21 May 2025 16:17:10 +0800 Subject: [PATCH 09/28] bitwise keccak benchmark --- gkr_iop/benches/bitwise_keccakf.rs | 60 +++++++---- gkr_iop/benches/lookup_keccakf.rs | 78 +++++++------- gkr_iop/examples/multi_layer_logup.rs | 14 +-- gkr_iop/src/chip/builder.rs | 11 +- gkr_iop/src/gkr.rs | 3 +- gkr_iop/src/gkr/layer.rs | 112 ++++++++++++++------- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 88 ++++++++++------ gkr_iop/src/gkr/mock.rs | 6 +- gkr_iop/src/precompiles/bitwise_keccakf.rs | 102 +++++++++++-------- gkr_iop/src/precompiles/mod.rs | 2 +- multilinear_extensions/src/virtual_poly.rs | 2 +- sumcheck/src/prover.rs | 4 +- 12 files changed, 294 insertions(+), 188 deletions(-) diff --git a/gkr_iop/benches/bitwise_keccakf.rs b/gkr_iop/benches/bitwise_keccakf.rs index 902b63a1b..95686d0bd 100644 --- a/gkr_iop/benches/bitwise_keccakf.rs +++ b/gkr_iop/benches/bitwise_keccakf.rs @@ -1,7 +1,9 @@ use std::time::Duration; use criterion::*; -use gkr_iop::precompiles::run_keccakf; +use ff_ext::GoldilocksExt2; +use gkr_iop::precompiles::{run_keccakf, setup_gkr_circuit}; +use itertools::Itertools; use rand::{Rng, SeedableRng}; criterion_group!(benches, keccak_f_fn); criterion_main!(benches); @@ -9,29 +11,43 @@ criterion_main!(benches); const NUM_SAMPLES: usize = 10; fn keccak_f_fn(c: &mut Criterion) { - // expand more input size once runtime is acceptable - let mut group = c.benchmark_group("keccak_f".to_string()); - group.sample_size(NUM_SAMPLES); - // Benchmark the proving time - group.bench_function(BenchmarkId::new("keccak_f", "keccak_f"), |b| { - b.iter_custom(|iters| { - let mut time = Duration::new(0, 0); - for _ in 0..iters { - // Use seeded rng for debugging convenience - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - let state: [u64; 25] = std::array::from_fn(|_| rng.gen()); + for log_instances in 5..8 { + let num_instance = 1 << log_instances; + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("keccak_f_{}", num_instance)); + group.sample_size(NUM_SAMPLES); + group.bench_function( + BenchmarkId::new("keccak_f", format!("prove_keccek_f_{}", num_instance)), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let states: Vec<[u64; 25]> = (0..num_instance) + .map(|_| std::array::from_fn(|_| rng.gen())) + .collect_vec(); - let instant = std::time::Instant::now(); - #[allow(clippy::unit_arg)] - black_box(run_keccakf(state, false, false)); - let elapsed = instant.elapsed(); - time += elapsed; - } + let instant = std::time::Instant::now(); - time - }); - }); + let circuit = setup_gkr_circuit(); + #[allow(clippy::unit_arg)] + run_keccakf::(circuit, black_box(states), false, false); + let elapsed = instant.elapsed(); + println!( + "keccak_f::create_proof, instances = {}, time = {}", + num_instance, + elapsed.as_secs_f64() + ); + time += elapsed; + } - group.finish(); + time + }); + }, + ); + group.finish(); + } } diff --git a/gkr_iop/benches/lookup_keccakf.rs b/gkr_iop/benches/lookup_keccakf.rs index a9ab4f5f0..b214f402c 100644 --- a/gkr_iop/benches/lookup_keccakf.rs +++ b/gkr_iop/benches/lookup_keccakf.rs @@ -1,39 +1,39 @@ -use std::time::Duration; - -use criterion::*; -use gkr_iop::precompiles::run_faster_keccakf; - -use rand::{Rng, SeedableRng}; -criterion_group!(benches, keccak_f_fn); -criterion_main!(benches); - -const NUM_SAMPLES: usize = 10; - -fn keccak_f_fn(c: &mut Criterion) { - // expand more input size once runtime is acceptable - let mut group = c.benchmark_group("keccakf"); - group.sample_size(NUM_SAMPLES); - - // Benchmark the proving time - group.bench_function(BenchmarkId::new("keccakf", "keccakf"), |b| { - b.iter_custom(|iters| { - let mut time = Duration::new(0, 0); - for _ in 0..iters { - // Use seeded rng for debugging convenience - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); - let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); - - let instant = std::time::Instant::now(); - #[allow(clippy::unit_arg)] - black_box(run_faster_keccakf(vec![state1, state2], false, false)); - let elapsed = instant.elapsed(); - time += elapsed; - } - - time - }); - }); - - group.finish(); -} +// use std::time::Duration; + +// use criterion::*; +// use gkr_iop::precompiles::run_faster_keccakf; + +// use rand::{Rng, SeedableRng}; +// criterion_group!(benches, keccak_f_fn); +// criterion_main!(benches); + +// const NUM_SAMPLES: usize = 10; + +// fn keccak_f_fn(c: &mut Criterion) { +// // expand more input size once runtime is acceptable +// let mut group = c.benchmark_group("keccakf"); +// group.sample_size(NUM_SAMPLES); + +// // Benchmark the proving time +// group.bench_function(BenchmarkId::new("keccakf", "keccakf"), |b| { +// b.iter_custom(|iters| { +// let mut time = Duration::new(0, 0); +// for _ in 0..iters { +// // Use seeded rng for debugging convenience +// let mut rng = rand::rngs::StdRng::seed_from_u64(42); +// let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); +// let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); + +// let instant = std::time::Instant::now(); +// #[allow(clippy::unit_arg)] +// black_box(run_faster_keccakf(vec![state1, state2], false, false)); +// let elapsed = instant.elapsed(); +// time += elapsed; +// } + +// time +// }); +// }); + +// group.finish(); +// } diff --git a/gkr_iop/examples/multi_layer_logup.rs b/gkr_iop/examples/multi_layer_logup.rs index 2301fcb8b..cb1bbb5dd 100644 --- a/gkr_iop/examples/multi_layer_logup.rs +++ b/gkr_iop/examples/multi_layer_logup.rs @@ -11,10 +11,10 @@ use gkr_iop::{ }, }; use itertools::{Itertools, izip}; +use multilinear_extensions::Expression; use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; use rand::{Rng, rngs::OsRng}; -use subprotocols::expression::{Constant, Expression}; use transcript::{BasicTranscript, Transcript}; #[cfg(debug_assertions)] @@ -31,21 +31,21 @@ struct TowerParams { } #[derive(Clone, Debug, Default)] -struct TowerChipLayout { +struct TowerChipLayout { params: TowerParams, // Committed poly indices. committed_table_id: usize, committed_count_id: usize, - lookup_challenge: Constant, + lookup_challenge: Expression, - output_cumulative_sum: [EvalExpression; 2], + output_cumulative_sum: [EvalExpression; 2], _field: PhantomData, } -impl ProtocolBuilder for TowerChipLayout { +impl ProtocolBuilder for TowerChipLayout { type Params = TowerParams; fn init(params: Self::Params) -> Self { @@ -55,12 +55,12 @@ impl ProtocolBuilder for TowerChipLayout { } } - fn build_commit_phase(&mut self, chip: &mut Chip) { + fn build_commit_phase(&mut self, chip: &mut Chip) { [self.committed_table_id, self.committed_count_id] = chip.allocate_committed(); [self.lookup_challenge] = chip.allocate_challenges(); } - fn build_gkr_phase(&mut self, chip: &mut Chip) { + fn build_gkr_phase(&mut self, chip: &mut Chip) { let height = self.params.height; let lookup_challenge = Expression::Const(self.lookup_challenge.clone()); diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index f23fddff4..de00ce3ae 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -90,7 +90,16 @@ impl Chip { /// Add a layer to the circuit. pub fn add_layer(&mut self, layer: Layer) { - assert_eq!(layer.outs.len(), layer.exprs.len()); + assert_eq!( + layer + .outs + .iter() + .map(|(_, outs)| outs) + .flatten() + .collect_vec() + .len(), + layer.exprs.len() + ); match layer.ty { LayerType::Linear => { assert!(layer.exprs.iter().all(|expr| expr.degree() == 1)); diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 5d40ec06e..e4ae98d6e 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -108,7 +108,8 @@ impl GKRCircuit { let mut challenges = challenges.to_vec(); let mut evaluations = out_evals.to_vec(); evaluations.resize(self.n_evaluations, PointAndEval::default()); - for (layer, layer_proof) in izip!(&self.layers, sumcheck_proofs) { + for (i, (layer, layer_proof)) in izip!(&self.layers, sumcheck_proofs).enumerate() { + tracing::info!("verifier layer {i} layer with layer name {}", layer.name); layer.verify( max_num_variables, layer_proof, diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index a7c388e71..86d76484c 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -32,6 +32,7 @@ pub enum LayerType { pub struct Layer { pub name: String, pub ty: LayerType, + pub max_expr_degree: usize, /// Challenges generated at the beginning of the layer protocol. pub challenges: Vec>, /// Expressions to prove in this layer. For zerocheck and linear layers, @@ -42,13 +43,14 @@ pub struct Layer { /// = \sum_x (r^0 eq_0(X) \cdot expr_0(x) + r^1 eq_1(X) \cdot expr_1(x) + ...)`. /// where `vec![e_0, beta * e_1]` will be the output evaluation expressions. pub exprs: Vec>, - /// eq expression for zero checks. Length should match with `exprs` - pub eqs: Vec>, + /// Positions to place the evaluations of the base inputs of this layer. pub in_eval_expr: Vec>, /// The expressions of the evaluations from the succeeding layers, which are /// connected to the outputs of this layer. - pub outs: Vec>, + /// It format indicated as different output group + /// first tuple value is optional eq + pub outs: Vec<(Option>, Vec>)>, // For debugging purposes pub expr_names: Vec, @@ -66,13 +68,18 @@ impl Layer { name: String, ty: LayerType, exprs: Vec>, - eqs: Vec>, challenges: Vec>, in_eval_expr: Vec>, - outs: Vec>, + // first tuple value is eq + outs: Vec<(Option>, Vec>)>, expr_names: Vec, ) -> Self { - assert!(eqs.len() == 1); + assert_eq!( + outs.iter() + .map(|(_, eval_expressions)| eval_expressions.len()) + .sum::(), + exprs.len() // output eval not match with number of expression + ); let mut expr_names = expr_names; if expr_names.len() < exprs.len() { expr_names.extend(vec![ @@ -80,12 +87,13 @@ impl Layer { exprs.len() - expr_names.len() ]); } + let max_expr_degree = exprs.iter().map(|expr| expr.degree()).max().unwrap(); Self { name, ty, + max_expr_degree, challenges, exprs, - eqs, in_eval_expr, outs, expr_names, @@ -103,7 +111,7 @@ impl Layer { transcript: &mut T, ) -> SumcheckLayerProof { self.update_challenges(challenges, transcript); - let (_, out_points) = self.extract_claim_and_point(claims, challenges); + let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); let sumcheck_layer_proof = match self.ty { LayerType::Sumcheck => as SumcheckLayer>::prove( @@ -114,18 +122,25 @@ impl Layer { challenges, transcript, ), - LayerType::Zerocheck => as ZerocheckLayer>::prove( - self, - num_threads, - max_num_variables, - wit, - &out_points, - challenges, - transcript, - ), + LayerType::Zerocheck => { + let out_points = eval_and_dedup_points + .into_iter() + .map(|(_, point)| point.expect("point must exist")) + .collect_vec(); + as ZerocheckLayer>::prove( + self, + num_threads, + max_num_variables, + wit, + &out_points, + challenges, + transcript, + ) + } LayerType::Linear => { - assert!(out_points.iter().all_equal()); - as LinearLayer>::prove(self, wit, &out_points[0], transcript) + assert_eq!(eval_and_dedup_points.len(), 1); + let (_, point) = eval_and_dedup_points.remove(0); + as LinearLayer>::prove(self, wit, point.as_ref().unwrap(), transcript) } }; @@ -147,30 +162,40 @@ impl Layer { transcript: &mut Trans, ) -> Result<(), BackendError> { self.update_challenges(challenges, transcript); - let (sigmas, points) = self.extract_claim_and_point(claims, challenges); + let mut eval_and_dedup_points = self.extract_claim_and_point(claims, challenges); let LayerClaims { in_point, evals } = match self.ty { - LayerType::Sumcheck => as SumcheckLayer>::verify( - self, - max_num_variables, - proof, - &sigmas.iter().cloned().sum(), - challenges, - transcript, - )?, + LayerType::Sumcheck => { + assert_eq!(eval_and_dedup_points.len(), 1); + let (sigmas, point) = eval_and_dedup_points.remove(0); + assert!(point.is_none()); + as SumcheckLayer>::verify( + self, + max_num_variables, + proof, + &sigmas.iter().cloned().sum(), + challenges, + transcript, + )? + } LayerType::Zerocheck => as ZerocheckLayer>::verify( self, max_num_variables, proof, - sigmas, - &points, + eval_and_dedup_points, challenges, transcript, )?, LayerType::Linear => { - assert!(points.iter().all(|point| point == &points[0])); + assert_eq!(eval_and_dedup_points.len(), 1); + let (sigmas, point) = eval_and_dedup_points.remove(0); as LinearLayer>::verify( - self, proof, &sigmas, &points[0], challenges, transcript, + self, + proof, + &sigmas, + point.as_ref().unwrap(), + challenges, + transcript, )? } }; @@ -180,18 +205,31 @@ impl Layer { Ok(()) } + // extract claim and dudup point fn extract_claim_and_point( &self, claims: &[PointAndEval], challenges: &[E], - ) -> (Vec, Vec>) { + ) -> Vec<(Vec, Option>)> { self.outs .iter() - .map(|out| { - let PointAndEval { point, eval } = out.evaluate(claims, challenges); - (eval, point) + .map(|(_, out_evals)| { + let evals = out_evals + .iter() + .map(|out_eval| { + let PointAndEval { eval, .. } = out_eval.evaluate(claims, challenges); + eval + }) + .collect_vec(); + // within same group, all the point should be the same + // so we assume only take first point as representative + let point = out_evals.first().map(|out_eval| { + let PointAndEval { point, .. } = out_eval.evaluate(claims, challenges); + point + }); + (evals, point) }) - .unzip() + .collect_vec() } // generate layer challenge, if have, and set to respective challenge_id index diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index 7cf9d49c2..d0aa5f539 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -11,7 +11,7 @@ use multilinear_extensions::{ virtual_polys::VirtualPolynomialsBuilder, }; use p3_field::dot_product; -use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverState, IOPVerifierState, SumCheckSubClaim, VerifierError}, @@ -39,8 +39,7 @@ pub trait ZerocheckLayer { &self, max_num_variables: usize, proof: SumcheckLayerProof, - sigmas: Vec, - out_points: &[Point], + eval_and_dedup_points: Vec<(Vec, Option>)>, challenges: &[E], transcript: &mut impl Transcript, ) -> Result, BackendError>; @@ -56,17 +55,38 @@ impl ZerocheckLayer for Layer { challenges: &[E], transcript: &mut impl Transcript, ) -> SumcheckLayerProof { - assert!( - out_points.iter().all_equal(), - "output points not all equals len() {}", - out_points.len() + assert_eq!( + self.outs.len(), + out_points.len(), + "out eval length {} != with distinct out_point {}", + self.outs.len(), + out_points.len(), ); - assert_eq!(self.exprs.len(), out_points.len()); + let mut expr_iter = self.exprs.iter(); + let mut zero_check_exprs = Vec::with_capacity(self.outs.len()); + + let alpha_pows = get_challenge_pows(self.exprs.len(), transcript) + .into_iter() + .map(|r| Expression::Constant(Either::Right(r))) + .collect_vec(); + let mut alpha_pows_iter = alpha_pows.iter(); + + for (eq_expr, out_evals) in self.outs.iter() { + let group_length = out_evals.len(); + let zero_check_expr = expr_iter + .by_ref() + .take(group_length) + .cloned() + .zip_eq(alpha_pows_iter.by_ref().take(group_length)) + .map(|(expr, alpha)| alpha * expr) + .sum::>(); + zero_check_exprs.push(eq_expr.clone().unwrap() * zero_check_expr); + } + assert!(expr_iter.next().is_none() && alpha_pows_iter.next().is_none()); let span = entered_span!("build_out_points_eq"); let mut eqs = out_points .par_iter() - .take(1) .map(|point| { MultilinearExtension::from_evaluations_ext_vec( point.len(), @@ -86,18 +106,8 @@ impl ZerocheckLayer for Layer { .chain(eqs.iter_mut().map(|eq| Either::Right(eq))) .collect_vec(), ); - let alpha_pows = get_challenge_pows(self.exprs.len(), transcript) - .into_iter() - .map(|r| Expression::Constant(Either::Right(r))) - .collect_vec(); - let zerocheck_expr = self - .exprs - .iter() - .zip_eq(alpha_pows) - .map(|(expr, alpha)| alpha * expr) - .sum::>(); let (proof, prover_state) = IOPProverState::prove( - builder.to_virtual_polys(&[self.eqs[0].clone() * zerocheck_expr], challenges), + builder.to_virtual_polys(&[zero_check_exprs.into_iter().sum()], challenges), transcript, ); SumcheckLayerProof { @@ -110,12 +120,17 @@ impl ZerocheckLayer for Layer { &self, max_num_variables: usize, proof: SumcheckLayerProof, - sigmas: Vec, - out_points: &[Point], + eval_and_dedup_points: Vec<(Vec, Option>)>, challenges: &[E], transcript: &mut impl Transcript, ) -> Result, BackendError> { - assert_eq!(sigmas.len(), out_points.len()); + assert_eq!( + self.outs.len(), + eval_and_dedup_points.len(), + "out eval length {} != with eval_and_dedup_points {}", + self.outs.len(), + eval_and_dedup_points.len(), + ); let SumcheckLayerProof { proof: IOPProof { proofs, .. }, mut evals, @@ -123,7 +138,14 @@ impl ZerocheckLayer for Layer { let alpha_pows = get_challenge_pows(self.exprs.len(), transcript); - let sigma: E = dot_product(alpha_pows.iter().cloned(), sigmas.iter().cloned()); + let sigma: E = dot_product( + alpha_pows.iter().copied(), + eval_and_dedup_points + .iter() + .map(|(sigmas, _)| sigmas) + .flatten() + .copied(), + ); let SumCheckSubClaim { point: in_point, @@ -135,7 +157,7 @@ impl ZerocheckLayer for Layer { proofs, }, &VPAuxInfo { - max_degree: self.exprs[0].degree(), + max_degree: self.max_expr_degree + 1, // +1 due to eq max_num_variables, phantom: PhantomData, }, @@ -144,12 +166,12 @@ impl ZerocheckLayer for Layer { let in_point = in_point.into_iter().map(|c| c.elements).collect_vec(); // eval eq and set to respective witin - out_points + eval_and_dedup_points .iter() - .map(|out_point| eq_eval(out_point, &in_point)) - .zip(&self.eqs) - .for_each(|(eval, eq_expr)| match eq_expr { - Expression::WitIn(witin_id) => evals[*witin_id as usize] = eval, + .map(|(_, out_point)| eq_eval(out_point.as_ref().unwrap(), &in_point)) + .zip(&self.outs) + .for_each(|(eval, (eq_expr, _))| match eq_expr { + Some(Expression::WitIn(witin_id)) => evals[*witin_id as usize] = eval, _ => unreachable!(), }); @@ -157,9 +179,9 @@ impl ZerocheckLayer for Layer { let got_claim = self .exprs .iter() - .zip_eq(&self.eqs) + .zip(&self.outs) .zip_eq(alpha_pows) - .map(|((expr, eq_expr), alpha)| { + .map(|((expr, (eq_expr, _)), alpha)| { alpha * eval_by_expr_with_instance( &[], @@ -167,7 +189,7 @@ impl ZerocheckLayer for Layer { &[], &[], challenges, - &(expr * eq_expr), + &(expr * eq_expr.clone().unwrap()), ) .right() .unwrap() diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 9d11819d7..64aadd2af 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -94,6 +94,8 @@ impl MockProver { let expects = layer .outs .iter() + .map(|(_, out)| out) + .flatten() .map(|out| out.mock_evaluate(&evaluations, &challenges, 1 << num_vars)) .collect_vec(); match layer.ty { @@ -117,7 +119,7 @@ impl MockProver { { if expect != got { return Err(MockProverError::ZerocheckExpressionNotMatch( - out.clone(), + out.1[0].clone(), expr.clone(), expect, got, @@ -130,7 +132,7 @@ impl MockProver { { if expect != got { return Err(MockProverError::LinearExpressionNotMatch( - out.clone(), + out.1[0].clone(), expr.clone(), expect, got, diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index 619888e1a..39aa2956c 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -17,18 +17,17 @@ use multilinear_extensions::{ util::ceil_log2, wit_infer_by_expr, }; -use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; -use p3_goldilocks::Goldilocks; +use p3_field::PrimeCharacteristicRing; use sumcheck::util::optimal_sumcheck_threads; use tiny_keccak::keccakf; use transcript::{BasicTranscript, Transcript}; #[derive(Clone, Debug, Default)] -struct KeccakParams {} +pub struct KeccakParams {} #[derive(Clone, Debug, Default)] -struct KeccakLayout { +pub struct KeccakLayout { _params: KeccakParams, committed_bits_id: usize, @@ -208,10 +207,9 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Iota:: compute output"), LayerType::Zerocheck, exprs, - vec![eq.0.expr()], vec![], chi_output.iter().map(|e| e.1.clone()).collect_vec(), - round_output.to_vec(), + vec![(Some(eq.0.expr()), round_output.to_vec())], vec![], )); @@ -232,10 +230,12 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Chi:: apply rho, pi and chi"), LayerType::Zerocheck, exprs, - vec![eq.0.expr()], vec![], theta_output.iter().map(|e| e.1.clone()).collect_vec(), - chi_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![( + Some(eq.0.expr()), + chi_output.iter().map(|e| e.1.clone()).collect_vec(), + )], vec![], )); @@ -255,10 +255,12 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Theta::compute output"), LayerType::Zerocheck, exprs, - vec![eq.0.expr()], vec![], d_and_state.iter().map(|e| e.1.clone()).collect_vec(), - theta_output.iter().map(|e| e.1.clone()).collect_vec(), + vec![( + Some(eq.0.expr()), + theta_output.iter().map(|e| e.1.clone()).collect_vec(), + )], vec![], )); @@ -274,14 +276,16 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Theta::compute D[x][z]"), LayerType::Zerocheck, d_exprs, - vec![eq.0.expr()], vec![], c.iter().map(|e| e.1.clone()).collect_vec(), - d.iter().map(|e| e.1.clone()).collect_vec(), + vec![( + Some(eq.0.expr()), + d.iter().map(|e| e.1.clone()).collect_vec(), + )], vec![], )); - let (state, [eq]) = chip.allocate_wits_in_zero_layer::(); + let (state, [eq0, eq1]) = chip.allocate_wits_in_zero_layer::(); let state_wits = state.iter().map(|s| s.0.expr()).collect_vec(); // Compute C[][] from state @@ -296,14 +300,18 @@ impl ProtocolBuilder for KeccakLayout { format!("Round {round}: Theta::compute C[x][z]"), LayerType::Zerocheck, chain!(c_exprs, id_exprs).collect_vec(), - vec![eq.0.expr()], vec![], state.iter().map(|t| t.1.clone()).collect_vec(), - chain!( - c.iter().map(|e| e.1.clone()), - state2.iter().map(|e| e.1.clone()) - ) - .collect_vec(), + vec![ + ( + Some(eq0.0.expr()), + c.iter().map(|e| e.1.clone()).collect_vec(), + ), + ( + Some(eq1.0.expr()), + state2.iter().map(|e| e.1.clone()).collect_vec(), + ), + ], vec![], )); @@ -401,6 +409,8 @@ where layer .outs .iter() + .map(|(_, out_eval)| out_eval) + .flatten() .zip_eq(¤t_layer_output) .for_each(|(out_eval, out_mle)| match out_eval { EvalExpression::Single(out) => { @@ -410,10 +420,6 @@ where }); } - // Assumes one input instance - let total_witness_size: usize = layer_wits.iter().map(|layer| layer.bases.len()).sum(); - dbg!(total_witness_size); - layer_wits.reverse(); GKRCircuitWitness { layers: layer_wits } @@ -459,18 +465,23 @@ fn rho_and_pi_permutation() -> Vec { pi(&rho(&perm)) } -pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { - type E = BinomialExtensionField; +pub fn setup_gkr_circuit() -> (KeccakLayout, GKRCircuit) { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + (layout, chip.gkr_circuit()) +} + +pub fn run_keccakf( + (layout, gkr_circuit): (KeccakLayout, GKRCircuit), + states: Vec<[u64; 25]>, + verify: bool, + test: bool, +) { let num_instances = states.len(); let log2_num_instances = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instances); - let params = KeccakParams {}; - let (layout, chip) = KeccakLayout::build(params); - let gkr_circuit = chip.gkr_circuit(); - let bits = keccak_witness(&states); - // get the view only phase 1 witness, since it need to be commit thus can't be in-place change let phase1_witness = layout.phase1_witness_group(KeccakTrace { bits }); let mut prover_transcript = BasicTranscript::::new(b"protocol"); @@ -480,7 +491,7 @@ pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { let out_evals = { let mut point = Point::new(); - point.extend(prover_transcript.sample_vec(1).to_vec()); + point.extend(prover_transcript.sample_vec(log2_num_instances).to_vec()); if test { // sanity check on first instance only @@ -500,13 +511,13 @@ pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { keccakf(&mut state[0]); // TODO test this - // assert_eq!( - // keccak_witness(&state) // result from tiny keccak - // .into_iter() - // .map(|b: MultilinearExtension<'_, E>| b.get_base_field_vec()[0]) - // .collect_vec(), - // result_from_witness - // ); + assert_eq!( + keccak_witness(&state) // result from tiny keccak + .into_iter() + .map(|b: MultilinearExtension<'_, E>| b.get_base_field_vec()[0]) + .collect_vec(), + result_from_witness + ); } gkr_witness.layers[0] @@ -522,7 +533,7 @@ pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove( num_threads, - 1, + log2_num_instances, gkr_witness, &out_evals, &[], @@ -534,6 +545,10 @@ pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { { let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + // TODO verify output + let mut point = Point::new(); + point.extend(verifier_transcript.sample_vec(1).to_vec()); + gkr_circuit .verify(1, gkr_proof, &out_evals, &[], &mut verifier_transcript) .expect("GKR verify failed"); @@ -545,12 +560,13 @@ pub fn run_keccakf(states: Vec<[u64; 25]>, verify: bool, test: bool) { #[cfg(test)] mod tests { - use rand::{Rng, SeedableRng}; - use super::*; + use ff_ext::GoldilocksExt2; + use rand::{Rng, SeedableRng}; #[test] fn test_keccakf() { + type E = GoldilocksExt2; let _ = tracing_subscriber::fmt() .with_max_level(tracing::Level::TRACE) .with_test_writer() @@ -559,10 +575,10 @@ mod tests { let random_u64: u64 = rand::random(); // Use seeded rng for debugging convenience let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); - let num_instance = 2; + let num_instance = 4; let states: Vec<[u64; 25]> = (0..num_instance) .map(|_| std::array::from_fn(|_| rng.gen())) .collect_vec(); - run_keccakf(states, true, true); + run_keccakf::(setup_gkr_circuit(), states, false, false); } } diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index 9b0ec2c98..d7715afe2 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -1,7 +1,7 @@ mod bitwise_keccakf; mod lookup_keccakf; mod utils; -pub use bitwise_keccakf::run_keccakf; +pub use bitwise_keccakf::{run_keccakf, setup_gkr_circuit}; // pub use lookup_keccakf::{ // AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KeccakLayout, KeccakParams, KeccakTrace, RANGE_LOOKUPS, // RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, run_faster_keccakf, diff --git a/multilinear_extensions/src/virtual_poly.rs b/multilinear_extensions/src/virtual_poly.rs index 2e3cbb105..361ff412d 100644 --- a/multilinear_extensions/src/virtual_poly.rs +++ b/multilinear_extensions/src/virtual_poly.rs @@ -157,7 +157,7 @@ impl<'a, E: ExtensionField> VirtualPolynomial<'a, E> { Expression::WitIn(witin_id) => { self.flattened_ml_extensions[*witin_id as usize].num_vars() } - _ => unimplemented!(), + e => unimplemented!("unimplemented {:?}", e), } }) .all_equal() diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index cc3f81125..b45d894be 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -423,7 +423,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .to_vec(), 5 => sumcheck_code_gen!(5, false, |i| &f[prod[i]], || get_poly_meta()) .to_vec(), - _ => unimplemented!("do not support degree {} > 5", prod.len()), + 6 => sumcheck_code_gen!(5, false, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + _ => unimplemented!("do not support degree {} > 6", prod.len()), }; uni_variate From 58b74dedd24860668cc33e885134f7395b40a572 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 21 May 2025 17:22:29 +0800 Subject: [PATCH 10/28] modify benchmark --- gkr_iop/benches/bitwise_keccakf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gkr_iop/benches/bitwise_keccakf.rs b/gkr_iop/benches/bitwise_keccakf.rs index 95686d0bd..83ca69e3f 100644 --- a/gkr_iop/benches/bitwise_keccakf.rs +++ b/gkr_iop/benches/bitwise_keccakf.rs @@ -12,7 +12,7 @@ const NUM_SAMPLES: usize = 10; fn keccak_f_fn(c: &mut Criterion) { // Benchmark the proving time - for log_instances in 5..8 { + for log_instances in 10..12 { let num_instance = 1 << log_instances; // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("keccak_f_{}", num_instance)); From 7920210986afc460c2d53ef76c6ed0f647069a01 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 21 May 2025 21:10:11 +0800 Subject: [PATCH 11/28] add more tracing and refactor --- Cargo.lock | 2 + gkr_iop/Cargo.toml | 2 + gkr_iop/src/bin/bitwise_keccak.rs | 71 +++++++++++++ gkr_iop/src/gkr.rs | 10 +- gkr_iop/src/gkr/layer/zerocheck_layer.rs | 7 +- gkr_iop/src/lib.rs | 78 +++++++++++++- gkr_iop/src/precompiles/bitwise_keccakf.rs | 116 +++------------------ gkr_iop/src/utils.rs | 21 ++++ sumcheck/src/prover.rs | 4 +- 9 files changed, 204 insertions(+), 107 deletions(-) create mode 100644 gkr_iop/src/bin/bitwise_keccak.rs diff --git a/Cargo.lock b/Cargo.lock index b01c7dd3f..c6ebbe73c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1111,6 +1111,7 @@ name = "gkr_iop" version = "0.1.0" dependencies = [ "ark-std", + "clap", "criterion", "either", "ff_ext", @@ -1127,6 +1128,7 @@ dependencies = [ "thiserror 1.0.69", "tiny-keccak", "tracing", + "tracing-forest", "tracing-subscriber", "transcript", "witness", diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index 38c43a968..a92dc48d0 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -10,6 +10,8 @@ repository.workspace = true version.workspace = true [dependencies] +clap.workspace = true +tracing-forest.workspace = true ark-std = { version = "0.5" } ff_ext = { path = "../ff_ext" } itertools.workspace = true diff --git a/gkr_iop/src/bin/bitwise_keccak.rs b/gkr_iop/src/bin/bitwise_keccak.rs new file mode 100644 index 000000000..f04e4a221 --- /dev/null +++ b/gkr_iop/src/bin/bitwise_keccak.rs @@ -0,0 +1,71 @@ +use clap::{Parser, command}; +use ff_ext::GoldilocksExt2; +use gkr_iop::precompiles::{run_keccakf, setup_gkr_circuit}; +use itertools::Itertools; +use rand::{Rng, SeedableRng}; +use tracing::level_filters::LevelFilter; +use tracing_forest::ForestLayer; +use tracing_subscriber::{ + EnvFilter, Registry, filter::filter_fn, fmt, layer::SubscriberExt, util::SubscriberInitExt, +}; +/// Prove the execution of a fixed RISC-V program. +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + // Profiling granularity. + // Setting any value restricts logs to profiling information + #[arg(long)] + profiling: Option, +} + +fn main() { + let args = { + let mut args = Args::parse(); + args + }; + type E = GoldilocksExt2; + + // default filter + let default_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::DEBUG.into()) + .from_env_lossy(); + + // filter by profiling level; + // spans with level i contain the field "profiling_{i}" + // this restricts statistics to first (args.profiling) levels + let profiling_level = args.profiling.unwrap_or(1); + let filter_by_profiling_level = filter_fn(move |metadata| { + (1..=profiling_level) + .map(|i| format!("profiling_{i}")) + .any(|field| metadata.fields().field(&field).is_some()) + }); + + let fmt_layer = fmt::layer() + .compact() + .with_thread_ids(false) + .with_thread_names(false) + .without_time(); + + Registry::default() + .with(args.profiling.is_some().then_some(ForestLayer::default())) + .with(fmt_layer) + // if some profiling granularity is specified, use the profiling filter, + // otherwise use the default + .with( + args.profiling + .is_some() + .then_some(filter_by_profiling_level), + ) + .with(args.profiling.is_none().then_some(default_filter)) + .init(); + + let random_u64: u64 = rand::random(); + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(random_u64); + let num_instance = 1024; + let states: Vec<[u64; 25]> = (0..num_instance) + .map(|_| std::array::from_fn(|_| rng.gen())) + .collect_vec(); + let circuit_setup = setup_gkr_circuit(); + run_keccakf::(circuit_setup, states, false, false); +} diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index e4ae98d6e..f5a18f857 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -3,6 +3,7 @@ use itertools::{Itertools, chain, izip}; use layer::{Layer, LayerWitness, sumcheck_layer::SumcheckLayerProof}; use multilinear_extensions::mle::{Point, PointAndEval}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; +use sumcheck::macros::{entered_span, exit_span}; use transcript::Transcript; use crate::{error::BackendError, evaluation::EvalExpression}; @@ -69,20 +70,25 @@ impl GKRCircuit { // running evals is a global referable within chip running_evals.resize(self.n_evaluations, PointAndEval::default()); let mut challenges = challenges.to_vec(); + let span = entered_span!("layer_proof", profiling_2 = true); let sumcheck_proofs = izip!(&self.layers, circuit_wit.layers) .enumerate() .map(|(i, (layer, layer_wit))| { tracing::info!("prove layer {i} layer with layer name {}", layer.name); - layer.prove( + let span = entered_span!("per_layer_proof", profiling_3 = true); + let res = layer.prove( num_threads, max_num_variables, layer_wit, &mut running_evals, &mut challenges, transcript, - ) + ); + exit_span!(span); + res }) .collect_vec(); + exit_span!(span); let opening_evaluations = self.opening_evaluations(&running_evals, &challenges); diff --git a/gkr_iop/src/gkr/layer/zerocheck_layer.rs b/gkr_iop/src/gkr/layer/zerocheck_layer.rs index d0aa5f539..c631ae23e 100644 --- a/gkr_iop/src/gkr/layer/zerocheck_layer.rs +++ b/gkr_iop/src/gkr/layer/zerocheck_layer.rs @@ -72,6 +72,7 @@ impl ZerocheckLayer for Layer { .collect_vec(); let mut alpha_pows_iter = alpha_pows.iter(); + let span = entered_span!("gen_expr", profiling_4 = true); for (eq_expr, out_evals) in self.outs.iter() { let group_length = out_evals.len(); let zero_check_expr = expr_iter @@ -83,8 +84,10 @@ impl ZerocheckLayer for Layer { .sum::>(); zero_check_exprs.push(eq_expr.clone().unwrap() * zero_check_expr); } + exit_span!(span); assert!(expr_iter.next().is_none() && alpha_pows_iter.next().is_none()); - let span = entered_span!("build_out_points_eq"); + + let span = entered_span!("build_out_points_eq", profiling_4 = true); let mut eqs = out_points .par_iter() .map(|point| { @@ -106,10 +109,12 @@ impl ZerocheckLayer for Layer { .chain(eqs.iter_mut().map(|eq| Either::Right(eq))) .collect_vec(), ); + let span = entered_span!("IOPProverState::prove", profiling_4 = true); let (proof, prover_state) = IOPProverState::prove( builder.to_virtual_polys(&[zero_check_exprs.into_iter().sum()], challenges), transcript, ); + exit_span!(span); SumcheckLayerProof { proof, evals: prover_state.get_mle_flatten_final_evaluations(), diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 3029416f5..55c1e0d94 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -1,10 +1,14 @@ use std::marker::PhantomData; use chip::Chip; +use evaluation::EvalExpression; use ff_ext::ExtensionField; -use gkr::{GKRCircuit, GKRCircuitWitness}; +use gkr::{GKRCircuit, GKRCircuitWitness, layer::LayerWitness}; +use itertools::Itertools; use multilinear_extensions::mle::ArcMultilinearExtension; +use sumcheck::macros::{entered_span, exit_span}; use transcript::Transcript; +use utils::infer_layer_witness; pub mod chip; pub mod error; @@ -13,7 +17,7 @@ pub mod gkr; pub mod precompiles; pub mod utils; -pub type Phase1WitnessGroup<'a, E> = Vec>>; +pub type Phase1WitnessGroup<'a, E> = Vec>; pub trait ProtocolBuilder: Sized { type Params; @@ -51,10 +55,76 @@ where /// GKR witness. fn gkr_witness( &self, - chip: &GKRCircuit, + circuit: &GKRCircuit, phase1_witness_group: Phase1WitnessGroup<'a, E>, challenges: &[E], - ) -> GKRCircuitWitness<'a, E>; + ) -> GKRCircuitWitness<'a, E> { + if cfg!(debug_assertions) { + // phase 1 input must all in base field + phase1_witness_group.iter().for_each(|mle| { + let _ = mle.get_base_field_vec(); + }); + } + + // layer order from output to input + let n_layers = 100; + let mut layer_wits = Vec::>::with_capacity(n_layers + 1); + + layer_wits.push(LayerWitness::new(phase1_witness_group.clone())); + let mut witness_mle_flattern = vec![None; circuit.n_evaluations]; + + // set input to witness_mle_flattern via first layer in_eval_expr + circuit.layers.last().map(|first_layer| { + first_layer + .in_eval_expr + .iter() + .enumerate() + .for_each(|(index, eval_expr)| match eval_expr { + EvalExpression::Single(witin) => { + witness_mle_flattern[*witin] = Some(phase1_witness_group[index].clone()); + } + other => unimplemented!("{:?}", other), + }) + }); + + // generate all layer witness from input to output + for (i, layer) in circuit.layers.iter().rev().enumerate() { + tracing::info!("generating input {i} layer with layer name {}", layer.name); + let span = entered_span!("per_layer_gen_witness", profiling_2 = true); + // process in_evals to prepare layer witness + let current_layer_wits = layer + .in_eval_expr + .iter() + .map(|eval| match eval { + EvalExpression::Single(witin) => witness_mle_flattern[*witin] + .clone() + .expect("witness must exist"), + other => unimplemented!("{:?}", other), + }) + .collect_vec(); + let current_layer_output = infer_layer_witness(&layer, ¤t_layer_wits, challenges); + layer_wits.push(LayerWitness::new(current_layer_wits)); + + // process out to prepare output witness + layer + .outs + .iter() + .map(|(_, out_eval)| out_eval) + .flatten() + .zip_eq(¤t_layer_output) + .for_each(|(out_eval, out_mle)| match out_eval { + EvalExpression::Single(out) => { + witness_mle_flattern[*out] = Some(out_mle.clone()) + } + other => unimplemented!("{:?}", other), + }); + exit_span!(span); + } + + layer_wits.reverse(); + + GKRCircuitWitness { layers: layer_wits } + } } // TODO: the following trait consists of `commit_phase1`, `commit_phase2`, diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index 39aa2956c..30a8c9fe8 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -5,21 +5,22 @@ use crate::{ chip::Chip, evaluation::EvalExpression, gkr::{ - GKRCircuit, GKRCircuitWitness, GKRProverOutput, - layer::{Layer, LayerType, LayerWitness}, + GKRCircuit, GKRProverOutput, + layer::{Layer, LayerType}, }, }; use ff_ext::ExtensionField; use itertools::{Itertools, chain, iproduct}; use multilinear_extensions::{ Expression, ToExpr, - mle::{ArcMultilinearExtension, MultilinearExtension, Point, PointAndEval}, + mle::{MultilinearExtension, Point, PointAndEval}, util::ceil_log2, - wit_infer_by_expr, }; use p3_field::PrimeCharacteristicRing; - -use sumcheck::util::optimal_sumcheck_threads; +use sumcheck::{ + macros::{entered_span, exit_span}, + util::optimal_sumcheck_threads, +}; use tiny_keccak::keccakf; use transcript::{BasicTranscript, Transcript}; @@ -326,103 +327,13 @@ pub struct KeccakTrace<'a, E: ExtensionField> { pub bits: [MultilinearExtension<'a, E>; STATE_SIZE], } -pub fn infer_layer_witness<'a, E>( - layer: &Layer, - layer_wits: &[ArcMultilinearExtension<'a, E>], - challenges: &[E], -) -> Vec> -where - E: ExtensionField, -{ - layer - .exprs - .iter() - .map(|expr| wit_infer_by_expr(&[], layer_wits, &[], &[], challenges, expr)) - .collect_vec() -} - impl<'a, E> ProtocolWitnessGenerator<'a, E> for KeccakLayout where E: ExtensionField, { type Trace = KeccakTrace<'a, E>; - fn phase1_witness_group(&self, phase1: Self::Trace) -> Phase1WitnessGroup<'a, E> { - vec![phase1.bits.into_iter().map(Arc::new).collect_vec()] - } - - fn gkr_witness( - &self, - circuit: &GKRCircuit, - phase1_witness_group: Phase1WitnessGroup<'a, E>, - challenges: &[E], - ) -> GKRCircuitWitness<'a, E> { - let input_bits: Vec> = - phase1_witness_group[self.committed_bits_id].clone(); - - if cfg!(debug_assertions) { - // phase 1 input must all in base field - input_bits.iter().for_each(|mle| { - let _ = mle.get_base_field_vec(); - }); - } - - // layer order from output to input - let n_layers = 100; - let mut layer_wits = Vec::>::with_capacity(n_layers + 1); - - layer_wits.push(LayerWitness::new(input_bits.clone())); - let mut witness_mle_flattern = vec![None; circuit.n_evaluations]; - - // set input to witness_mle_flattern via first layer in_eval_expr - circuit.layers.last().map(|first_layer| { - first_layer - .in_eval_expr - .iter() - .enumerate() - .for_each(|(index, eval_expr)| match eval_expr { - EvalExpression::Single(witin) => { - witness_mle_flattern[*witin] = Some(input_bits[index].clone()); - } - other => unimplemented!("{:?}", other), - }) - }); - - // generate all layer witness from input to output - for (i, layer) in circuit.layers.iter().rev().enumerate() { - tracing::info!("generating input {i} layer with layer name {}", layer.name); - // process in_evals to prepare layer witness - let current_layer_wits = layer - .in_eval_expr - .iter() - .map(|eval| match eval { - EvalExpression::Single(witin) => witness_mle_flattern[*witin] - .clone() - .expect("witness must exist"), - other => unimplemented!("{:?}", other), - }) - .collect_vec(); - let current_layer_output = infer_layer_witness(&layer, ¤t_layer_wits, challenges); - layer_wits.push(LayerWitness::new(current_layer_wits)); - - // process out to prepare output witness - layer - .outs - .iter() - .map(|(_, out_eval)| out_eval) - .flatten() - .zip_eq(¤t_layer_output) - .for_each(|(out_eval, out_mle)| match out_eval { - EvalExpression::Single(out) => { - witness_mle_flattern[*out] = Some(out_mle.clone()) - } - other => unimplemented!("{:?}", other), - }); - } - - layer_wits.reverse(); - - GKRCircuitWitness { layers: layer_wits } + phase1.bits.into_iter().map(Arc::new).collect_vec() } } @@ -481,13 +392,18 @@ pub fn run_keccakf( let log2_num_instances = ceil_log2(num_instances); let num_threads = optimal_sumcheck_threads(log2_num_instances); + let span = entered_span!("keccak_witness", profiling_1 = true); let bits = keccak_witness(&states); - // get the view only phase 1 witness, since it need to be commit thus can't be in-place change + exit_span!(span); + let span = entered_span!("phase1_witness_group", profiling_1 = true); let phase1_witness = layout.phase1_witness_group(KeccakTrace { bits }); + exit_span!(span); let mut prover_transcript = BasicTranscript::::new(b"protocol"); // Omit the commit phase1 and phase2. + let span = entered_span!("gkr_witness", profiling_1 = true); let gkr_witness = layout.gkr_witness(&gkr_circuit, phase1_witness, &[]); + exit_span!(span); let out_evals = { let mut point = Point::new(); @@ -530,6 +446,7 @@ pub fn run_keccakf( .collect_vec() }; + let span = entered_span!("prove", profiling_1 = true); let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove( num_threads, @@ -540,6 +457,7 @@ pub fn run_keccakf( &mut prover_transcript, ) .expect("Failed to prove phase"); + exit_span!(span); if verify { { @@ -579,6 +497,6 @@ mod tests { let states: Vec<[u64; 25]> = (0..num_instance) .map(|_| std::array::from_fn(|_| rng.gen())) .collect_vec(); - run_keccakf::(setup_gkr_circuit(), states, false, false); + run_keccakf::(setup_gkr_circuit(), states, false, true); } } diff --git a/gkr_iop/src/utils.rs b/gkr_iop/src/utils.rs index f0d4a06b8..e1bad9412 100644 --- a/gkr_iop/src/utils.rs +++ b/gkr_iop/src/utils.rs @@ -1,5 +1,26 @@ use std::sync::Arc; +use ff_ext::ExtensionField; +use multilinear_extensions::{mle::ArcMultilinearExtension, wit_infer_by_expr}; +use rayon::iter::{IntoParallelRefIterator, ParallelIterator}; + +use crate::gkr::layer::Layer; + +pub fn infer_layer_witness<'a, E>( + layer: &Layer, + layer_wits: &[ArcMultilinearExtension<'a, E>], + challenges: &[E], +) -> Vec> +where + E: ExtensionField, +{ + layer + .exprs + .par_iter() + .map(|expr| wit_infer_by_expr(&[], layer_wits, &[], &[], challenges, expr)) + .collect::>() +} + pub trait SliceVector { fn slice_vector(&self) -> Vec<&[T]>; } diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index b45d894be..33a4af290 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -686,7 +686,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .to_vec(), 5 => sumcheck_code_gen!(5, true, |i| &f[prod[i]], || get_poly_meta()) .to_vec(), - _ => unimplemented!("do not support degree {} > 5", prod.len()), + 6 => sumcheck_code_gen!(5, true, |i| &f[prod[i]], || get_poly_meta()) + .to_vec(), + _ => unimplemented!("do not support degree {} > 6", prod.len()), }; sum.iter_mut() .for_each(|sum| either::for_both!(*scalar, scalar => *sum *= scalar)); From b78f9c56457c45ebb684599aaa0f935a83247776 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 28 May 2025 13:41:16 +0800 Subject: [PATCH 12/28] chores: cleanup --- gkr_iop/src/bin/bitwise_keccak.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/gkr_iop/src/bin/bitwise_keccak.rs b/gkr_iop/src/bin/bitwise_keccak.rs index f04e4a221..c99df3cb9 100644 --- a/gkr_iop/src/bin/bitwise_keccak.rs +++ b/gkr_iop/src/bin/bitwise_keccak.rs @@ -19,10 +19,7 @@ struct Args { } fn main() { - let args = { - let mut args = Args::parse(); - args - }; + let args = Args::parse(); type E = GoldilocksExt2; // default filter From 3cd93ea76a92a17f2a92db8cb2024f1204d45d08 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Wed, 28 May 2025 15:08:32 +0800 Subject: [PATCH 13/28] wip for lookup version --- ceno_zkvm/src/structs.rs | 27 +- gkr_iop/benches/lookup_keccakf.rs | 78 +- gkr_iop/src/precompiles/lookup_keccakf.rs | 2709 +++++++++++---------- gkr_iop/src/precompiles/mod.rs | 10 +- multilinear_extensions/src/expression.rs | 9 +- 5 files changed, 1414 insertions(+), 1419 deletions(-) diff --git a/ceno_zkvm/src/structs.rs b/ceno_zkvm/src/structs.rs index a22243969..86063769c 100644 --- a/ceno_zkvm/src/structs.rs +++ b/ceno_zkvm/src/structs.rs @@ -7,6 +7,7 @@ use crate::{ witness::LkMultiplicity, }; use ceno_emul::{CENO_PLATFORM, KeccakSpec, Platform, StepRecord, SyscallSpec}; +use either::Either; use ff_ext::ExtensionField; use gkr_iop::{ gkr::{GKRCircuitOutput, GKRCircuitWitness}, @@ -73,32 +74,6 @@ impl_expr_from_unsigned!(RAMType); pub type PointAndEval = multilinear_extensions::mle::PointAndEval; -impl Default for PointAndEval { - fn default() -> Self { - Self { - point: vec![], - eval: E::ZERO, - } - } -} - -impl PointAndEval { - /// Construct a new pair of point and eval. - /// Caller gives up ownership - pub fn new(point: Point, eval: F) -> Self { - Self { point, eval } - } - - /// Construct a new pair of point and eval. - /// Performs deep copy. - pub fn new_from_ref(point: &Point, eval: &F) -> Self { - Self { - point: (*point).clone(), - eval: eval.clone(), - } - } -} - #[derive(Clone)] pub struct ProvingKey { pub vk: VerifyingKey, diff --git a/gkr_iop/benches/lookup_keccakf.rs b/gkr_iop/benches/lookup_keccakf.rs index b214f402c..a9ab4f5f0 100644 --- a/gkr_iop/benches/lookup_keccakf.rs +++ b/gkr_iop/benches/lookup_keccakf.rs @@ -1,39 +1,39 @@ -// use std::time::Duration; - -// use criterion::*; -// use gkr_iop::precompiles::run_faster_keccakf; - -// use rand::{Rng, SeedableRng}; -// criterion_group!(benches, keccak_f_fn); -// criterion_main!(benches); - -// const NUM_SAMPLES: usize = 10; - -// fn keccak_f_fn(c: &mut Criterion) { -// // expand more input size once runtime is acceptable -// let mut group = c.benchmark_group("keccakf"); -// group.sample_size(NUM_SAMPLES); - -// // Benchmark the proving time -// group.bench_function(BenchmarkId::new("keccakf", "keccakf"), |b| { -// b.iter_custom(|iters| { -// let mut time = Duration::new(0, 0); -// for _ in 0..iters { -// // Use seeded rng for debugging convenience -// let mut rng = rand::rngs::StdRng::seed_from_u64(42); -// let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); -// let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); - -// let instant = std::time::Instant::now(); -// #[allow(clippy::unit_arg)] -// black_box(run_faster_keccakf(vec![state1, state2], false, false)); -// let elapsed = instant.elapsed(); -// time += elapsed; -// } - -// time -// }); -// }); - -// group.finish(); -// } +use std::time::Duration; + +use criterion::*; +use gkr_iop::precompiles::run_faster_keccakf; + +use rand::{Rng, SeedableRng}; +criterion_group!(benches, keccak_f_fn); +criterion_main!(benches); + +const NUM_SAMPLES: usize = 10; + +fn keccak_f_fn(c: &mut Criterion) { + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group("keccakf"); + group.sample_size(NUM_SAMPLES); + + // Benchmark the proving time + group.bench_function(BenchmarkId::new("keccakf", "keccakf"), |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); + let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); + + let instant = std::time::Instant::now(); + #[allow(clippy::unit_arg)] + black_box(run_faster_keccakf(vec![state1, state2], false, false)); + let elapsed = instant.elapsed(); + time += elapsed; + } + + time + }); + }); + + group.finish(); +} diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 97e2f29ca..5a0c0d1c6 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -1,1348 +1,1361 @@ -// use std::{array, cmp::Ordering, marker::PhantomData, sync::Arc}; - -// use ff_ext::{ExtensionField, SmallField}; -// use itertools::{Itertools, chain, iproduct, zip_eq}; -// use ndarray::{ArrayView, Ix2, Ix3, s}; -// use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; -// use p3_goldilocks::Goldilocks; -// use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; -// use serde::{Deserialize, Serialize}; -// use subprotocols::expression::{Constant, Expression, Witness}; -// use tiny_keccak::keccakf; -// use transcript::BasicTranscript; -// use witness::{InstancePaddingStrategy, RowMajorMatrix}; - -// use crate::{ -// ProtocolBuilder, ProtocolWitnessGenerator, -// chip::Chip, -// evaluation::{EvalExpression, PointAndEval}, -// gkr::{ -// GKRCircuitOutput, GKRCircuitWitness, GKRProverOutput, -// layer::{Layer, LayerType, LayerWitness}, -// }, -// precompiles::utils::{MaskRepresentation, not8_expr, zero_expr}, -// }; - -// use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; - -// type E = BinomialExtensionField; - -// #[derive(Clone, Debug, Serialize, Deserialize)] -// pub struct KeccakParams {} - -// #[derive(Clone, Debug, Default, Serialize, Deserialize)] -// pub struct KeccakLayerLayout { -// c_aux: Vec, -// c_temp: Vec, -// c_rot: Vec, -// d: Vec, -// theta_output: Vec, -// rotation_witness: Vec, -// rhopi_output: Vec, -// nonlinear: Vec, -// chi_output: Vec, -// iota_output: Vec, -// } - -// #[derive(Clone, Debug, Default, Serialize, Deserialize)] -// pub struct KeccakLayout { -// keccak_input8: Vec, -// keccak_layers: [KeccakLayerLayout; ROUNDS], -// _marker: PhantomData, -// } - -// fn expansion_expr(expansion: &[(usize, Witness)]) -> Expression { -// let (total, ret) = expansion -// .iter() -// .rev() -// .fold((0, zero_expr()), |acc, (sz, felt)| { -// ( -// acc.0 + sz, -// acc.1 * Expression::Const(Constant::Base(1 << sz)) + (*felt).into(), -// ) -// }); - -// assert_eq!(total, SIZE); -// ret -// } - -// /// Compute an adequate split of 64-bits into chunks for performing a rotation -// /// by `delta`. The first element of the return value is the vec of chunk sizes. -// /// The second one is the length of its suffix that needs to be rotated -// fn rotation_split(delta: usize) -> (Vec, usize) { -// let delta = delta % 64; - -// if delta == 0 { -// return (vec![32, 32], 0); -// } - -// // This split meets all requirements except for <= 16 sizes -// let split32 = match delta.cmp(&32) { -// Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], -// Ordering::Equal => vec![32, 32], -// Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], -// }; - -// // Split off large chunks -// let split16 = split32 -// .into_iter() -// .flat_map(|size| { -// assert!(size < 32); -// if size <= 16 { -// vec![size] -// } else { -// vec![16, size - 16] -// } -// }) -// .collect_vec(); - -// let mut sum = 0; -// for (i, size) in split16.iter().rev().enumerate() { -// sum += size; -// if sum == delta { -// return (split16, i + 1); -// } -// } - -// panic!(); -// } - -// struct ConstraintSystem { -// expressions: Vec, -// expr_names: Vec, -// evals: Vec, -// and_lookups: Vec, -// xor_lookups: Vec, -// range_lookups: Vec, -// } - -// impl ConstraintSystem { -// fn new() -> Self { -// ConstraintSystem { -// expressions: vec![], -// evals: vec![], -// expr_names: vec![], -// and_lookups: vec![], -// xor_lookups: vec![], -// range_lookups: vec![], -// } -// } - -// fn add_constraint(&mut self, expr: Expression, name: String) { -// self.expressions.push(expr); -// self.evals.push(zero_eval()); -// self.expr_names.push(name); -// } - -// fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { -// self.and_lookups.push(CenoLookup::And(a, b, c)); -// } - -// fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { -// self.xor_lookups.push(CenoLookup::Xor(a, b, c)); -// } - -// /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. -// /// In general it can be done by two U16 checks: one for `value` and one for -// /// `value << (16 - size)`. -// fn lookup_range(&mut self, value: Expression, size: usize) { -// assert!(size <= 16); -// self.range_lookups.push(CenoLookup::U16(value.clone())); -// if size < 16 { -// self.range_lookups.push(CenoLookup::U16( -// value * Expression::Const(Constant::Base(1 << (16 - size))), -// )) -// } -// } - -// fn constrain_eq(&mut self, lhs: Expression, rhs: Expression, name: String) { -// self.add_constraint(lhs - rhs, name); -// } - -// // Constrains that lhs and rhs encode the same value of SIZE bits -// // WARNING: Assumes that forall i, (lhs[i].1 < (2 ^ lhs[i].0)) -// // This needs to be constrained separately -// fn constrain_reps_eq( -// &mut self, -// lhs: &[(usize, Witness)], -// rhs: &[(usize, Witness)], -// name: String, -// ) { -// self.add_constraint( -// expansion_expr::(lhs) - expansion_expr::(rhs), -// name, -// ); -// } - -// /// Checks that `rot8` is equal to `input8` left-rotated by `delta`. -// /// `rot8` and `input8` each consist of 8 chunks of 8-bits. -// /// -// /// `split_rep` is a chunk representation of the input which -// /// allows to reduce the required rotation to an array rotation. It may use -// /// non-uniform chunks. -// /// -// /// For example, when `delta = 2`, the 64 bits are split into chunks of -// /// sizes `[16a, 14b, 2c, 16d, 14e, 2f]` (here the first chunks contains the -// /// least significant bits so a left rotation will become a right rotation -// /// of the array). To perform the required rotation, we can -// /// simply rotate the array: [2f, 16a, 14b, 2c, 16d, 14e]. -// /// -// /// In the first step, we check that `rot8` and `split_rep` represent the -// /// same 64 bits. In the second step we check that `rot8` and the appropiate -// /// array rotation of `split_rep` represent the same 64 bits. -// /// -// /// This type of representation-equality check is done by packing chunks -// /// into sizes of exactly 32 (so for `delta = 2` we compare [16a, 14b, -// /// 2c] to the first 4 elements of `rot8`). In addition, we do range -// /// checks on `split_rep` which check that the felts meet the required -// /// sizes. -// /// -// /// This algorithm imposes the following general requirements for -// /// `split_rep`: -// /// - There exists a suffix of `split_rep` which sums to exactly `delta`. -// /// This suffix can contain several elements. -// /// - Chunk sizes are at most 16 (so they can be range-checked) or they are -// /// exactly equal to 32. -// /// - There exists a prefix of chunks which sums exactly to 32. This must -// /// hold for the rotated array as well. -// /// - The number of chunks should be as small as possible. -// /// -// /// Consult the method `rotation_split` to see how splits are computed for a -// /// given `delta -// /// -// /// Note that the function imposes range checks on chunk values, but it -// /// makes two exceptions: -// /// 1. It doesn't check the 8-bit reps (input and output). This is -// /// because all 8-bit reps in the global circuit are implicitly -// /// range-checked because they are lookup arguments. -// /// 2. It doesn't range-check 32-bit chunks. This is because a 32-bit -// /// chunk value is checked to be equal to the composition of 4 8-bit -// /// chunks. As mentioned in 1., these can be trusted to be range -// /// checked, so the resulting 32-bit is correct by construction as -// /// well. -// fn constrain_left_rotation64( -// &mut self, -// input8: &[Witness], -// split_rep: &[(usize, Witness)], -// rot8: &[Witness], -// delta: usize, -// label: String, -// ) { -// assert_eq!(input8.len(), 8); -// assert_eq!(rot8.len(), 8); - -// // Assert that the given split witnesses are correct for this delta -// let (sizes, chunks_rotation) = rotation_split(delta); -// assert_eq!(sizes, split_rep.iter().map(|e| e.0).collect_vec()); - -// // Lookup ranges -// for (size, elem) in split_rep { -// if *size != 32 { -// self.lookup_range((*elem).into(), *size); -// } -// } - -// // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are -// // the same 64 bitstring -// let mut helper = |rep8: &[Witness], rep_x: &[(usize, Witness)], chunks_rotation: usize| { -// // Do the same thing for the two 32-bit halves -// let mut rep_x = rep_x.to_owned(); -// rep_x.rotate_right(chunks_rotation); - -// for i in 0..2 { -// // The respective 4 elements in the byte representation -// let lhs = rep8[4 * i..4 * (i + 1)] -// .iter() -// .map(|wit| (8, *wit)) -// .collect_vec(); -// let cnt = rep_x.len() / 2; -// let rhs = &rep_x[cnt * i..cnt * (i + 1)]; - -// assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); - -// self.constrain_reps_eq::<32>( -// &lhs, -// rhs, -// format!( -// "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", -// sizes -// ), -// ); -// } -// }; - -// helper(input8, split_rep, 0); -// helper(rot8, split_rep, chunks_rotation); -// } -// } - -// const ROUNDS: usize = 24; - -// const RC: [u64; ROUNDS] = [ -// 1u64, -// 0x8082u64, -// 0x800000000000808au64, -// 0x8000000080008000u64, -// 0x808bu64, -// 0x80000001u64, -// 0x8000000080008081u64, -// 0x8000000000008009u64, -// 0x8au64, -// 0x88u64, -// 0x80008009u64, -// 0x8000000au64, -// 0x8000808bu64, -// 0x800000000000008bu64, -// 0x8000000000008089u64, -// 0x8000000000008003u64, -// 0x8000000000008002u64, -// 0x8000000000000080u64, -// 0x800au64, -// 0x800000008000000au64, -// 0x8000000080008081u64, -// 0x8000000000008080u64, -// 0x80000001u64, -// 0x8000000080008008u64, -// ]; - -// const ROTATION_CONSTANTS: [[usize; 5]; 5] = [ -// [0, 1, 62, 28, 27], -// [36, 44, 6, 55, 20], -// [3, 10, 43, 25, 39], -// [41, 45, 15, 21, 8], -// [18, 2, 61, 56, 14], -// ]; - -// pub const KECCAK_INPUT_SIZE: usize = 50; -// pub const KECCAK_OUTPUT_SIZE: usize = 50; - -// pub const KECCAK_LAYER_BYTE_SIZE: usize = 200; - -// pub const AND_LOOKUPS_PER_ROUND: usize = 200; -// pub const XOR_LOOKUPS_PER_ROUND: usize = 608; -// pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; -// pub const LOOKUP_FELTS_PER_ROUND: usize = -// 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; - -// pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; -// pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; -// pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; - -// pub const KECCAK_OUT_EVAL_SIZE: usize = -// KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS; - -// pub const KECCAK_WIT_SIZE_PER_ROUND: usize = 1264; -// pub const KECCAK_WIT_SIZE: usize = KECCAK_WIT_SIZE_PER_ROUND * ROUNDS + KECCAK_LAYER_BYTE_SIZE; - -// #[allow(unused)] -// macro_rules! allocate_and_split { -// ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ -// let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); -// let mut iter = witnesses.into_iter(); -// ( -// $( -// iter.by_ref().take($size).collect_vec(), -// )* -// ) -// }}; -// } - -// macro_rules! split_from_offset { -// ($witnesses:expr, $offset:ident, $total:expr, $( $size:expr ),* ) => {{ -// let mut iter = $witnesses[$offset..].iter().cloned(); -// ( -// $( -// iter.by_ref().take($size).collect_vec(), -// )* -// ) -// }}; -// } - -// impl ProtocolBuilder for KeccakLayout { -// type Params = KeccakParams; - -// fn init(_params: Self::Params) -> Self { -// Self { -// ..Default::default() -// } -// } - -// fn build_commit_phase(&mut self, chip: &mut Chip) { -// let bases = chip.allocate_committed_base::(); -// self.keccak_input8 = bases[..KECCAK_LAYER_BYTE_SIZE].to_vec(); - -// let mut offset = KECCAK_LAYER_BYTE_SIZE; -// self.keccak_layers = array::from_fn(|_| { -// let ( -// c_aux, -// c_temp, -// c_rot, -// d, -// theta_output, -// rotation_witness, -// rhopi_output, -// nonlinear, -// chi_output, -// iota_output, -// ) = split_from_offset!( -// bases, -// offset, -// KECCAK_WIT_SIZE_PER_ROUND, -// 200, -// 30, -// 40, -// 40, -// 200, -// 146, -// 200, -// 200, -// 8, -// 200 -// ); -// offset += KECCAK_WIT_SIZE_PER_ROUND; -// KeccakLayerLayout { -// c_aux, -// c_temp, -// c_rot, -// d, -// theta_output, -// rotation_witness, -// rhopi_output, -// nonlinear, -// chi_output, -// iota_output, -// } -// }); -// } - -// fn build_gkr_phase(&mut self, chip: &mut Chip) { -// let final_outputs = -// chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS }>(); - -// let mut final_outputs_iter = final_outputs.iter(); - -// let [keccak_output32, keccak_input32, lookup_outputs] = [ -// KECCAK_OUTPUT_SIZE, -// KECCAK_INPUT_SIZE, -// LOOKUP_FELTS_PER_ROUND * ROUNDS, -// ] -// .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); - -// let lookup_outputs = lookup_outputs.to_vec(); - -// let (bases, []) = chip.allocate_wits_in_layer::(); -// for (openings, wit) in bases.iter().enumerate() { -// chip.allocate_base_opening(openings, wit.1.clone()); -// } - -// let keccak_input8 = &bases[..KECCAK_LAYER_BYTE_SIZE]; -// let keccak_output8 = &bases[KECCAK_WIT_SIZE - KECCAK_LAYER_BYTE_SIZE..KECCAK_WIT_SIZE]; - -// let mut system = ConstraintSystem::new(); - -// let mut offset = KECCAK_LAYER_BYTE_SIZE; -// let _ = (0..ROUNDS).fold(keccak_input8.to_vec(), |state8, round| { -// #[allow(non_snake_case)] -// let ( -// c_aux, -// c_temp, -// c_rot, -// d, -// theta_output, -// rotation_witness, -// rhopi_output, -// nonlinear, -// chi_output, -// iota_output, -// ) = split_from_offset!( -// bases, -// offset, -// KECCAK_WIT_SIZE_PER_ROUND, -// 200, -// 30, -// 40, -// 40, -// 200, -// 146, -// 200, -// 200, -// 8, -// 200 -// ); -// offset += KECCAK_WIT_SIZE_PER_ROUND; - -// { -// let n_wits = 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 8 + 200; -// assert_eq!(KECCAK_WIT_SIZE_PER_ROUND, n_wits); -// } - -// // TODO: ndarrays can be replaced with normal arrays - -// // Input state of the round in 8-bit chunks -// let state8: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), &state8).unwrap(); - -// // The purpose is to compute the auxiliary array -// // c[i] = XOR (state[j][i]) for j in 0..5 -// // We unroll it into -// // c_aux[i][j] = XOR (state[k][i]) for k in 0..j -// // We use c_aux[i][4] instead of c[i] -// // c_aux is also stored in 8-bit chunks -// let c_aux: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); - -// for i in 0..5 { -// for k in 0..8 { -// // Initialize first element -// system.constrain_eq( -// state8[[0, i, k]].0.into(), -// c_aux[[i, 0, k]].0.into(), -// "init c_aux".to_string(), -// ); -// } -// for j in 1..5 { -// // Check xor using lookups over all chunks -// for k in 0..8 { -// system.lookup_xor8( -// c_aux[[i, j - 1, k]].0.into(), -// state8[[j, i, k]].0.into(), -// c_aux[[i, j, k]].0.into(), -// ); -// } -// } -// } - -// // Compute c_rot[i] = c[i].rotate_left(1) -// // To understand how rotations are performed in general, consult the -// // documentation of `constrain_left_rotation64`. Here c_temp is the split -// // witness for a 1-rotation. - -// let c_temp: ArrayView<(Witness, EvalExpression), Ix2> = -// ArrayView::from_shape((5, 6), &c_temp).unwrap(); -// let c_rot: ArrayView<(Witness, EvalExpression), Ix2> = -// ArrayView::from_shape((5, 8), &c_rot).unwrap(); - -// let (sizes, _) = rotation_split(1); - -// for i in 0..5 { -// assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); - -// system.constrain_left_rotation64( -// &c_aux.slice(s![i, 4, ..]).iter().map(|e| e.0).collect_vec(), -// &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) -// .map(|(e, sz)| (*sz, e.0)) -// .collect_vec(), -// &c_rot.slice(s![i, ..]).iter().map(|e| e.0).collect_vec(), -// 1, -// "theta rotation".to_string(), -// ); -// } - -// // d is computed simply as XOR of required elements of c (and rotations) -// // again stored as 8-bit chunks -// let d: ArrayView<(Witness, EvalExpression), Ix2> = -// ArrayView::from_shape((5, 8), &d).unwrap(); - -// for i in 0..5 { -// for k in 0..8 { -// system.lookup_xor8( -// c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), -// c_rot[[(i + 1) % 5, k]].0.into(), -// d[[i, k]].0.into(), -// ) -// } -// } - -// // output state of the Theta sub-round, simple XOR, in 8-bit chunks -// let theta_output: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); - -// for i in 0..5 { -// for j in 0..5 { -// for k in 0..8 { -// system.lookup_xor8( -// state8[[j, i, k]].0.into(), -// d[[i, k]].0.into(), -// theta_output[[j, i, k]].0.into(), -// ) -// } -// } -// } - -// // output state after applying both Rho and Pi sub-rounds -// // sub-round Pi is a simple permutation of 64-bit lanes -// // sub-round Rho requires rotations -// let rhopi_output: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); - -// // iterator over split witnesses -// let mut rotation_witness = rotation_witness.iter(); - -// for i in 0..5 { -// #[allow(clippy::needless_range_loop)] -// for j in 0..5 { -// let arg = theta_output -// .slice(s!(j, i, ..)) -// .iter() -// .map(|e| e.0) -// .collect_vec(); -// let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); -// let many = sizes.len(); -// let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) -// .map(|(sz, (wit, _))| (sz, *wit)) -// .collect_vec(); -// let arg_rotated = rhopi_output -// .slice(s!((2 * i + 3 * j) % 5, j, ..)) -// .iter() -// .map(|e| e.0) -// .collect_vec(); -// system.constrain_left_rotation64( -// &arg, -// &rep_split, -// &arg_rotated, -// ROTATION_CONSTANTS[j][i], -// format!("RHOPI {i}, {j}"), -// ); -// } -// } - -// let mut chi_output = chi_output; -// chi_output.extend(iota_output[8..].to_vec()); -// let chi_output: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); - -// // for the Chi sub-round, we use an intermediate witness storing the result of -// // the required AND -// let nonlinear: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); - -// for i in 0..5 { -// for j in 0..5 { -// for k in 0..8 { -// system.lookup_and8( -// not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), -// rhopi_output[[j, (i + 2) % 5, k]].0.into(), -// nonlinear[[j, i, k]].0.into(), -// ); - -// system.lookup_xor8( -// rhopi_output[[j, i, k]].0.into(), -// nonlinear[[j, i, k]].0.into(), -// chi_output[[j, i, k]].0.into(), -// ); -// } -// } -// } - -// // TODO: 24/25 elements stay the same after Iota; eliminate duplication? -// let iota_output_arr: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); - -// for k in 0..8 { -// system.lookup_xor8( -// chi_output[[0, 0, k]].0.into(), -// Expression::Const(Constant::Base(((RC[round] >> (k * 8)) & 0xFF) as i64)), -// iota_output_arr[[0, 0, k]].0.into(), -// ); -// } - -// iota_output -// }); - -// let mut global_and_lookup = 0; -// let mut global_xor_lookup = 3 * AND_LOOKUPS; -// let mut global_range_lookup = 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS; - -// let ConstraintSystem { -// mut expressions, -// mut expr_names, -// mut evals, -// and_lookups, -// xor_lookups, -// range_lookups, -// .. -// } = system; - -// for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) -// .flatten() -// .enumerate() -// { -// expressions.push(lookup); -// let (idx, round) = if i < 3 * AND_LOOKUPS { -// let round = i / AND_LOOKUPS; -// (&mut global_and_lookup, round) -// } else if i < 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS { -// let round = (i - 3 * AND_LOOKUPS) / XOR_LOOKUPS; -// (&mut global_xor_lookup, round) -// } else { -// let round = (i - 3 * AND_LOOKUPS - 3 * XOR_LOOKUPS) / RANGE_LOOKUPS; -// (&mut global_range_lookup, round) -// }; -// expr_names.push(format!("round {round}: {i}th lookup felt")); -// evals.push(lookup_outputs[*idx].clone()); -// *idx += 1; -// } - -// assert!(global_and_lookup == 3 * AND_LOOKUPS); -// assert!(global_xor_lookup == 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); -// assert!(global_range_lookup == LOOKUP_FELTS_PER_ROUND * ROUNDS); - -// let keccak_input8: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), keccak_input8).unwrap(); -// let keccak_input32 = keccak_input32.to_vec(); -// let mut keccak_input32_iter = keccak_input32.iter().cloned(); - -// for x in 0..5 { -// for y in 0..5 { -// for k in 0..2 { -// // create an expression combining 4 elements of state8 into a single 32-bit felt -// let expr = expansion_expr::<32>( -// keccak_input8 -// .slice(s![x, y, 4 * k..4 * (k + 1)]) -// .iter() -// .map(|e| (8, e.0)) -// .collect_vec() -// .as_slice(), -// ); -// expressions.push(expr); -// evals.push(keccak_input32_iter.next().unwrap().clone()); -// expr_names.push(format!("build 32-bit input: {x}, {y}, {k}")); -// } -// } -// } - -// let keccak_output32 = keccak_output32.to_vec(); -// let keccak_output8: ArrayView<(Witness, EvalExpression), Ix3> = -// ArrayView::from_shape((5, 5, 8), keccak_output8).unwrap(); -// let mut keccak_output32_iter = keccak_output32.iter().cloned(); - -// for x in 0..5 { -// for y in 0..5 { -// for k in 0..2 { -// // create an expression combining 4 elements of state8 into a single 32-bit felt -// let expr = expansion_expr::<32>( -// &keccak_output8 -// .slice(s![x, y, 4 * k..4 * (k + 1)]) -// .iter() -// .map(|e| (8, e.0)) -// .collect_vec(), -// ); -// expressions.push(expr); -// evals.push(keccak_output32_iter.next().unwrap().clone()); -// expr_names.push(format!("build 32-bit output: {x}, {y}, {k}")); -// } -// } -// } - -// chip.add_layer(Layer::new( -// "Rounds".to_string(), -// LayerType::Zerocheck, -// expressions, -// vec![], -// bases.into_iter().map(|e| e.1).collect_vec(), -// vec![], -// evals, -// expr_names, -// )); -// } -// } - -// #[derive(Clone, Default)] -// pub struct KeccakTrace { -// pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, -// } - -// impl ProtocolWitnessGenerator for KeccakLayout -// where -// E: ExtensionField, -// { -// type Trace = KeccakTrace; - -// fn phase1_witness(&self, phase1: Self::Trace) -> RowMajorMatrix { -// let instances = &phase1.instances; -// let num_instances = instances.len(); - -// let wits: Vec<_> = (0..num_instances) -// .into_par_iter() -// .flat_map(|instance_id| { -// fn conv64to8(input: u64) -> [u64; 8] { -// MaskRepresentation::new(vec![(64, input).into()]) -// .convert(vec![8; 8]) -// .values() -// .try_into() -// .unwrap() -// } - -// let state32 = instances[instance_id] -// .iter() -// .map(|&e| e as u64) -// .collect_vec(); - -// let mut state64 = [[0u64; 5]; 5]; -// let mut state8 = [[[0u64; 8]; 5]; 5]; - -// zip_eq(iproduct!(0..5, 0..5), state32.iter().tuples()) -// .map(|((x, y), (&lo, &hi))| { -// state64[x][y] = lo | (hi << 32); -// }) -// .count(); - -// for x in 0..5 { -// for y in 0..5 { -// state8[x][y] = conv64to8(state64[x][y]); -// } -// } - -// let mut wits = Vec::with_capacity(KECCAK_WIT_SIZE_PER_ROUND); -// let mut push_instance = |new_wits: Vec| { -// let felts = u64s_to_felts::(new_wits); -// wits.extend(felts); -// }; - -// push_instance(state8.into_iter().flatten().flatten().collect_vec()); - -// #[allow(clippy::needless_range_loop)] -// for round in 0..ROUNDS { -// let mut c_aux64 = [[0u64; 5]; 5]; -// let mut c_aux8 = [[[0u64; 8]; 5]; 5]; - -// for i in 0..5 { -// c_aux64[i][0] = state64[0][i]; -// c_aux8[i][0] = conv64to8(c_aux64[i][0]); -// for j in 1..5 { -// c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; -// c_aux8[i][j] = conv64to8(c_aux64[i][j]); -// } -// } - -// let mut c64 = [0u64; 5]; -// let mut c8 = [[0u64; 8]; 5]; - -// for x in 0..5 { -// c64[x] = c_aux64[x][4]; -// c8[x] = conv64to8(c64[x]); -// } - -// let mut c_temp = [[0u64; 6]; 5]; -// for i in 0..5 { -// let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) -// .convert(vec![16, 15, 1, 16, 15, 1]); -// c_temp[i] = rep.values().try_into().unwrap(); -// } - -// let mut crot64 = [0u64; 5]; -// let mut crot8 = [[0u64; 8]; 5]; -// for i in 0..5 { -// crot64[i] = c64[i].rotate_left(1); -// crot8[i] = conv64to8(crot64[i]); -// } - -// let mut d64 = [0u64; 5]; -// let mut d8 = [[0u64; 8]; 5]; -// for x in 0..5 { -// d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); -// d8[x] = conv64to8(d64[x]); -// } - -// let mut theta_state64 = state64; -// let mut theta_state8 = [[[0u64; 8]; 5]; 5]; -// let mut rotation_witness = vec![]; - -// for x in 0..5 { -// for y in 0..5 { -// theta_state64[y][x] ^= d64[x]; -// theta_state8[y][x] = conv64to8(theta_state64[y][x]); - -// let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); -// let rep = -// MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) -// .convert(sizes); -// rotation_witness.extend(rep.values()); -// } -// } - -// // Rho and Pi steps -// let mut rhopi_output64 = [[0u64; 5]; 5]; -// let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; - -// for x in 0..5 { -// for y in 0..5 { -// rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = -// theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); -// } -// } - -// for x in 0..5 { -// for y in 0..5 { -// rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]); -// } -// } - -// // Chi step -// let mut nonlinear64 = [[0u64; 5]; 5]; -// let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; -// for x in 0..5 { -// for y in 0..5 { -// nonlinear64[y][x] = -// !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; -// nonlinear8[y][x] = conv64to8(nonlinear64[y][x]); -// } -// } - -// let mut chi_output64 = [[0u64; 5]; 5]; -// let mut chi_output8 = [[[0u64; 8]; 5]; 5]; -// for x in 0..5 { -// for y in 0..5 { -// chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; -// chi_output8[y][x] = conv64to8(chi_output64[y][x]); -// } -// } - -// // Iota step -// let mut iota_output64 = chi_output64; -// let mut iota_output8 = [[[0u64; 8]; 5]; 5]; -// iota_output64[0][0] ^= RC[round]; - -// for x in 0..5 { -// for y in 0..5 { -// iota_output8[x][y] = conv64to8(iota_output64[x][y]); -// } -// } - -// let all_wits64 = [ -// c_aux8.into_iter().flatten().flatten().collect_vec(), -// c_temp.into_iter().flatten().collect_vec(), -// crot8.into_iter().flatten().collect_vec(), -// d8.into_iter().flatten().collect_vec(), -// theta_state8.into_iter().flatten().flatten().collect_vec(), -// rotation_witness, -// rhopi_output8.into_iter().flatten().flatten().collect_vec(), -// nonlinear8.into_iter().flatten().flatten().collect_vec(), -// chi_output8[0][0].to_vec(), -// iota_output8.into_iter().flatten().flatten().collect_vec(), -// ] -// .into_iter() -// .flatten() -// .collect_vec(); - -// assert_eq!(all_wits64.len(), KECCAK_WIT_SIZE_PER_ROUND); -// push_instance(all_wits64); - -// state64 = iota_output64; -// } - -// wits -// }) -// .collect(); -// RowMajorMatrix::new_by_values(wits, KECCAK_WIT_SIZE, InstancePaddingStrategy::Default) -// } - -// fn gkr_witness( -// &self, -// phase1: &RowMajorMatrix, -// _challenges: &[E], -// ) -> (GKRCircuitWitness, GKRCircuitOutput) { -// // TODO: Make it more efficient. -// let instances = phase1 -// .values -// .par_iter() -// .map(|wit| wit.to_canonical_u64()) -// .collect::>(); -// let num_instances = phase1.num_vars(); -// let num_cols = phase1.n_col(); -// assert_eq!(num_cols, KECCAK_WIT_SIZE); - -// let to_5x5x8_array = |input: &[u64]| -> [[[u64; 8]; 5]; 5] { -// assert_eq!(input.len(), 5 * 5 * 8); -// input -// .chunks(40) -// .map(|chunk| { -// chunk -// .chunks(8) -// .map(|x| x.to_vec().try_into().unwrap()) -// .collect_vec() -// .try_into() -// .unwrap() -// }) -// .collect_vec() -// .try_into() -// .unwrap() -// }; -// let to_5x8_array = |input: &[u64]| -> [[u64; 8]; 5] { -// input -// .chunks(8) -// .map(|x| x.to_vec().try_into().unwrap()) -// .collect_vec() -// .try_into() -// .unwrap() -// }; -// let u8_slice_to_u64 = -// |input: &[u64]| -> u64 { input.iter().rev().fold(0, |acc, &e| (acc << 8) | e) }; -// let u8_slice_to_u32_slice = |input: &[u64]| -> [u64; 2] { -// input -// .chunks(4) -// .map(u8_slice_to_u64) -// .collect_vec() -// .try_into() -// .unwrap() -// }; - -// let output_bases: Vec = (0..num_instances) -// .into_par_iter() -// .flat_map(|instance_id| { -// let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; -// let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; -// let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; - -// let mut add_and = |a: u64, b: u64, round: usize| { -// let c = a & b; -// assert!(a < (1 << 8)); -// assert!(b < (1 << 8)); -// and_lookups[round].extend(vec![a, b, c]); -// }; - -// let mut add_xor = |a: u64, b: u64, round: usize| { -// let c = a ^ b; -// assert!(a < (1 << 8)); -// assert!(b < (1 << 8)); -// xor_lookups[round].extend(vec![a, b, c]); -// }; - -// let mut add_range = |value: u64, size: usize, round: usize| { -// assert!(size <= 16, "{size}"); -// range_lookups[round].push(value); -// if size < 16 { -// range_lookups[round].push(value << (16 - size)); -// assert!(value << (16 - size) < (1 << 16)); -// } -// }; - -// let mut state8: [[[u64; 8]; 5]; 5] = to_5x5x8_array( -// &instances -// [instance_id * num_cols..instance_id * num_cols + KECCAK_LAYER_BYTE_SIZE], -// ); -// let mut keccak_input32 = [[[0u64; 2]; 5]; 5]; -// for x in 0..5 { -// for y in 0..5 { -// keccak_input32[x][y] = u8_slice_to_u32_slice(&state8[x][y]); -// } -// } -// let mut offset = KECCAK_LAYER_BYTE_SIZE; -// #[allow(clippy::needless_range_loop)] -// for round in 0..ROUNDS { -// let ( -// c_aux8, -// _c_temp, -// crot8, -// d8, -// theta_state8, -// _rotation_witness, -// rhopi_output8, -// nonlinear8, -// chi_output8, -// iota_output8, -// ) = split_from_offset!( -// instances[instance_id * num_cols..(instance_id + 1) * num_cols], -// offset, -// KECCAK_WIT_SIZE_PER_ROUND, -// 200, -// 30, -// 40, -// 40, -// 200, -// 146, -// 200, -// 200, -// 8, -// 200 -// ); -// offset += KECCAK_WIT_SIZE_PER_ROUND; -// let c_aux8 = to_5x5x8_array(&c_aux8); - -// for i in 0..5 { -// for j in 1..5 { -// for k in 0..8 { -// add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); -// } -// } -// } - -// let mut c8 = [[0u64; 8]; 5]; -// let mut c64 = [0u64; 5]; - -// for x in 0..5 { -// c8[x] = c_aux8[x][4]; -// c64[x] = u8_slice_to_u64(&c8[x]); -// } - -// for i in 0..5 { -// let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) -// .convert(vec![16, 15, 1, 16, 15, 1]); -// for mask in rep.rep { -// add_range(mask.value, mask.size, round); -// } -// } - -// let crot8 = to_5x8_array(&crot8); -// let d8 = to_5x8_array(&d8); -// for x in 0..5 { -// for k in 0..8 { -// add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); -// } -// } - -// let theta_state8 = to_5x5x8_array(&theta_state8); -// let mut theta_state64 = [[0u64; 5]; 5]; -// for x in 0..5 { -// for y in 0..5 { -// theta_state64[y][x] = u8_slice_to_u64(&theta_state8[y][x]); -// } -// } - -// for x in 0..5 { -// for y in 0..5 { -// for k in 0..8 { -// add_xor(state8[y][x][k], d8[x][k], round); -// } - -// let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); -// let rep = -// MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) -// .convert(sizes); -// for mask in rep.rep.iter() { -// if mask.size != 32 { -// add_range(mask.value, mask.size, round); -// } -// } -// } -// } - -// // Rho and Pi steps -// let rhopi_output8 = to_5x5x8_array(&rhopi_output8); - -// // Chi step -// let nonlinear8 = to_5x5x8_array(&nonlinear8); -// for x in 0..5 { -// for y in 0..5 { -// for k in 0..8 { -// add_and( -// 0xFF - rhopi_output8[y][(x + 1) % 5][k], -// rhopi_output8[y][(x + 2) % 5][k], -// round, -// ); -// } -// } -// } - -// for x in 0..5 { -// for y in 0..5 { -// for k in 0..8 { -// add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) -// } -// } -// } - -// // Iota step -// let chi_output8: [u64; 8] = chi_output8.try_into().unwrap(); // only save chi_output8[0][0]; -// let iota_output8 = to_5x5x8_array(&iota_output8); -// for k in 0..8 { -// add_xor(chi_output8[k], (RC[round] >> (k * 8)) & 0xFF, round); -// } - -// state8 = iota_output8; -// } - -// let mut keccak_output32 = [[[0u64; 2]; 5]; 5]; -// for x in 0..5 { -// for y in 0..5 { -// keccak_output32[x][y] = u8_slice_to_u32_slice(&state8[x][y]); -// } -// } - -// chain!( -// keccak_output32.into_iter().flatten().flatten(), -// keccak_input32.into_iter().flatten().flatten(), -// (0..ROUNDS).rev().flat_map(|i| and_lookups[i].clone()), -// (0..ROUNDS).rev().flat_map(|i| xor_lookups[i].clone()), -// (0..ROUNDS).rev().flat_map(|i| range_lookups[i].clone()) -// ) -// .collect_vec() -// }) -// .collect(); - -// let bases = phase1.to_cols_base::(); -// let output_bases = RowMajorMatrix::new_by_values( -// output_bases -// .into_iter() -// .map(E::BaseField::from_u64) -// .collect(), -// KECCAK_OUT_EVAL_SIZE, -// InstancePaddingStrategy::Default, -// ) -// .to_cols_base::(); - -// ( -// GKRCircuitWitness { -// layers: vec![LayerWitness { -// bases, -// ..Default::default() -// }], -// }, -// GKRCircuitOutput(LayerWitness { -// bases: output_bases, -// ..Default::default() -// }), -// ) -// } -// } - -// pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) { -// let params = KeccakParams {}; -// let (layout, chip) = KeccakLayout::build(params); - -// let mut instances = vec![]; -// for state in &states { -// let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); -// let state_mask32 = state_mask64.convert(vec![32; 50]); - -// instances.push( -// state_mask32 -// .values() -// .iter() -// .map(|e| *e as u32) -// .collect_vec() -// .try_into() -// .unwrap(), -// ); -// } - -// let num_instances = instances.len(); -// let phase1_witness = layout.phase1_witness(KeccakTrace { -// instances: instances.clone(), -// }); - -// let mut prover_transcript = BasicTranscript::::new(b"protocol"); - -// // Omit the commit phase1 and phase2. -// let (gkr_witness, _gkr_output) = layout.gkr_witness(&phase1_witness, &[]); - -// let out_evals = { -// let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); -// let point = Arc::new(vec![E::from_u64(29); log2_num_instances as usize]); - -// if test_outputs { -// // Confront outputs with tiny_keccak::keccakf call -// let mut instance_outputs = vec![vec![]; num_instances]; -// for base in gkr_witness -// .layers -// .last() -// .unwrap() -// .bases -// .iter() -// .take(KECCAK_OUTPUT_SIZE) -// { -// assert_eq!(base.len(), num_instances); -// for i in 0..num_instances { -// instance_outputs[i].push(base[i]); -// } -// } - -// for i in 0..num_instances { -// let mut state = states[i]; -// keccakf(&mut state); -// assert_eq!( -// state -// .to_vec() -// .iter() -// .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) -// .map(|e| Goldilocks::from_u64(e as u64)) -// .collect_vec(), -// instance_outputs[i] -// ); -// } -// } - -// let out_evals = gkr_witness -// .layers -// .last() -// .unwrap() -// .bases -// .iter() -// .map(|base| PointAndEval { -// point: point.clone(), -// eval: subprotocols::utils::evaluate_mle_ext(base, &point), -// }) -// .collect_vec(); - -// assert_eq!(out_evals.len(), KECCAK_OUT_EVAL_SIZE); - -// out_evals -// }; - -// let gkr_circuit = chip.gkr_circuit(); -// dbg!(&gkr_circuit.layers.len()); -// let GKRProverOutput { gkr_proof, .. } = gkr_circuit -// .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) -// .expect("Failed to prove phase"); - -// if verify { -// { -// let mut verifier_transcript = BasicTranscript::::new(b"protocol"); - -// gkr_circuit -// .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) -// .expect("GKR verify failed"); - -// // Omit the PCS opening phase. -// } -// } -// } - -// #[cfg(test)] -// mod tests { -// use super::*; -// use rand::{Rng, SeedableRng}; - -// #[test] -// fn test_keccakf() { -// std::thread::Builder::new() -// .name("keccak_test".into()) -// .stack_size(64 * 1024 * 1024) -// .spawn(|| { -// let mut rng = rand::rngs::StdRng::seed_from_u64(42); - -// let num_instances = 8; -// let mut states: Vec<[u64; 25]> = vec![]; -// for _ in 0..num_instances { -// states.push(std::array::from_fn(|_| rng.gen())); -// } - -// run_faster_keccakf(states, true, true); -// }) -// .unwrap() -// .join() -// .unwrap(); -// } - -// #[ignore] -// #[test] -// fn test_keccakf_nonpow2() { -// std::thread::Builder::new() -// .name("keccak_test".into()) -// .stack_size(64 * 1024 * 1024) -// .spawn(|| { -// let mut rng = rand::rngs::StdRng::seed_from_u64(42); - -// let num_instances = 5; -// let mut states: Vec<[u64; 25]> = vec![]; -// for _ in 0..num_instances { -// states.push(std::array::from_fn(|_| rng.gen())); -// } - -// run_faster_keccakf(states, true, true); -// }) -// .unwrap() -// .join() -// .unwrap(); -// } -// } +use std::{array, cmp::Ordering, marker::PhantomData, sync::Arc}; + +use ff_ext::{ExtensionField, SmallField}; +use itertools::{Itertools, chain, iproduct, zip_eq}; +use multilinear_extensions::{Expression, ToExpr, WitIn}; +use ndarray::{ArrayView, Ix2, Ix3, s}; +use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; +use p3_goldilocks::Goldilocks; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use serde::{Deserialize, Serialize}; +use tiny_keccak::keccakf; +use transcript::BasicTranscript; +use witness::{InstancePaddingStrategy, RowMajorMatrix}; + +use crate::{ + ProtocolBuilder, ProtocolWitnessGenerator, + chip::Chip, + evaluation::EvalExpression, + gkr::{ + GKRCircuitOutput, GKRCircuitWitness, GKRProverOutput, + layer::{Layer, LayerType, LayerWitness}, + }, + precompiles::utils::{MaskRepresentation, not8_expr}, +}; + +use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; + +type E = BinomialExtensionField; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct KeccakParams {} + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct KeccakLayerLayout { + c_aux: Vec, + c_temp: Vec, + c_rot: Vec, + d: Vec, + theta_output: Vec, + rotation_witness: Vec, + rhopi_output: Vec, + nonlinear: Vec, + chi_output: Vec, + iota_output: Vec, +} + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct KeccakLayout { + keccak_input8: Vec, + keccak_layers: [KeccakLayerLayout; ROUNDS], + _marker: PhantomData, +} + +fn expansion_expr( + expansion: &[(usize, Expression)], +) -> Expression { + let (total, ret) = + expansion + .iter() + .rev() + .fold((0, E::BaseField::ZERO.expr()), |acc, (sz, felt)| { + ( + acc.0 + sz, + acc.1 * E::BaseField::from_u64((1 << sz) as u64).expr() + felt.expr(), + ) + }); + + assert_eq!(total, SIZE); + ret +} + +/// Compute an adequate split of 64-bits into chunks for performing a rotation +/// by `delta`. The first element of the return value is the vec of chunk sizes. +/// The second one is the length of its suffix that needs to be rotated +fn rotation_split(delta: usize) -> (Vec, usize) { + let delta = delta % 64; + + if delta == 0 { + return (vec![32, 32], 0); + } + + // This split meets all requirements except for <= 16 sizes + let split32 = match delta.cmp(&32) { + Ordering::Less => vec![32 - delta, delta, 32 - delta, delta], + Ordering::Equal => vec![32, 32], + Ordering::Greater => vec![32 - (delta - 32), delta - 32, 32 - (delta - 32), delta - 32], + }; + + // Split off large chunks + let split16 = split32 + .into_iter() + .flat_map(|size| { + assert!(size < 32); + if size <= 16 { + vec![size] + } else { + vec![16, size - 16] + } + }) + .collect_vec(); + + let mut sum = 0; + for (i, size) in split16.iter().rev().enumerate() { + sum += size; + if sum == delta { + return (split16, i + 1); + } + } + + panic!(); +} + +struct ConstraintSystem { + expressions: Vec>, + expr_names: Vec, + evals: Vec>, + and_lookups: Vec>, + xor_lookups: Vec>, + range_lookups: Vec>, +} + +impl ConstraintSystem { + fn new() -> Self { + ConstraintSystem { + expressions: vec![], + evals: vec![], + expr_names: vec![], + and_lookups: vec![], + xor_lookups: vec![], + range_lookups: vec![], + } + } + + fn add_constraint(&mut self, expr: Expression, name: String) { + self.expressions.push(expr); + self.evals.push(zero_eval()); + self.expr_names.push(name); + } + + fn lookup_and8(&mut self, a: Expression, b: Expression, c: Expression) { + self.and_lookups.push(CenoLookup::And(a, b, c)); + } + + fn lookup_xor8(&mut self, a: Expression, b: Expression, c: Expression) { + self.xor_lookups.push(CenoLookup::Xor(a, b, c)); + } + + /// Generates U16 lookups to prove that `value` fits on `size < 16` bits. + /// In general it can be done by two U16 checks: one for `value` and one for + /// `value << (16 - size)`. + fn lookup_range(&mut self, value: Expression, size: usize) { + assert!(size <= 16); + self.range_lookups.push(CenoLookup::U16(value.clone())); + if size < 16 { + self.range_lookups.push(CenoLookup::U16( + value * E::BaseField::from_u64((1 << (16 - size)) as u64).expr(), + )) + } + } + + fn constrain_eq(&mut self, lhs: Expression, rhs: Expression, name: String) { + self.add_constraint(lhs - rhs, name); + } + + // Constrains that lhs and rhs encode the same value of SIZE bits + // WARNING: Assumes that forall i, (lhs[i].1 < (2 ^ lhs[i].0)) + // This needs to be constrained separately + fn constrain_reps_eq( + &mut self, + lhs: &[(usize, Expression)], + rhs: &[(usize, Expression)], + name: String, + ) { + self.add_constraint( + expansion_expr::(lhs) - expansion_expr::(rhs), + name, + ); + } + + /// Checks that `rot8` is equal to `input8` left-rotated by `delta`. + /// `rot8` and `input8` each consist of 8 chunks of 8-bits. + /// + /// `split_rep` is a chunk representation of the input which + /// allows to reduce the required rotation to an array rotation. It may use + /// non-uniform chunks. + /// + /// For example, when `delta = 2`, the 64 bits are split into chunks of + /// sizes `[16a, 14b, 2c, 16d, 14e, 2f]` (here the first chunks contains the + /// least significant bits so a left rotation will become a right rotation + /// of the array). To perform the required rotation, we can + /// simply rotate the array: [2f, 16a, 14b, 2c, 16d, 14e]. + /// + /// In the first step, we check that `rot8` and `split_rep` represent the + /// same 64 bits. In the second step we check that `rot8` and the appropiate + /// array rotation of `split_rep` represent the same 64 bits. + /// + /// This type of representation-equality check is done by packing chunks + /// into sizes of exactly 32 (so for `delta = 2` we compare [16a, 14b, + /// 2c] to the first 4 elements of `rot8`). In addition, we do range + /// checks on `split_rep` which check that the felts meet the required + /// sizes. + /// + /// This algorithm imposes the following general requirements for + /// `split_rep`: + /// - There exists a suffix of `split_rep` which sums to exactly `delta`. + /// This suffix can contain several elements. + /// - Chunk sizes are at most 16 (so they can be range-checked) or they are + /// exactly equal to 32. + /// - There exists a prefix of chunks which sums exactly to 32. This must + /// hold for the rotated array as well. + /// - The number of chunks should be as small as possible. + /// + /// Consult the method `rotation_split` to see how splits are computed for a + /// given `delta + /// + /// Note that the function imposes range checks on chunk values, but it + /// makes two exceptions: + /// 1. It doesn't check the 8-bit reps (input and output). This is + /// because all 8-bit reps in the global circuit are implicitly + /// range-checked because they are lookup arguments. + /// 2. It doesn't range-check 32-bit chunks. This is because a 32-bit + /// chunk value is checked to be equal to the composition of 4 8-bit + /// chunks. As mentioned in 1., these can be trusted to be range + /// checked, so the resulting 32-bit is correct by construction as + /// well. + fn constrain_left_rotation64( + &mut self, + input8: &[Expression], + split_rep: &[(usize, Expression)], + rot8: &[Expression], + delta: usize, + label: String, + ) { + assert_eq!(input8.len(), 8); + assert_eq!(rot8.len(), 8); + + // Assert that the given split witnesses are correct for this delta + let (sizes, chunks_rotation) = rotation_split(delta); + assert_eq!(sizes, split_rep.iter().map(|e| e.0).collect_vec()); + + // Lookup ranges + for (size, elem) in split_rep { + if *size != 32 { + self.lookup_range(elem.expr(), *size); + } + } + + // constrain the fact that rep8 and repX.rotate_left(chunks_rotation) are + // the same 64 bitstring + let mut helper = |rep8: &[Expression], + rep_x: &[(usize, Expression)], + chunks_rotation: usize| { + // Do the same thing for the two 32-bit halves + let mut rep_x = rep_x.to_owned(); + rep_x.rotate_right(chunks_rotation); + + for i in 0..2 { + // The respective 4 elements in the byte representation + let lhs = rep8[4 * i..4 * (i + 1)] + .iter() + .map(|wit| (8, wit.expr())) + .collect_vec(); + let cnt = rep_x.len() / 2; + let rhs = &rep_x[cnt * i..cnt * (i + 1)]; + + assert_eq!(rhs.iter().map(|e| e.0).sum::(), 32); + + self.constrain_reps_eq::<32>( + &lhs, + rhs, + format!( + "rotation internal {label}, round {i}, rot: {chunks_rotation}, delta: {delta}, {:?}", + sizes + ), + ); + } + }; + + helper(input8, split_rep, 0); + helper(rot8, split_rep, chunks_rotation); + } +} + +const ROUNDS: usize = 24; + +const RC: [u64; ROUNDS] = [ + 1u64, + 0x8082u64, + 0x800000000000808au64, + 0x8000000080008000u64, + 0x808bu64, + 0x80000001u64, + 0x8000000080008081u64, + 0x8000000000008009u64, + 0x8au64, + 0x88u64, + 0x80008009u64, + 0x8000000au64, + 0x8000808bu64, + 0x800000000000008bu64, + 0x8000000000008089u64, + 0x8000000000008003u64, + 0x8000000000008002u64, + 0x8000000000000080u64, + 0x800au64, + 0x800000008000000au64, + 0x8000000080008081u64, + 0x8000000000008080u64, + 0x80000001u64, + 0x8000000080008008u64, +]; + +const ROTATION_CONSTANTS: [[usize; 5]; 5] = [ + [0, 1, 62, 28, 27], + [36, 44, 6, 55, 20], + [3, 10, 43, 25, 39], + [41, 45, 15, 21, 8], + [18, 2, 61, 56, 14], +]; + +pub const KECCAK_INPUT_SIZE: usize = 50; +pub const KECCAK_OUTPUT_SIZE: usize = 50; + +pub const KECCAK_LAYER_BYTE_SIZE: usize = 200; + +pub const AND_LOOKUPS_PER_ROUND: usize = 200; +pub const XOR_LOOKUPS_PER_ROUND: usize = 608; +pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; +pub const LOOKUP_FELTS_PER_ROUND: usize = + 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; + +pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; +pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; +pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; + +pub const KECCAK_OUT_EVAL_SIZE: usize = + KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS; + +pub const KECCAK_WIT_SIZE_PER_ROUND: usize = 1264; +pub const KECCAK_WIT_SIZE: usize = KECCAK_WIT_SIZE_PER_ROUND * ROUNDS + KECCAK_LAYER_BYTE_SIZE; + +#[allow(unused)] +macro_rules! allocate_and_split { + ($chip:expr, $total:expr, $( $size:expr ),* ) => {{ + let (witnesses, _) = $chip.allocate_wits_in_layer::<$total, 0>(); + let mut iter = witnesses.into_iter(); + ( + $( + iter.by_ref().take($size).collect_vec(), + )* + ) + }}; +} + +macro_rules! split_from_offset { + ($witnesses:expr, $offset:ident, $total:expr, $( $size:expr ),* ) => {{ + let mut iter = $witnesses[$offset..].iter().cloned(); + ( + $( + iter.by_ref().take($size).collect_vec(), + )* + ) + }}; +} + +impl ProtocolBuilder for KeccakLayout { + type Params = KeccakParams; + + fn init(_params: Self::Params) -> Self { + Self { + ..Default::default() + } + } + + fn build_commit_phase(&mut self, chip: &mut Chip) { + let bases = chip.allocate_committed::(); + self.keccak_input8 = bases[..KECCAK_LAYER_BYTE_SIZE].to_vec(); + + let mut offset = KECCAK_LAYER_BYTE_SIZE; + self.keccak_layers = array::from_fn(|_| { + let ( + c_aux, + c_temp, + c_rot, + d, + theta_output, + rotation_witness, + rhopi_output, + nonlinear, + chi_output, + iota_output, + ) = split_from_offset!( + bases, + offset, + KECCAK_WIT_SIZE_PER_ROUND, + 200, + 30, + 40, + 40, + 200, + 146, + 200, + 200, + 8, + 200 + ); + offset += KECCAK_WIT_SIZE_PER_ROUND; + KeccakLayerLayout { + c_aux, + c_temp, + c_rot, + d, + theta_output, + rotation_witness, + rhopi_output, + nonlinear, + chi_output, + iota_output, + } + }); + } + + fn build_gkr_phase(&mut self, chip: &mut Chip) { + let final_outputs = + chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS }>(); + + let mut final_outputs_iter = final_outputs.iter(); + + let [keccak_output32, keccak_input32, lookup_outputs] = [ + KECCAK_OUTPUT_SIZE, + KECCAK_INPUT_SIZE, + LOOKUP_FELTS_PER_ROUND * ROUNDS, + ] + .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); + + let lookup_outputs = lookup_outputs.to_vec(); + + let bases = chip.allocate_wits_in_layer::(); + for (openings, wit) in bases.iter().enumerate() { + chip.allocate_opening(openings, wit.1.clone()); + } + + let keccak_input8 = &bases[..KECCAK_LAYER_BYTE_SIZE]; + let keccak_output8 = &bases[KECCAK_WIT_SIZE - KECCAK_LAYER_BYTE_SIZE..KECCAK_WIT_SIZE]; + + let mut system = ConstraintSystem::new(); + + let mut offset = KECCAK_LAYER_BYTE_SIZE; + let _ = (0..ROUNDS).fold(keccak_input8.to_vec(), |state8, round| { + #[allow(non_snake_case)] + let ( + c_aux, + c_temp, + c_rot, + d, + theta_output, + rotation_witness, + rhopi_output, + nonlinear, + chi_output, + iota_output, + ) = split_from_offset!( + bases, + offset, + KECCAK_WIT_SIZE_PER_ROUND, + 200, + 30, + 40, + 40, + 200, + 146, + 200, + 200, + 8, + 200 + ); + offset += KECCAK_WIT_SIZE_PER_ROUND; + + { + let n_wits = 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 8 + 200; + assert_eq!(KECCAK_WIT_SIZE_PER_ROUND, n_wits); + } + + // TODO: ndarrays can be replaced with normal arrays + + // Input state of the round in 8-bit chunks + let state8: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + + // The purpose is to compute the auxiliary array + // c[i] = XOR (state[j][i]) for j in 0..5 + // We unroll it into + // c_aux[i][j] = XOR (state[k][i]) for k in 0..j + // We use c_aux[i][4] instead of c[i] + // c_aux is also stored in 8-bit chunks + let c_aux: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); + + for i in 0..5 { + for k in 0..8 { + // Initialize first element + system.constrain_eq( + state8[[0, i, k]].0.into(), + c_aux[[i, 0, k]].0.into(), + "init c_aux".to_string(), + ); + } + for j in 1..5 { + // Check xor using lookups over all chunks + for k in 0..8 { + system.lookup_xor8( + c_aux[[i, j - 1, k]].0.into(), + state8[[j, i, k]].0.into(), + c_aux[[i, j, k]].0.into(), + ); + } + } + } + + // Compute c_rot[i] = c[i].rotate_left(1) + // To understand how rotations are performed in general, consult the + // documentation of `constrain_left_rotation64`. Here c_temp is the split + // witness for a 1-rotation. + + let c_temp: ArrayView<(WitIn, EvalExpression), Ix2> = + ArrayView::from_shape((5, 6), &c_temp).unwrap(); + let c_rot: ArrayView<(WitIn, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &c_rot).unwrap(); + + let (sizes, _) = rotation_split(1); + + for i in 0..5 { + assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); + + system.constrain_left_rotation64( + &c_aux + .slice(s![i, 4, ..]) + .iter() + .map(|e| e.0.expr()) + .collect_vec(), + &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) + .map(|(e, sz)| (*sz, e.0.expr())) + .collect_vec(), + &c_rot + .slice(s![i, ..]) + .iter() + .map(|e| e.0.expr()) + .collect_vec(), + 1, + "theta rotation".to_string(), + ); + } + + // d is computed simply as XOR of required elements of c (and rotations) + // again stored as 8-bit chunks + let d: ArrayView<(WitIn, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &d).unwrap(); + + for i in 0..5 { + for k in 0..8 { + system.lookup_xor8( + c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), + c_rot[[(i + 1) % 5, k]].0.into(), + d[[i, k]].0.into(), + ) + } + } + + // output state of the Theta sub-round, simple XOR, in 8-bit chunks + let theta_output: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_xor8( + state8[[j, i, k]].0.into(), + d[[i, k]].0.into(), + theta_output[[j, i, k]].0.into(), + ) + } + } + } + + // output state after applying both Rho and Pi sub-rounds + // sub-round Pi is a simple permutation of 64-bit lanes + // sub-round Rho requires rotations + let rhopi_output: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); + + // iterator over split witnesses + let mut rotation_witness = rotation_witness.iter(); + + for i in 0..5 { + #[allow(clippy::needless_range_loop)] + for j in 0..5 { + let arg = theta_output + .slice(s!(j, i, ..)) + .iter() + .map(|e| e.0.expr()) + .collect_vec(); + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); + let many = sizes.len(); + let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) + .map(|(sz, (wit, _))| (sz, wit.expr())) + .collect_vec(); + let arg_rotated = rhopi_output + .slice(s!((2 * i + 3 * j) % 5, j, ..)) + .iter() + .map(|e| e.0.expr()) + .collect_vec(); + system.constrain_left_rotation64( + &arg, + &rep_split, + &arg_rotated, + ROTATION_CONSTANTS[j][i], + format!("RHOPI {i}, {j}"), + ); + } + } + + let mut chi_output = chi_output; + chi_output.extend(iota_output[8..].to_vec()); + let chi_output: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); + + // for the Chi sub-round, we use an intermediate witness storing the result of + // the required AND + let nonlinear: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); + + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_and8( + not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), + rhopi_output[[j, (i + 2) % 5, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + ); + + system.lookup_xor8( + rhopi_output[[j, i, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + ); + } + } + } + + // TODO: 24/25 elements stay the same after Iota; eliminate duplication? + let iota_output_arr: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); + + for k in 0..8 { + system.lookup_xor8( + chi_output[[0, 0, k]].0.into(), + E::BaseField::from_i64(((RC[round] >> (k * 8)) & 0xFF) as i64).expr(), + iota_output_arr[[0, 0, k]].0.into(), + ); + } + + iota_output + }); + + let mut global_and_lookup = 0; + let mut global_xor_lookup = 3 * AND_LOOKUPS; + let mut global_range_lookup = 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS; + + let ConstraintSystem { + mut expressions, + mut expr_names, + mut evals, + and_lookups, + xor_lookups, + range_lookups, + .. + } = system; + + for (i, lookup) in chain!(and_lookups, xor_lookups, range_lookups) + .flatten() + .enumerate() + { + expressions.push(lookup); + let (idx, round) = if i < 3 * AND_LOOKUPS { + let round = i / AND_LOOKUPS; + (&mut global_and_lookup, round) + } else if i < 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS { + let round = (i - 3 * AND_LOOKUPS) / XOR_LOOKUPS; + (&mut global_xor_lookup, round) + } else { + let round = (i - 3 * AND_LOOKUPS - 3 * XOR_LOOKUPS) / RANGE_LOOKUPS; + (&mut global_range_lookup, round) + }; + expr_names.push(format!("round {round}: {i}th lookup felt")); + evals.push(lookup_outputs[*idx].clone()); + *idx += 1; + } + + assert!(global_and_lookup == 3 * AND_LOOKUPS); + assert!(global_xor_lookup == 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); + assert!(global_range_lookup == LOOKUP_FELTS_PER_ROUND * ROUNDS); + + let keccak_input8: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), keccak_input8).unwrap(); + let keccak_input32 = keccak_input32.to_vec(); + let mut keccak_input32_iter = keccak_input32.iter().cloned(); + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::( + keccak_input8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0.expr())) + .collect_vec() + .as_slice(), + ); + expressions.push(expr); + evals.push(keccak_input32_iter.next().unwrap().clone()); + expr_names.push(format!("build 32-bit input: {x}, {y}, {k}")); + } + } + } + + let keccak_output32 = keccak_output32.to_vec(); + let keccak_output8: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), keccak_output8).unwrap(); + let mut keccak_output32_iter = keccak_output32.iter().cloned(); + + for x in 0..5 { + for y in 0..5 { + for k in 0..2 { + // create an expression combining 4 elements of state8 into a single 32-bit felt + let expr = expansion_expr::( + &keccak_output8 + .slice(s![x, y, 4 * k..4 * (k + 1)]) + .iter() + .map(|e| (8, e.0.expr())) + .collect_vec(), + ); + expressions.push(expr); + evals.push(keccak_output32_iter.next().unwrap().clone()); + expr_names.push(format!("build 32-bit output: {x}, {y}, {k}")); + } + } + } + + chip.add_layer(Layer::new( + "Rounds".to_string(), + LayerType::Zerocheck, + expressions, + vec![], + bases.into_iter().map(|e| e.1).collect_vec(), + vec![], + evals, + expr_names, + )); + } +} + +#[derive(Clone, Default)] +pub struct KeccakTrace { + pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, +} + +impl ProtocolWitnessGenerator for KeccakLayout +where + E: ExtensionField, +{ + type Trace = KeccakTrace; + + fn phase1_witness_group(&self, phase1: Self::Trace) -> RowMajorMatrix { + let instances = &phase1.instances; + let num_instances = instances.len(); + + let wits: Vec<_> = (0..num_instances) + .into_par_iter() + .flat_map(|instance_id| { + fn conv64to8(input: u64) -> [u64; 8] { + MaskRepresentation::new(vec![(64, input).into()]) + .convert(vec![8; 8]) + .values() + .try_into() + .unwrap() + } + + let state32 = instances[instance_id] + .iter() + .map(|&e| e as u64) + .collect_vec(); + + let mut state64 = [[0u64; 5]; 5]; + let mut state8 = [[[0u64; 8]; 5]; 5]; + + zip_eq(iproduct!(0..5, 0..5), state32.iter().tuples()) + .map(|((x, y), (&lo, &hi))| { + state64[x][y] = lo | (hi << 32); + }) + .count(); + + for x in 0..5 { + for y in 0..5 { + state8[x][y] = conv64to8(state64[x][y]); + } + } + + let mut wits = Vec::with_capacity(KECCAK_WIT_SIZE_PER_ROUND); + let mut push_instance = |new_wits: Vec| { + let felts = u64s_to_felts::(new_wits); + wits.extend(felts); + }; + + push_instance(state8.into_iter().flatten().flatten().collect_vec()); + + #[allow(clippy::needless_range_loop)] + for round in 0..ROUNDS { + let mut c_aux64 = [[0u64; 5]; 5]; + let mut c_aux8 = [[[0u64; 8]; 5]; 5]; + + for i in 0..5 { + c_aux64[i][0] = state64[0][i]; + c_aux8[i][0] = conv64to8(c_aux64[i][0]); + for j in 1..5 { + c_aux64[i][j] = state64[j][i] ^ c_aux64[i][j - 1]; + c_aux8[i][j] = conv64to8(c_aux64[i][j]); + } + } + + let mut c64 = [0u64; 5]; + let mut c8 = [[0u64; 8]; 5]; + + for x in 0..5 { + c64[x] = c_aux64[x][4]; + c8[x] = conv64to8(c64[x]); + } + + let mut c_temp = [[0u64; 6]; 5]; + for i in 0..5 { + let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) + .convert(vec![16, 15, 1, 16, 15, 1]); + c_temp[i] = rep.values().try_into().unwrap(); + } + + let mut crot64 = [0u64; 5]; + let mut crot8 = [[0u64; 8]; 5]; + for i in 0..5 { + crot64[i] = c64[i].rotate_left(1); + crot8[i] = conv64to8(crot64[i]); + } + + let mut d64 = [0u64; 5]; + let mut d8 = [[0u64; 8]; 5]; + for x in 0..5 { + d64[x] = c64[(x + 4) % 5] ^ c64[(x + 1) % 5].rotate_left(1); + d8[x] = conv64to8(d64[x]); + } + + let mut theta_state64 = state64; + let mut theta_state8 = [[[0u64; 8]; 5]; 5]; + let mut rotation_witness = vec![]; + + for x in 0..5 { + for y in 0..5 { + theta_state64[y][x] ^= d64[x]; + theta_state8[y][x] = conv64to8(theta_state64[y][x]); + + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); + let rep = + MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) + .convert(sizes); + rotation_witness.extend(rep.values()); + } + } + + // Rho and Pi steps + let mut rhopi_output64 = [[0u64; 5]; 5]; + let mut rhopi_output8 = [[[0u64; 8]; 5]; 5]; + + for x in 0..5 { + for y in 0..5 { + rhopi_output64[(2 * x + 3 * y) % 5][y % 5] = + theta_state64[y][x].rotate_left(ROTATION_CONSTANTS[y][x] as u32); + } + } + + for x in 0..5 { + for y in 0..5 { + rhopi_output8[x][y] = conv64to8(rhopi_output64[x][y]); + } + } + + // Chi step + let mut nonlinear64 = [[0u64; 5]; 5]; + let mut nonlinear8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + nonlinear64[y][x] = + !rhopi_output64[y][(x + 1) % 5] & rhopi_output64[y][(x + 2) % 5]; + nonlinear8[y][x] = conv64to8(nonlinear64[y][x]); + } + } + + let mut chi_output64 = [[0u64; 5]; 5]; + let mut chi_output8 = [[[0u64; 8]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + chi_output64[y][x] = nonlinear64[y][x] ^ rhopi_output64[y][x]; + chi_output8[y][x] = conv64to8(chi_output64[y][x]); + } + } + + // Iota step + let mut iota_output64 = chi_output64; + let mut iota_output8 = [[[0u64; 8]; 5]; 5]; + iota_output64[0][0] ^= RC[round]; + + for x in 0..5 { + for y in 0..5 { + iota_output8[x][y] = conv64to8(iota_output64[x][y]); + } + } + + let all_wits64 = [ + c_aux8.into_iter().flatten().flatten().collect_vec(), + c_temp.into_iter().flatten().collect_vec(), + crot8.into_iter().flatten().collect_vec(), + d8.into_iter().flatten().collect_vec(), + theta_state8.into_iter().flatten().flatten().collect_vec(), + rotation_witness, + rhopi_output8.into_iter().flatten().flatten().collect_vec(), + nonlinear8.into_iter().flatten().flatten().collect_vec(), + chi_output8[0][0].to_vec(), + iota_output8.into_iter().flatten().flatten().collect_vec(), + ] + .into_iter() + .flatten() + .collect_vec(); + + assert_eq!(all_wits64.len(), KECCAK_WIT_SIZE_PER_ROUND); + push_instance(all_wits64); + + state64 = iota_output64; + } + + wits + }) + .collect(); + RowMajorMatrix::new_by_values(wits, KECCAK_WIT_SIZE, InstancePaddingStrategy::Default) + } + + fn gkr_witness( + &self, + phase1: &RowMajorMatrix, + _challenges: &[E], + ) -> (GKRCircuitWitness, GKRCircuitOutput) { + // TODO: Make it more efficient. + let instances = phase1 + .values + .par_iter() + .map(|wit| wit.to_canonical_u64()) + .collect::>(); + let num_instances = phase1.num_vars(); + let num_cols = phase1.n_col(); + assert_eq!(num_cols, KECCAK_WIT_SIZE); + + let to_5x5x8_array = |input: &[u64]| -> [[[u64; 8]; 5]; 5] { + assert_eq!(input.len(), 5 * 5 * 8); + input + .chunks(40) + .map(|chunk| { + chunk + .chunks(8) + .map(|x| x.to_vec().try_into().unwrap()) + .collect_vec() + .try_into() + .unwrap() + }) + .collect_vec() + .try_into() + .unwrap() + }; + let to_5x8_array = |input: &[u64]| -> [[u64; 8]; 5] { + input + .chunks(8) + .map(|x| x.to_vec().try_into().unwrap()) + .collect_vec() + .try_into() + .unwrap() + }; + let u8_slice_to_u64 = + |input: &[u64]| -> u64 { input.iter().rev().fold(0, |acc, &e| (acc << 8) | e) }; + let u8_slice_to_u32_slice = |input: &[u64]| -> [u64; 2] { + input + .chunks(4) + .map(u8_slice_to_u64) + .collect_vec() + .try_into() + .unwrap() + }; + + let output_bases: Vec = (0..num_instances) + .into_par_iter() + .flat_map(|instance_id| { + let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; + let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; + + let mut add_and = |a: u64, b: u64, round: usize| { + let c = a & b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + and_lookups[round].extend(vec![a, b, c]); + }; + + let mut add_xor = |a: u64, b: u64, round: usize| { + let c = a ^ b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + xor_lookups[round].extend(vec![a, b, c]); + }; + + let mut add_range = |value: u64, size: usize, round: usize| { + assert!(size <= 16, "{size}"); + range_lookups[round].push(value); + if size < 16 { + range_lookups[round].push(value << (16 - size)); + assert!(value << (16 - size) < (1 << 16)); + } + }; + + let mut state8: [[[u64; 8]; 5]; 5] = to_5x5x8_array( + &instances + [instance_id * num_cols..instance_id * num_cols + KECCAK_LAYER_BYTE_SIZE], + ); + let mut keccak_input32 = [[[0u64; 2]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + keccak_input32[x][y] = u8_slice_to_u32_slice(&state8[x][y]); + } + } + let mut offset = KECCAK_LAYER_BYTE_SIZE; + #[allow(clippy::needless_range_loop)] + for round in 0..ROUNDS { + let ( + c_aux8, + _c_temp, + crot8, + d8, + theta_state8, + _rotation_witness, + rhopi_output8, + nonlinear8, + chi_output8, + iota_output8, + ) = split_from_offset!( + instances[instance_id * num_cols..(instance_id + 1) * num_cols], + offset, + KECCAK_WIT_SIZE_PER_ROUND, + 200, + 30, + 40, + 40, + 200, + 146, + 200, + 200, + 8, + 200 + ); + offset += KECCAK_WIT_SIZE_PER_ROUND; + let c_aux8 = to_5x5x8_array(&c_aux8); + + for i in 0..5 { + for j in 1..5 { + for k in 0..8 { + add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); + } + } + } + + let mut c8 = [[0u64; 8]; 5]; + let mut c64 = [0u64; 5]; + + for x in 0..5 { + c8[x] = c_aux8[x][4]; + c64[x] = u8_slice_to_u64(&c8[x]); + } + + for i in 0..5 { + let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) + .convert(vec![16, 15, 1, 16, 15, 1]); + for mask in rep.rep { + add_range(mask.value, mask.size, round); + } + } + + let crot8 = to_5x8_array(&crot8); + let d8 = to_5x8_array(&d8); + for x in 0..5 { + for k in 0..8 { + add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); + } + } + + let theta_state8 = to_5x5x8_array(&theta_state8); + let mut theta_state64 = [[0u64; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + theta_state64[y][x] = u8_slice_to_u64(&theta_state8[y][x]); + } + } + + for x in 0..5 { + for y in 0..5 { + for k in 0..8 { + add_xor(state8[y][x][k], d8[x][k], round); + } + + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); + let rep = + MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) + .convert(sizes); + for mask in rep.rep.iter() { + if mask.size != 32 { + add_range(mask.value, mask.size, round); + } + } + } + } + + // Rho and Pi steps + let rhopi_output8 = to_5x5x8_array(&rhopi_output8); + + // Chi step + let nonlinear8 = to_5x5x8_array(&nonlinear8); + for x in 0..5 { + for y in 0..5 { + for k in 0..8 { + add_and( + 0xFF - rhopi_output8[y][(x + 1) % 5][k], + rhopi_output8[y][(x + 2) % 5][k], + round, + ); + } + } + } + + for x in 0..5 { + for y in 0..5 { + for k in 0..8 { + add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) + } + } + } + + // Iota step + let chi_output8: [u64; 8] = chi_output8.try_into().unwrap(); // only save chi_output8[0][0]; + let iota_output8 = to_5x5x8_array(&iota_output8); + for k in 0..8 { + add_xor(chi_output8[k], (RC[round] >> (k * 8)) & 0xFF, round); + } + + state8 = iota_output8; + } + + let mut keccak_output32 = [[[0u64; 2]; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + keccak_output32[x][y] = u8_slice_to_u32_slice(&state8[x][y]); + } + } + + chain!( + keccak_output32.into_iter().flatten().flatten(), + keccak_input32.into_iter().flatten().flatten(), + (0..ROUNDS).rev().flat_map(|i| and_lookups[i].clone()), + (0..ROUNDS).rev().flat_map(|i| xor_lookups[i].clone()), + (0..ROUNDS).rev().flat_map(|i| range_lookups[i].clone()) + ) + .collect_vec() + }) + .collect(); + + let bases = phase1.to_cols_base::(); + let output_bases = RowMajorMatrix::new_by_values( + output_bases + .into_iter() + .map(E::BaseField::from_u64) + .collect(), + KECCAK_OUT_EVAL_SIZE, + InstancePaddingStrategy::Default, + ) + .to_cols_base::(); + + ( + GKRCircuitWitness { + layers: vec![LayerWitness { + bases, + ..Default::default() + }], + }, + GKRCircuitOutput(LayerWitness { + bases: output_bases, + ..Default::default() + }), + ) + } +} + +pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) { + let params = KeccakParams {}; + let (layout, chip) = KeccakLayout::build(params); + + let mut instances = vec![]; + for state in &states { + let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); + let state_mask32 = state_mask64.convert(vec![32; 50]); + + instances.push( + state_mask32 + .values() + .iter() + .map(|e| *e as u32) + .collect_vec() + .try_into() + .unwrap(), + ); + } + + let num_instances = instances.len(); + let phase1_witness = layout.phase1_witness(KeccakTrace { + instances: instances.clone(), + }); + + let mut prover_transcript = BasicTranscript::::new(b"protocol"); + + // Omit the commit phase1 and phase2. + let (gkr_witness, _gkr_output) = layout.gkr_witness(&phase1_witness, &[]); + + let out_evals = { + let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); + let point = Arc::new(vec![E::from_u64(29); log2_num_instances as usize]); + + if test_outputs { + // Confront outputs with tiny_keccak::keccakf call + let mut instance_outputs = vec![vec![]; num_instances]; + for base in gkr_witness + .layers + .last() + .unwrap() + .bases + .iter() + .take(KECCAK_OUTPUT_SIZE) + { + assert_eq!(base.len(), num_instances); + for i in 0..num_instances { + instance_outputs[i].push(base[i]); + } + } + + for i in 0..num_instances { + let mut state = states[i]; + keccakf(&mut state); + assert_eq!( + state + .to_vec() + .iter() + .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) + .map(|e| Goldilocks::from_u64(e as u64)) + .collect_vec(), + instance_outputs[i] + ); + } + } + + let out_evals = gkr_witness + .layers + .last() + .unwrap() + .bases + .iter() + .map(|base| PointAndEval { + point: point.clone(), + eval: subprotocols::utils::evaluate_mle_ext(base, &point), + }) + .collect_vec(); + + assert_eq!(out_evals.len(), KECCAK_OUT_EVAL_SIZE); + + out_evals + }; + + let gkr_circuit = chip.gkr_circuit(); + dbg!(&gkr_circuit.layers.len()); + let GKRProverOutput { gkr_proof, .. } = gkr_circuit + .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) + .expect("Failed to prove phase"); + + if verify { + { + let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + + gkr_circuit + .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) + .expect("GKR verify failed"); + + // Omit the PCS opening phase. + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rand::{Rng, SeedableRng}; + + #[test] + fn test_keccakf() { + std::thread::Builder::new() + .name("keccak_test".into()) + .stack_size(64 * 1024 * 1024) + .spawn(|| { + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let num_instances = 8; + let mut states: Vec<[u64; 25]> = vec![]; + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())); + } + + run_faster_keccakf(states, true, true); + }) + .unwrap() + .join() + .unwrap(); + } + + #[ignore] + #[test] + fn test_keccakf_nonpow2() { + std::thread::Builder::new() + .name("keccak_test".into()) + .stack_size(64 * 1024 * 1024) + .spawn(|| { + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let num_instances = 5; + let mut states: Vec<[u64; 25]> = vec![]; + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())); + } + + run_faster_keccakf(states, true, true); + }) + .unwrap() + .join() + .unwrap(); + } +} diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index 2eaddb5b1..6b58a4c19 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -2,8 +2,8 @@ mod bitwise_keccakf; mod lookup_keccakf; mod utils; pub use bitwise_keccakf::{run_keccakf, setup_gkr_circuit}; -// pub use lookup_keccakf::{ -// AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KECCAK_OUT_EVAL_SIZE, KeccakLayout, KeccakParams, -// KeccakTrace, RANGE_LOOKUPS, RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, -// run_faster_keccakf, -// }; +pub use lookup_keccakf::{ + AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KECCAK_OUT_EVAL_SIZE, KeccakLayout, KeccakParams, + KeccakTrace, RANGE_LOOKUPS, RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, + run_faster_keccakf, +}; diff --git a/multilinear_extensions/src/expression.rs b/multilinear_extensions/src/expression.rs index 2a54a0e99..99ff50be8 100644 --- a/multilinear_extensions/src/expression.rs +++ b/multilinear_extensions/src/expression.rs @@ -915,6 +915,13 @@ impl> ToExpr for F { } } +impl ToExpr for Expression { + type Output = Expression; + fn expr(&self) -> Self::Output { + self.clone() + } +} + pub fn wit_infer_by_expr<'a, E: ExtensionField>( fixed: &[ArcMultilinearExtension<'a, E>], witnesses: &[ArcMultilinearExtension<'a, E>], @@ -1067,7 +1074,7 @@ macro_rules! impl_expr_from_unsigned { $( impl> From<$t> for Expression { fn from(value: $t) -> Self { - Expression::Constant(itertools::Either::Left(F::from_u64(value as u64))) + Expression::Constant(Either::Left(F::from_u64(value as u64))) } } )* From ccda0a8b677b36df39663ce73d9b701803abc23d Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 29 May 2025 15:06:59 +0800 Subject: [PATCH 14/28] wip --- gkr_iop/src/precompiles/lookup_keccakf.rs | 95 +++++++++++++---------- 1 file changed, 56 insertions(+), 39 deletions(-) diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 5a0c0d1c6..5680c3d67 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -2,11 +2,13 @@ use std::{array, cmp::Ordering, marker::PhantomData, sync::Arc}; use ff_ext::{ExtensionField, SmallField}; use itertools::{Itertools, chain, iproduct, zip_eq}; -use multilinear_extensions::{Expression, ToExpr, WitIn}; +use multilinear_extensions::{Expression, ToExpr, WitIn, util::ceil_log2}; use ndarray::{ArrayView, Ix2, Ix3, s}; use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{ + IntoParallelIterator, IntoParallelRefIterator, ParallelExtend, ParallelIterator, +}; use serde::{Deserialize, Serialize}; use tiny_keccak::keccakf; use transcript::BasicTranscript; @@ -806,7 +808,10 @@ where } } - let mut wits = Vec::with_capacity(KECCAK_WIT_SIZE_PER_ROUND); + // TODO take structural id information from circuit to do wits assignment + // 1 instance will derive 24 round result + 8 round padding to pow2 for easiler rotation design + let mut wits = + Vec::with_capacity(KECCAK_WIT_SIZE_PER_ROUND * ROUNDS.next_power_of_two()); let mut push_instance = |new_wits: Vec| { let felts = u64s_to_felts::(new_wits); wits.extend(felts); @@ -944,10 +949,20 @@ where state64 = iota_output64; } + // padding to next_power_of_2 rounds for rotation + wits.par_extend( + (0..(ROUNDS.next_power_of_two() - ROUNDS) * KECCAK_WIT_SIZE_PER_ROUND) + .into_par_iter() + .map(|_| E::BaseField::ZERO), + ); wits }) .collect(); - RowMajorMatrix::new_by_values(wits, KECCAK_WIT_SIZE, InstancePaddingStrategy::Default) + RowMajorMatrix::new_by_values( + wits, + KECCAK_WIT_SIZE_PER_ROUND, + InstancePaddingStrategy::Default, + ) } fn gkr_witness( @@ -955,15 +970,15 @@ where phase1: &RowMajorMatrix, _challenges: &[E], ) -> (GKRCircuitWitness, GKRCircuitOutput) { - // TODO: Make it more efficient. - let instances = phase1 + // TODO: fix efficient as here convert basefield back to u64 + let instances_with_rotations = phase1 .values .par_iter() .map(|wit| wit.to_canonical_u64()) .collect::>(); - let num_instances = phase1.num_vars(); + let num_instances_with_rotations = phase1.num_vars(); let num_cols = phase1.n_col(); - assert_eq!(num_cols, KECCAK_WIT_SIZE); + assert_eq!(num_cols, KECCAK_WIT_SIZE_PER_ROUND); let to_5x5x8_array = |input: &[u64]| -> [[[u64; 8]; 5]; 5] { assert_eq!(input.len(), 5 * 5 * 8); @@ -1000,39 +1015,12 @@ where .unwrap() }; - let output_bases: Vec = (0..num_instances) + // process output bases + let output_bases: Vec = (0..num_instances_with_rotations) .into_par_iter() .flat_map(|instance_id| { - let mut and_lookups: Vec> = vec![vec![]; ROUNDS]; - let mut xor_lookups: Vec> = vec![vec![]; ROUNDS]; - let mut range_lookups: Vec> = vec![vec![]; ROUNDS]; - - let mut add_and = |a: u64, b: u64, round: usize| { - let c = a & b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - and_lookups[round].extend(vec![a, b, c]); - }; - - let mut add_xor = |a: u64, b: u64, round: usize| { - let c = a ^ b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - xor_lookups[round].extend(vec![a, b, c]); - }; - - let mut add_range = |value: u64, size: usize, round: usize| { - assert!(size <= 16, "{size}"); - range_lookups[round].push(value); - if size < 16 { - range_lookups[round].push(value << (16 - size)); - assert!(value << (16 - size) < (1 << 16)); - } - }; - let mut state8: [[[u64; 8]; 5]; 5] = to_5x5x8_array( - &instances - [instance_id * num_cols..instance_id * num_cols + KECCAK_LAYER_BYTE_SIZE], + &instances_with_rotations[instance_id * num_cols..][..KECCAK_LAYER_BYTE_SIZE], ); let mut keccak_input32 = [[[0u64; 2]; 5]; 5]; for x in 0..5 { @@ -1043,6 +1031,34 @@ where let mut offset = KECCAK_LAYER_BYTE_SIZE; #[allow(clippy::needless_range_loop)] for round in 0..ROUNDS { + // TODO use with_capacity and retrive number of lookup from circuit + let mut and_lookups: Vec = vec![]; + let mut xor_lookups: Vec = vec![]; + let mut range_lookups: Vec = vec![]; + + let mut add_and = |a: u64, b: u64, round: usize| { + let c = a & b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + and_lookups.extend(vec![a, b, c]); + }; + + let mut add_xor = |a: u64, b: u64, round: usize| { + let c = a ^ b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + xor_lookups.extend(vec![a, b, c]); + }; + + let mut add_range = |value: u64, size: usize, round: usize| { + assert!(size <= 16, "{size}"); + range_lookups.push(value); + if size < 16 { + range_lookups.push(value << (16 - size)); + assert!(value << (16 - size) < (1 << 16)); + } + }; + let ( c_aux8, _c_temp, @@ -1055,7 +1071,8 @@ where chi_output8, iota_output8, ) = split_from_offset!( - instances[instance_id * num_cols..(instance_id + 1) * num_cols], + instances_with_rotations + [instance_id * num_cols..(instance_id + 1) * num_cols], offset, KECCAK_WIT_SIZE_PER_ROUND, 200, From 27eda78700272968369701e8d1fda989aebf8a17 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 29 May 2025 20:38:11 +0800 Subject: [PATCH 15/28] gkr witness assignment pass --- gkr_iop/src/lib.rs | 5 +- gkr_iop/src/precompiles/lookup_keccakf.rs | 291 +++++++++++----------- 2 files changed, 152 insertions(+), 144 deletions(-) diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 70670df1b..933117244 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -59,7 +59,7 @@ where circuit: &GKRCircuit, phase1_witness_group: &RowMajorMatrix, challenges: &[E], - ) -> GKRCircuitWitness<'a, E> { + ) -> (GKRCircuitWitness<'a, E>, GKRCircuitOutput) { // layer order from output to input let n_layers = 100; let mut layer_wits = Vec::>::with_capacity(n_layers + 1); @@ -122,7 +122,8 @@ where layer_wits.reverse(); - GKRCircuitWitness { layers: layer_wits } + GKRCircuitWitness { layers: layer_wits }; + unimplemented!() } } diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 5680c3d67..3dc930000 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -19,7 +19,7 @@ use crate::{ chip::Chip, evaluation::EvalExpression, gkr::{ - GKRCircuitOutput, GKRCircuitWitness, GKRProverOutput, + GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, GKRProverOutput, layer::{Layer, LayerType, LayerWitness}, }, precompiles::utils::{MaskRepresentation, not8_expr}, @@ -336,7 +336,7 @@ pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; pub const KECCAK_OUT_EVAL_SIZE: usize = - KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS; + KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND; pub const KECCAK_WIT_SIZE_PER_ROUND: usize = 1264; pub const KECCAK_WIT_SIZE: usize = KECCAK_WIT_SIZE_PER_ROUND * ROUNDS + KECCAK_LAYER_BYTE_SIZE; @@ -767,7 +767,7 @@ pub struct KeccakTrace { pub instances: Vec<[u32; KECCAK_INPUT_SIZE]>, } -impl ProtocolWitnessGenerator for KeccakLayout +impl<'a, E> ProtocolWitnessGenerator<'a, E> for KeccakLayout where E: ExtensionField, { @@ -950,9 +950,8 @@ where } // padding to next_power_of_2 rounds for rotation - wits.par_extend( + wits.extend( (0..(ROUNDS.next_power_of_two() - ROUNDS) * KECCAK_WIT_SIZE_PER_ROUND) - .into_par_iter() .map(|_| E::BaseField::ZERO), ); wits @@ -967,11 +966,12 @@ where fn gkr_witness( &self, + _circuit: &GKRCircuit, phase1: &RowMajorMatrix, _challenges: &[E], - ) -> (GKRCircuitWitness, GKRCircuitOutput) { - // TODO: fix efficient as here convert basefield back to u64 - let instances_with_rotations = phase1 + ) -> (GKRCircuitWitness<'a, E>, GKRCircuitOutput) { + // TODO: fix efficient as here as it convert felts back to u64 + let instances_rounds = phase1 .values .par_iter() .map(|wit| wit.to_canonical_u64()) @@ -1018,9 +1018,16 @@ where // process output bases let output_bases: Vec = (0..num_instances_with_rotations) .into_par_iter() - .flat_map(|instance_id| { + .flat_map(|instance_round_id| { + let round = instance_round_id % ROUNDS.next_power_of_two(); + + if round >= ROUNDS { + // padding with zero + return vec![0; KECCAK_OUT_EVAL_SIZE]; + } + let mut state8: [[[u64; 8]; 5]; 5] = to_5x5x8_array( - &instances_with_rotations[instance_id * num_cols..][..KECCAK_LAYER_BYTE_SIZE], + &instances_rounds[instance_round_id * num_cols..][..KECCAK_LAYER_BYTE_SIZE], ); let mut keccak_input32 = [[[0u64; 2]; 5]; 5]; for x in 0..5 { @@ -1029,178 +1036,175 @@ where } } let mut offset = KECCAK_LAYER_BYTE_SIZE; - #[allow(clippy::needless_range_loop)] - for round in 0..ROUNDS { - // TODO use with_capacity and retrive number of lookup from circuit - let mut and_lookups: Vec = vec![]; - let mut xor_lookups: Vec = vec![]; - let mut range_lookups: Vec = vec![]; - - let mut add_and = |a: u64, b: u64, round: usize| { - let c = a & b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - and_lookups.extend(vec![a, b, c]); - }; - - let mut add_xor = |a: u64, b: u64, round: usize| { - let c = a ^ b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - xor_lookups.extend(vec![a, b, c]); - }; - - let mut add_range = |value: u64, size: usize, round: usize| { - assert!(size <= 16, "{size}"); - range_lookups.push(value); - if size < 16 { - range_lookups.push(value << (16 - size)); - assert!(value << (16 - size) < (1 << 16)); - } - }; - - let ( - c_aux8, - _c_temp, - crot8, - d8, - theta_state8, - _rotation_witness, - rhopi_output8, - nonlinear8, - chi_output8, - iota_output8, - ) = split_from_offset!( - instances_with_rotations - [instance_id * num_cols..(instance_id + 1) * num_cols], - offset, - KECCAK_WIT_SIZE_PER_ROUND, - 200, - 30, - 40, - 40, - 200, - 146, - 200, - 200, - 8, - 200 - ); - offset += KECCAK_WIT_SIZE_PER_ROUND; - let c_aux8 = to_5x5x8_array(&c_aux8); + // #[allow(clippy::needless_range_loop)] + // for round in 0..ROUNDS { + // TODO use with_capacity and retrive number of lookup from circuit + let mut and_lookups: Vec = vec![]; + let mut xor_lookups: Vec = vec![]; + let mut range_lookups: Vec = vec![]; + + let mut add_and = |a: u64, b: u64| { + let c = a & b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + and_lookups.extend(vec![a, b, c]); + }; - for i in 0..5 { - for j in 1..5 { - for k in 0..8 { - add_xor(c_aux8[i][j - 1][k], state8[j][i][k], round); - } - } + let mut add_xor = |a: u64, b: u64| { + let c = a ^ b; + assert!(a < (1 << 8)); + assert!(b < (1 << 8)); + xor_lookups.extend(vec![a, b, c]); + }; + + let mut add_range = |value: u64, size: usize| { + assert!(size <= 16, "{size}"); + range_lookups.push(value); + if size < 16 { + range_lookups.push(value << (16 - size)); + assert!(value << (16 - size) < (1 << 16)); } + }; - let mut c8 = [[0u64; 8]; 5]; - let mut c64 = [0u64; 5]; + let ( + c_aux8, + _c_temp, + crot8, + d8, + theta_state8, + _rotation_witness, + rhopi_output8, + nonlinear8, + chi_output8, + iota_output8, + ) = split_from_offset!( + instances_rounds[instance_round_id * num_cols..][..num_cols], + offset, + KECCAK_WIT_SIZE_PER_ROUND, + 200, + 30, + 40, + 40, + 200, + 146, + 200, + 200, + 8, + 200 + ); + offset += KECCAK_WIT_SIZE_PER_ROUND; + let c_aux8 = to_5x5x8_array(&c_aux8); - for x in 0..5 { - c8[x] = c_aux8[x][4]; - c64[x] = u8_slice_to_u64(&c8[x]); + for i in 0..5 { + for j in 1..5 { + for k in 0..8 { + add_xor(c_aux8[i][j - 1][k], state8[j][i][k]); + } } + } - for i in 0..5 { - let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) - .convert(vec![16, 15, 1, 16, 15, 1]); - for mask in rep.rep { - add_range(mask.value, mask.size, round); - } + let mut c8 = [[0u64; 8]; 5]; + let mut c64 = [0u64; 5]; + + for x in 0..5 { + c8[x] = c_aux8[x][4]; + c64[x] = u8_slice_to_u64(&c8[x]); + } + + for i in 0..5 { + let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) + .convert(vec![16, 15, 1, 16, 15, 1]); + for mask in rep.rep { + add_range(mask.value, mask.size); } + } - let crot8 = to_5x8_array(&crot8); - let d8 = to_5x8_array(&d8); - for x in 0..5 { - for k in 0..8 { - add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k], round); - } + let crot8 = to_5x8_array(&crot8); + let d8 = to_5x8_array(&d8); + for x in 0..5 { + for k in 0..8 { + add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k]); } + } - let theta_state8 = to_5x5x8_array(&theta_state8); - let mut theta_state64 = [[0u64; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - theta_state64[y][x] = u8_slice_to_u64(&theta_state8[y][x]); - } + let theta_state8 = to_5x5x8_array(&theta_state8); + let mut theta_state64 = [[0u64; 5]; 5]; + for x in 0..5 { + for y in 0..5 { + theta_state64[y][x] = u8_slice_to_u64(&theta_state8[y][x]); } + } - for x in 0..5 { - for y in 0..5 { - for k in 0..8 { - add_xor(state8[y][x][k], d8[x][k], round); - } + for x in 0..5 { + for y in 0..5 { + for k in 0..8 { + add_xor(state8[y][x][k], d8[x][k]); + } - let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); - let rep = - MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) - .convert(sizes); - for mask in rep.rep.iter() { - if mask.size != 32 { - add_range(mask.value, mask.size, round); - } + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); + let rep = MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) + .convert(sizes); + for mask in rep.rep.iter() { + if mask.size != 32 { + add_range(mask.value, mask.size); } } } + } - // Rho and Pi steps - let rhopi_output8 = to_5x5x8_array(&rhopi_output8); + // Rho and Pi steps + let rhopi_output8 = to_5x5x8_array(&rhopi_output8); - // Chi step - let nonlinear8 = to_5x5x8_array(&nonlinear8); - for x in 0..5 { - for y in 0..5 { - for k in 0..8 { - add_and( - 0xFF - rhopi_output8[y][(x + 1) % 5][k], - rhopi_output8[y][(x + 2) % 5][k], - round, - ); - } + // Chi step + let nonlinear8 = to_5x5x8_array(&nonlinear8); + for x in 0..5 { + for y in 0..5 { + for k in 0..8 { + add_and( + 0xFF - rhopi_output8[y][(x + 1) % 5][k], + rhopi_output8[y][(x + 2) % 5][k], + ); } } + } - for x in 0..5 { - for y in 0..5 { - for k in 0..8 { - add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k], round) - } + for x in 0..5 { + for y in 0..5 { + for k in 0..8 { + add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k]) } } + } - // Iota step - let chi_output8: [u64; 8] = chi_output8.try_into().unwrap(); // only save chi_output8[0][0]; - let iota_output8 = to_5x5x8_array(&iota_output8); - for k in 0..8 { - add_xor(chi_output8[k], (RC[round] >> (k * 8)) & 0xFF, round); - } + // Iota step + let chi_output8: [u64; 8] = chi_output8.try_into().unwrap(); // only save chi_output8[0][0]; + let iota_output8 = to_5x5x8_array(&iota_output8); - state8 = iota_output8; + for k in 0..8 { + add_xor(chi_output8[k], (RC[round] >> (k * 8)) & 0xFF); } + // } + let mut keccak_output32 = [[[0u64; 2]; 5]; 5]; for x in 0..5 { for y in 0..5 { - keccak_output32[x][y] = u8_slice_to_u32_slice(&state8[x][y]); + keccak_output32[x][y] = u8_slice_to_u32_slice(&iota_output8[x][y]); } } chain!( keccak_output32.into_iter().flatten().flatten(), keccak_input32.into_iter().flatten().flatten(), - (0..ROUNDS).rev().flat_map(|i| and_lookups[i].clone()), - (0..ROUNDS).rev().flat_map(|i| xor_lookups[i].clone()), - (0..ROUNDS).rev().flat_map(|i| range_lookups[i].clone()) + and_lookups, + xor_lookups, + range_lookups ) .collect_vec() }) .collect(); - let bases = phase1.to_cols_base::(); + let bases = phase1.to_mles().into_iter().map(Arc::new).collect_vec(); let output_bases = RowMajorMatrix::new_by_values( output_bases .into_iter() @@ -1209,7 +1213,10 @@ where KECCAK_OUT_EVAL_SIZE, InstancePaddingStrategy::Default, ) - .to_cols_base::(); + .to_mles() + .into_iter() + .map(Arc::new) + .collect_vec(); ( GKRCircuitWitness { From 1640fd4dc47a8ab9765e8689bb1b3b77915110d6 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Thu, 29 May 2025 22:54:41 +0800 Subject: [PATCH 16/28] compile pass --- gkr_iop/src/gkr/layer.rs | 17 +- gkr_iop/src/precompiles/bitwise_keccakf.rs | 2 +- gkr_iop/src/precompiles/lookup_keccakf.rs | 480 +++++++++++---------- gkr_iop/src/precompiles/utils.rs | 4 + 4 files changed, 266 insertions(+), 237 deletions(-) diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 86d76484c..3d9dfe594 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -67,6 +67,7 @@ impl Layer { pub fn new( name: String, ty: LayerType, + // exprs concat zero/non-zero expression. exprs: Vec>, challenges: Vec>, in_eval_expr: Vec>, @@ -74,18 +75,12 @@ impl Layer { outs: Vec<(Option>, Vec>)>, expr_names: Vec, ) -> Self { - assert_eq!( - outs.iter() - .map(|(_, eval_expressions)| eval_expressions.len()) - .sum::(), - exprs.len() // output eval not match with number of expression - ); - let mut expr_names = expr_names; if expr_names.len() < exprs.len() { - expr_names.extend(vec![ - "unavailable".to_string(); - exprs.len() - expr_names.len() - ]); + // expr_names.extend(vec![ + // "unavailable".to_string(); + // exprs.len() - expr_names.len() + // ]); + panic!("there are expr without name") } let max_expr_degree = exprs.iter().map(|expr| expr.degree()).max().unwrap(); Self { diff --git a/gkr_iop/src/precompiles/bitwise_keccakf.rs b/gkr_iop/src/precompiles/bitwise_keccakf.rs index c7d6ff2a6..ec0e76138 100644 --- a/gkr_iop/src/precompiles/bitwise_keccakf.rs +++ b/gkr_iop/src/precompiles/bitwise_keccakf.rs @@ -393,7 +393,7 @@ pub fn run_keccakf( // Omit the commit phase1 and phase2. let span = entered_span!("gkr_witness", profiling_1 = true); - let gkr_witness = layout.gkr_witness(&gkr_circuit, &phase1_witness, &[]); + let (gkr_witness, _gkr_output) = layout.gkr_witness(&gkr_circuit, &phase1_witness, &[]); exit_span!(span); let out_evals = { diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 3dc930000..59af45e1a 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -2,15 +2,17 @@ use std::{array, cmp::Ordering, marker::PhantomData, sync::Arc}; use ff_ext::{ExtensionField, SmallField}; use itertools::{Itertools, chain, iproduct, zip_eq}; -use multilinear_extensions::{Expression, ToExpr, WitIn, util::ceil_log2}; +use multilinear_extensions::{ + Expression, ToExpr, WitIn, + mle::{Point, PointAndEval}, + util::ceil_log2, +}; use ndarray::{ArrayView, Ix2, Ix3, s}; use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; use p3_goldilocks::Goldilocks; -use rayon::iter::{ - IntoParallelIterator, IntoParallelRefIterator, ParallelExtend, ParallelIterator, -}; +use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use serde::{Deserialize, Serialize}; -use tiny_keccak::keccakf; +use sumcheck::util::optimal_sumcheck_threads; use transcript::BasicTranscript; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -113,6 +115,9 @@ fn rotation_split(delta: usize) -> (Vec, usize) { } struct ConstraintSystem { + // expressions include zero & non-zero expression, differentiate via evals + // zero expr represented as Linear with all 0 value + // TODO we should define an Zero enum for it expressions: Vec>, expr_names: Vec, evals: Vec>, @@ -424,245 +429,249 @@ impl ProtocolBuilder for KeccakLayout { fn build_gkr_phase(&mut self, chip: &mut Chip) { let final_outputs = - chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND * ROUNDS }>(); + chip.allocate_output_evals::<{ KECCAK_OUTPUT_SIZE + KECCAK_INPUT_SIZE + LOOKUP_FELTS_PER_ROUND }>(); let mut final_outputs_iter = final_outputs.iter(); + // TODO we can rlc lookup via alpha/beta challenge, so gkr output layer only got rlc result + // with that, we save more prover cost with less allocation + let [keccak_output32, keccak_input32, lookup_outputs] = [ KECCAK_OUTPUT_SIZE, KECCAK_INPUT_SIZE, - LOOKUP_FELTS_PER_ROUND * ROUNDS, + LOOKUP_FELTS_PER_ROUND, ] .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); let lookup_outputs = lookup_outputs.to_vec(); - let bases = chip.allocate_wits_in_layer::(); + // TODO we should separate into different eq group, because they should reduce from differenent points + // TODO it should be at least 2 group. + // TODO - group1: lookup one group (due to same tower prover length) + // TODO - group2: read/write another group + let (bases, [eq]) = chip.allocate_wits_in_zero_layer::(); for (openings, wit) in bases.iter().enumerate() { chip.allocate_opening(openings, wit.1.clone()); } let keccak_input8 = &bases[..KECCAK_LAYER_BYTE_SIZE]; - let keccak_output8 = &bases[KECCAK_WIT_SIZE - KECCAK_LAYER_BYTE_SIZE..KECCAK_WIT_SIZE]; + let keccak_output8 = &bases[KECCAK_WIT_SIZE_PER_ROUND - KECCAK_LAYER_BYTE_SIZE..]; let mut system = ConstraintSystem::new(); - let mut offset = KECCAK_LAYER_BYTE_SIZE; - let _ = (0..ROUNDS).fold(keccak_input8.to_vec(), |state8, round| { - #[allow(non_snake_case)] - let ( - c_aux, - c_temp, - c_rot, - d, - theta_output, - rotation_witness, - rhopi_output, - nonlinear, - chi_output, - iota_output, - ) = split_from_offset!( - bases, - offset, - KECCAK_WIT_SIZE_PER_ROUND, - 200, - 30, - 40, - 40, - 200, - 146, - 200, - 200, - 8, - 200 - ); - offset += KECCAK_WIT_SIZE_PER_ROUND; + #[allow(non_snake_case)] + let ( + c_aux, + c_temp, + c_rot, + d, + theta_output, + rotation_witness, + rhopi_output, + nonlinear, + chi_output, + iota_output, + ) = split_from_offset!( + bases, + KECCAK_LAYER_BYTE_SIZE, + KECCAK_WIT_SIZE_PER_ROUND, + 200, + 30, + 40, + 40, + 200, + 146, + 200, + 200, + 8, + 200 + ); - { - let n_wits = 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 8 + 200; - assert_eq!(KECCAK_WIT_SIZE_PER_ROUND, n_wits); - } + { + let n_wits = 200 + 30 + 40 + 40 + 200 + 146 + 200 + 200 + 8 + 200; + assert_eq!(KECCAK_WIT_SIZE_PER_ROUND, n_wits); + } - // TODO: ndarrays can be replaced with normal arrays + // TODO: ndarrays can be replaced with normal arrays - // Input state of the round in 8-bit chunks - let state8: ArrayView<(WitIn, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &state8).unwrap(); + // Input state of the round in 8-bit chunks + let state8: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &keccak_input8).unwrap(); - // The purpose is to compute the auxiliary array - // c[i] = XOR (state[j][i]) for j in 0..5 - // We unroll it into - // c_aux[i][j] = XOR (state[k][i]) for k in 0..j - // We use c_aux[i][4] instead of c[i] - // c_aux is also stored in 8-bit chunks - let c_aux: ArrayView<(WitIn, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); + // The purpose is to compute the auxiliary array + // c[i] = XOR (state[j][i]) for j in 0..5 + // We unroll it into + // c_aux[i][j] = XOR (state[k][i]) for k in 0..j + // We use c_aux[i][4] instead of c[i] + // c_aux is also stored in 8-bit chunks + let c_aux: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &c_aux).unwrap(); - for i in 0..5 { + for i in 0..5 { + for k in 0..8 { + // Initialize first element + system.constrain_eq( + state8[[0, i, k]].0.into(), + c_aux[[i, 0, k]].0.into(), + "init c_aux".to_string(), + ); + } + for j in 1..5 { + // Check xor using lookups over all chunks for k in 0..8 { - // Initialize first element - system.constrain_eq( - state8[[0, i, k]].0.into(), - c_aux[[i, 0, k]].0.into(), - "init c_aux".to_string(), + system.lookup_xor8( + c_aux[[i, j - 1, k]].0.into(), + state8[[j, i, k]].0.into(), + c_aux[[i, j, k]].0.into(), ); } - for j in 1..5 { - // Check xor using lookups over all chunks - for k in 0..8 { - system.lookup_xor8( - c_aux[[i, j - 1, k]].0.into(), - state8[[j, i, k]].0.into(), - c_aux[[i, j, k]].0.into(), - ); - } - } } + } - // Compute c_rot[i] = c[i].rotate_left(1) - // To understand how rotations are performed in general, consult the - // documentation of `constrain_left_rotation64`. Here c_temp is the split - // witness for a 1-rotation. + // Compute c_rot[i] = c[i].rotate_left(1) + // To understand how rotations are performed in general, consult the + // documentation of `constrain_left_rotation64`. Here c_temp is the split + // witness for a 1-rotation. - let c_temp: ArrayView<(WitIn, EvalExpression), Ix2> = - ArrayView::from_shape((5, 6), &c_temp).unwrap(); - let c_rot: ArrayView<(WitIn, EvalExpression), Ix2> = - ArrayView::from_shape((5, 8), &c_rot).unwrap(); + let c_temp: ArrayView<(WitIn, EvalExpression), Ix2> = + ArrayView::from_shape((5, 6), &c_temp).unwrap(); + let c_rot: ArrayView<(WitIn, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &c_rot).unwrap(); - let (sizes, _) = rotation_split(1); + let (sizes, _) = rotation_split(1); - for i in 0..5 { - assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); + for i in 0..5 { + assert_eq!(c_temp.slice(s![i, ..]).iter().len(), sizes.iter().len()); - system.constrain_left_rotation64( - &c_aux - .slice(s![i, 4, ..]) - .iter() - .map(|e| e.0.expr()) - .collect_vec(), - &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) - .map(|(e, sz)| (*sz, e.0.expr())) - .collect_vec(), - &c_rot - .slice(s![i, ..]) - .iter() - .map(|e| e.0.expr()) - .collect_vec(), - 1, - "theta rotation".to_string(), - ); + system.constrain_left_rotation64( + &c_aux + .slice(s![i, 4, ..]) + .iter() + .map(|e| e.0.expr()) + .collect_vec(), + &zip_eq(c_temp.slice(s![i, ..]).iter(), sizes.iter()) + .map(|(e, sz)| (*sz, e.0.expr())) + .collect_vec(), + &c_rot + .slice(s![i, ..]) + .iter() + .map(|e| e.0.expr()) + .collect_vec(), + 1, + "theta rotation".to_string(), + ); + } + + // d is computed simply as XOR of required elements of c (and rotations) + // again stored as 8-bit chunks + let d: ArrayView<(WitIn, EvalExpression), Ix2> = + ArrayView::from_shape((5, 8), &d).unwrap(); + + for i in 0..5 { + for k in 0..8 { + system.lookup_xor8( + c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), + c_rot[[(i + 1) % 5, k]].0.into(), + d[[i, k]].0.into(), + ) } + } - // d is computed simply as XOR of required elements of c (and rotations) - // again stored as 8-bit chunks - let d: ArrayView<(WitIn, EvalExpression), Ix2> = - ArrayView::from_shape((5, 8), &d).unwrap(); + // output state of the Theta sub-round, simple XOR, in 8-bit chunks + let theta_output: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); - for i in 0..5 { + for i in 0..5 { + for j in 0..5 { for k in 0..8 { system.lookup_xor8( - c_aux[[(i + 5 - 1) % 5, 4, k]].0.into(), - c_rot[[(i + 1) % 5, k]].0.into(), + state8[[j, i, k]].0.into(), d[[i, k]].0.into(), + theta_output[[j, i, k]].0.into(), ) } } + } - // output state of the Theta sub-round, simple XOR, in 8-bit chunks - let theta_output: ArrayView<(WitIn, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &theta_output).unwrap(); + // output state after applying both Rho and Pi sub-rounds + // sub-round Pi is a simple permutation of 64-bit lanes + // sub-round Rho requires rotations + let rhopi_output: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); - for i in 0..5 { - for j in 0..5 { - for k in 0..8 { - system.lookup_xor8( - state8[[j, i, k]].0.into(), - d[[i, k]].0.into(), - theta_output[[j, i, k]].0.into(), - ) - } - } + // iterator over split witnesses + let mut rotation_witness = rotation_witness.iter(); + + for i in 0..5 { + #[allow(clippy::needless_range_loop)] + for j in 0..5 { + let arg = theta_output + .slice(s!(j, i, ..)) + .iter() + .map(|e| e.0.expr()) + .collect_vec(); + let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); + let many = sizes.len(); + let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) + .map(|(sz, (wit, _))| (sz, wit.expr())) + .collect_vec(); + let arg_rotated = rhopi_output + .slice(s!((2 * i + 3 * j) % 5, j, ..)) + .iter() + .map(|e| e.0.expr()) + .collect_vec(); + system.constrain_left_rotation64( + &arg, + &rep_split, + &arg_rotated, + ROTATION_CONSTANTS[j][i], + format!("RHOPI {i}, {j}"), + ); } + } - // output state after applying both Rho and Pi sub-rounds - // sub-round Pi is a simple permutation of 64-bit lanes - // sub-round Rho requires rotations - let rhopi_output: ArrayView<(WitIn, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &rhopi_output).unwrap(); + let mut chi_output = chi_output; + chi_output.extend(iota_output[8..].to_vec()); + let chi_output: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); - // iterator over split witnesses - let mut rotation_witness = rotation_witness.iter(); + // for the Chi sub-round, we use an intermediate witness storing the result of + // the required AND + let nonlinear: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); - for i in 0..5 { - #[allow(clippy::needless_range_loop)] - for j in 0..5 { - let arg = theta_output - .slice(s!(j, i, ..)) - .iter() - .map(|e| e.0.expr()) - .collect_vec(); - let (sizes, _) = rotation_split(ROTATION_CONSTANTS[j][i]); - let many = sizes.len(); - let rep_split = zip_eq(sizes, rotation_witness.by_ref().take(many)) - .map(|(sz, (wit, _))| (sz, wit.expr())) - .collect_vec(); - let arg_rotated = rhopi_output - .slice(s!((2 * i + 3 * j) % 5, j, ..)) - .iter() - .map(|e| e.0.expr()) - .collect_vec(); - system.constrain_left_rotation64( - &arg, - &rep_split, - &arg_rotated, - ROTATION_CONSTANTS[j][i], - format!("RHOPI {i}, {j}"), + for i in 0..5 { + for j in 0..5 { + for k in 0..8 { + system.lookup_and8( + not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), + rhopi_output[[j, (i + 2) % 5, k]].0.into(), + nonlinear[[j, i, k]].0.into(), ); - } - } - - let mut chi_output = chi_output; - chi_output.extend(iota_output[8..].to_vec()); - let chi_output: ArrayView<(WitIn, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &chi_output).unwrap(); - // for the Chi sub-round, we use an intermediate witness storing the result of - // the required AND - let nonlinear: ArrayView<(WitIn, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &nonlinear).unwrap(); - - for i in 0..5 { - for j in 0..5 { - for k in 0..8 { - system.lookup_and8( - not8_expr(rhopi_output[[j, (i + 1) % 5, k]].0.into()), - rhopi_output[[j, (i + 2) % 5, k]].0.into(), - nonlinear[[j, i, k]].0.into(), - ); - - system.lookup_xor8( - rhopi_output[[j, i, k]].0.into(), - nonlinear[[j, i, k]].0.into(), - chi_output[[j, i, k]].0.into(), - ); - } + system.lookup_xor8( + rhopi_output[[j, i, k]].0.into(), + nonlinear[[j, i, k]].0.into(), + chi_output[[j, i, k]].0.into(), + ); } } + } - // TODO: 24/25 elements stay the same after Iota; eliminate duplication? - let iota_output_arr: ArrayView<(WitIn, EvalExpression), Ix3> = - ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); + // TODO: 24/25 elements stay the same after Iota; eliminate duplication? + let iota_output_arr: ArrayView<(WitIn, EvalExpression), Ix3> = + ArrayView::from_shape((5, 5, 8), &iota_output).unwrap(); - for k in 0..8 { - system.lookup_xor8( - chi_output[[0, 0, k]].0.into(), - E::BaseField::from_i64(((RC[round] >> (k * 8)) & 0xFF) as i64).expr(), - iota_output_arr[[0, 0, k]].0.into(), - ); - } + for k in 0..8 { + system.lookup_xor8( + chi_output[[0, 0, k]].0.into(), + // TODO figure out how to deal with RC, since it's not a constant in rotation + E::BaseField::from_i64(((RC[0] >> (k * 8)) & 0xFF) as i64).expr(), + iota_output_arr[[0, 0, k]].0.into(), + ); + } - iota_output - }); + // TODO add rotation constrain let mut global_and_lookup = 0; let mut global_xor_lookup = 3 * AND_LOOKUPS; @@ -755,8 +764,7 @@ impl ProtocolBuilder for KeccakLayout { expressions, vec![], bases.into_iter().map(|e| e.1).collect_vec(), - vec![], - evals, + vec![(Some(eq.0.expr()), evals)], expr_names, )); } @@ -1026,7 +1034,7 @@ where return vec![0; KECCAK_OUT_EVAL_SIZE]; } - let mut state8: [[[u64; 8]; 5]; 5] = to_5x5x8_array( + let state8: [[[u64; 8]; 5]; 5] = to_5x5x8_array( &instances_rounds[instance_round_id * num_cols..][..KECCAK_LAYER_BYTE_SIZE], ); let mut keccak_input32 = [[[0u64; 2]; 5]; 5]; @@ -1035,7 +1043,6 @@ where keccak_input32[x][y] = u8_slice_to_u32_slice(&state8[x][y]); } } - let mut offset = KECCAK_LAYER_BYTE_SIZE; // #[allow(clippy::needless_range_loop)] // for round in 0..ROUNDS { // TODO use with_capacity and retrive number of lookup from circuit @@ -1079,7 +1086,7 @@ where iota_output8, ) = split_from_offset!( instances_rounds[instance_round_id * num_cols..][..num_cols], - offset, + KECCAK_LAYER_BYTE_SIZE, // offset KECCAK_WIT_SIZE_PER_ROUND, 200, 30, @@ -1092,7 +1099,6 @@ where 8, 200 ); - offset += KECCAK_WIT_SIZE_PER_ROUND; let c_aux8 = to_5x5x8_array(&c_aux8); for i in 0..5 { @@ -1233,11 +1239,23 @@ where } } -pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bool) { +pub fn setup_gkr_circuit() -> (KeccakLayout, GKRCircuit) { let params = KeccakParams {}; let (layout, chip) = KeccakLayout::build(params); + (layout, chip.gkr_circuit()) +} + +pub fn run_faster_keccakf( + (layout, gkr_circuit): (KeccakLayout, GKRCircuit), + states: Vec<[u64; 25]>, + verify: bool, + test_outputs: bool, +) { + let num_instances = states.len(); + let log2_num_instances = ceil_log2(num_instances); + let num_threads = optimal_sumcheck_threads(log2_num_instances); + let mut instances = Vec::with_capacity(num_instances); - let mut instances = vec![]; for state in &states { let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); let state_mask32 = state_mask64.convert(vec![32; 50]); @@ -1253,19 +1271,18 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo ); } - let num_instances = instances.len(); - let phase1_witness = layout.phase1_witness(KeccakTrace { - instances: instances.clone(), + let phase1_witness = layout.phase1_witness_group(KeccakTrace { + instances: instances, }); let mut prover_transcript = BasicTranscript::::new(b"protocol"); // Omit the commit phase1 and phase2. - let (gkr_witness, _gkr_output) = layout.gkr_witness(&phase1_witness, &[]); + let (gkr_witness, _gkr_output) = layout.gkr_witness(&gkr_circuit, &phase1_witness, &[]); let out_evals = { let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); - let point = Arc::new(vec![E::from_u64(29); log2_num_instances as usize]); + let point = vec![E::from_u64(29); log2_num_instances as usize] as Point; if test_outputs { // Confront outputs with tiny_keccak::keccakf call @@ -1278,25 +1295,26 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo .iter() .take(KECCAK_OUTPUT_SIZE) { - assert_eq!(base.len(), num_instances); + assert_eq!(base.evaluations().len(), num_instances); for i in 0..num_instances { - instance_outputs[i].push(base[i]); + instance_outputs[i].push(base.get_base_field_vec()[i]); } } - for i in 0..num_instances { - let mut state = states[i]; - keccakf(&mut state); - assert_eq!( - state - .to_vec() - .iter() - .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) - .map(|e| Goldilocks::from_u64(e as u64)) - .collect_vec(), - instance_outputs[i] - ); - } + // TODO Need fix to check rotation mode + // for i in 0..num_instances { + // let mut state = states[i]; + // keccakf(&mut state); + // assert_eq!( + // state + // .to_vec() + // .iter() + // .flat_map(|e| vec![*e as u32, (e >> 32) as u32]) + // .map(|e| Goldilocks::from_u64(e as u64)) + // .collect_vec(), + // instance_outputs[i] + // ); + // } } let out_evals = gkr_witness @@ -1307,7 +1325,7 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo .iter() .map(|base| PointAndEval { point: point.clone(), - eval: subprotocols::utils::evaluate_mle_ext(base, &point), + eval: base.evaluate(&point), }) .collect_vec(); @@ -1316,10 +1334,16 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo out_evals }; - let gkr_circuit = chip.gkr_circuit(); dbg!(&gkr_circuit.layers.len()); let GKRProverOutput { gkr_proof, .. } = gkr_circuit - .prove(gkr_witness, &out_evals, &[], &mut prover_transcript) + .prove( + num_threads, + log2_num_instances, + gkr_witness, + &out_evals, + &[], + &mut prover_transcript, + ) .expect("Failed to prove phase"); if verify { @@ -1327,7 +1351,13 @@ pub fn run_faster_keccakf(states: Vec<[u64; 25]>, verify: bool, test_outputs: bo let mut verifier_transcript = BasicTranscript::::new(b"protocol"); gkr_circuit - .verify(gkr_proof, &out_evals, &[], &mut verifier_transcript) + .verify( + log2_num_instances, + gkr_proof, + &out_evals, + &[], + &mut verifier_transcript, + ) .expect("GKR verify failed"); // Omit the PCS opening phase. @@ -1349,12 +1379,12 @@ mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let num_instances = 8; - let mut states: Vec<[u64; 25]> = vec![]; + let mut states: Vec<[u64; 25]> = Vec::with_capacity(num_instances); for _ in 0..num_instances { states.push(std::array::from_fn(|_| rng.gen())); } - run_faster_keccakf(states, true, true); + run_faster_keccakf(setup_gkr_circuit(), states, false, true); }) .unwrap() .join() @@ -1371,12 +1401,12 @@ mod tests { let mut rng = rand::rngs::StdRng::seed_from_u64(42); let num_instances = 5; - let mut states: Vec<[u64; 25]> = vec![]; + let mut states: Vec<[u64; 25]> = Vec::with_capacity(num_instances); for _ in 0..num_instances { states.push(std::array::from_fn(|_| rng.gen())); } - run_faster_keccakf(states, true, true); + run_faster_keccakf(setup_gkr_circuit(), states, false, true); }) .unwrap() .join() diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs index e25523193..dbba574b4 100644 --- a/gkr_iop/src/precompiles/utils.rs +++ b/gkr_iop/src/precompiles/utils.rs @@ -5,6 +5,10 @@ use p3_field::PrimeCharacteristicRing; use crate::evaluation::EvalExpression; +pub fn not8_expr(expr: Expression) -> Expression { + E::BaseField::from_u8(0xFF).expr() - expr +} + pub fn zero_eval() -> EvalExpression { EvalExpression::Linear( 0, From 375949819f30f03626ecef5a4c773e378adc8f83 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 15:36:25 +0800 Subject: [PATCH 17/28] benchmark ready --- gkr_iop/benches/bitwise_keccakf.rs | 4 +- gkr_iop/benches/lookup_keccakf.rs | 75 +++++++++----- gkr_iop/src/bin/bitwise_keccak.rs | 4 +- gkr_iop/src/gkr/mock.rs | 2 +- gkr_iop/src/precompiles/lookup_keccakf.rs | 113 ++++++++++++---------- gkr_iop/src/precompiles/mod.rs | 4 +- 6 files changed, 119 insertions(+), 83 deletions(-) diff --git a/gkr_iop/benches/bitwise_keccakf.rs b/gkr_iop/benches/bitwise_keccakf.rs index 83ca69e3f..34bfeaaba 100644 --- a/gkr_iop/benches/bitwise_keccakf.rs +++ b/gkr_iop/benches/bitwise_keccakf.rs @@ -2,7 +2,7 @@ use std::time::Duration; use criterion::*; use ff_ext::GoldilocksExt2; -use gkr_iop::precompiles::{run_keccakf, setup_gkr_circuit}; +use gkr_iop::precompiles::{run_keccakf, setup_keccak_bitwise_circuit}; use itertools::Itertools; use rand::{Rng, SeedableRng}; criterion_group!(benches, keccak_f_fn); @@ -32,7 +32,7 @@ fn keccak_f_fn(c: &mut Criterion) { let instant = std::time::Instant::now(); - let circuit = setup_gkr_circuit(); + let circuit = setup_keccak_bitwise_circuit(); #[allow(clippy::unit_arg)] run_keccakf::(circuit, black_box(states), false, false); let elapsed = instant.elapsed(); diff --git a/gkr_iop/benches/lookup_keccakf.rs b/gkr_iop/benches/lookup_keccakf.rs index a9ab4f5f0..9aa6f78bf 100644 --- a/gkr_iop/benches/lookup_keccakf.rs +++ b/gkr_iop/benches/lookup_keccakf.rs @@ -1,8 +1,10 @@ use std::time::Duration; use criterion::*; -use gkr_iop::precompiles::run_faster_keccakf; +use ff_ext::GoldilocksExt2; +use gkr_iop::precompiles::{run_faster_keccakf, setup_keccak_lookup_circuit}; +use itertools::Itertools; use rand::{Rng, SeedableRng}; criterion_group!(benches, keccak_f_fn); criterion_main!(benches); @@ -10,30 +12,51 @@ criterion_main!(benches); const NUM_SAMPLES: usize = 10; fn keccak_f_fn(c: &mut Criterion) { - // expand more input size once runtime is acceptable - let mut group = c.benchmark_group("keccakf"); - group.sample_size(NUM_SAMPLES); - // Benchmark the proving time - group.bench_function(BenchmarkId::new("keccakf", "keccakf"), |b| { - b.iter_custom(|iters| { - let mut time = Duration::new(0, 0); - for _ in 0..iters { - // Use seeded rng for debugging convenience - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - let state1: [u64; 25] = std::array::from_fn(|_| rng.gen()); - let state2: [u64; 25] = std::array::from_fn(|_| rng.gen()); - - let instant = std::time::Instant::now(); - #[allow(clippy::unit_arg)] - black_box(run_faster_keccakf(vec![state1, state2], false, false)); - let elapsed = instant.elapsed(); - time += elapsed; - } - - time - }); - }); - - group.finish(); + for log_instances in 10..12 { + let num_instance = 1 << log_instances; + // expand more input size once runtime is acceptable + let mut group = c.benchmark_group(format!("keccak_lookup_f_{}", num_instance)); + group.sample_size(NUM_SAMPLES); + group.bench_function( + BenchmarkId::new( + "keccak_lookup_f", + format!("prove_keccak_lookup_f_{}", num_instance), + ), + |b| { + b.iter_custom(|iters| { + let mut time = Duration::new(0, 0); + for _ in 0..iters { + // Use seeded rng for debugging convenience + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let states: Vec<[u64; 25]> = (0..num_instance) + .map(|_| std::array::from_fn(|_| rng.gen())) + .collect_vec(); + + let instant = std::time::Instant::now(); + + let circuit = setup_keccak_lookup_circuit(); + #[allow(clippy::unit_arg)] + run_faster_keccakf::( + circuit, + black_box(states), + false, + false, + ); + let elapsed = instant.elapsed(); + println!( + "keccak_f::create_proof, instances = {}, time = {}", + num_instance, + elapsed.as_secs_f64() + ); + time += elapsed; + } + + time + }); + }, + ); + group.finish(); + } } diff --git a/gkr_iop/src/bin/bitwise_keccak.rs b/gkr_iop/src/bin/bitwise_keccak.rs index c99df3cb9..14e4af84a 100644 --- a/gkr_iop/src/bin/bitwise_keccak.rs +++ b/gkr_iop/src/bin/bitwise_keccak.rs @@ -1,6 +1,6 @@ use clap::{Parser, command}; use ff_ext::GoldilocksExt2; -use gkr_iop::precompiles::{run_keccakf, setup_gkr_circuit}; +use gkr_iop::precompiles::{run_keccakf, setup_keccak_bitwise_circuit}; use itertools::Itertools; use rand::{Rng, SeedableRng}; use tracing::level_filters::LevelFilter; @@ -63,6 +63,6 @@ fn main() { let states: Vec<[u64; 25]> = (0..num_instance) .map(|_| std::array::from_fn(|_| rng.gen())) .collect_vec(); - let circuit_setup = setup_gkr_circuit(); + let circuit_setup = setup_keccak_bitwise_circuit(); run_keccakf::(circuit_setup, states, false, false); } diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index 64aadd2af..f2ff6db88 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -11,7 +11,7 @@ use multilinear_extensions::{ use rand::{rngs::OsRng, thread_rng}; use thiserror::Error; -use crate::{evaluation::EvalExpression, utils::SliceIterator}; +use crate::evaluation::EvalExpression; use multilinear_extensions::{ Expression, mle::FieldType, smart_slice::SmartSlice, wit_infer_by_expr, }; diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 59af45e1a..9314d9503 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -2,18 +2,13 @@ use std::{array, cmp::Ordering, marker::PhantomData, sync::Arc}; use ff_ext::{ExtensionField, SmallField}; use itertools::{Itertools, chain, iproduct, zip_eq}; -use multilinear_extensions::{ - Expression, ToExpr, WitIn, - mle::{Point, PointAndEval}, - util::ceil_log2, -}; +use multilinear_extensions::{Expression, ToExpr, WitIn, mle::PointAndEval, util::ceil_log2}; use ndarray::{ArrayView, Ix2, Ix3, s}; -use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; -use p3_goldilocks::Goldilocks; +use p3_field::PrimeCharacteristicRing; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use serde::{Deserialize, Serialize}; use sumcheck::util::optimal_sumcheck_threads; -use transcript::BasicTranscript; +use transcript::{BasicTranscript, Transcript}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ @@ -29,8 +24,6 @@ use crate::{ use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; -type E = BinomialExtensionField; - #[derive(Clone, Debug, Serialize, Deserialize)] pub struct KeccakParams {} @@ -51,7 +44,7 @@ pub struct KeccakLayerLayout { #[derive(Clone, Debug, Default, Serialize, Deserialize)] pub struct KeccakLayout { keccak_input8: Vec, - keccak_layers: [KeccakLayerLayout; ROUNDS], + keccak_layers: [KeccakLayerLayout; 1], _marker: PhantomData, } @@ -65,7 +58,7 @@ fn expansion_expr( .fold((0, E::BaseField::ZERO.expr()), |acc, (sz, felt)| { ( acc.0 + sz, - acc.1 * E::BaseField::from_u64((1 << sz) as u64).expr() + felt.expr(), + acc.1 * E::BaseField::from_i64(1 << sz).expr() + felt.expr(), ) }); @@ -160,7 +153,7 @@ impl ConstraintSystem { self.range_lookups.push(CenoLookup::U16(value.clone())); if size < 16 { self.range_lookups.push(CenoLookup::U16( - value * E::BaseField::from_u64((1 << (16 - size)) as u64).expr(), + value * E::BaseField::from_i64(1 << (16 - size)).expr(), )) } } @@ -336,15 +329,15 @@ pub const RANGE_LOOKUPS_PER_ROUND: usize = 290; pub const LOOKUP_FELTS_PER_ROUND: usize = 3 * AND_LOOKUPS_PER_ROUND + 3 * XOR_LOOKUPS_PER_ROUND + RANGE_LOOKUPS_PER_ROUND; -pub const AND_LOOKUPS: usize = ROUNDS * AND_LOOKUPS_PER_ROUND; -pub const XOR_LOOKUPS: usize = ROUNDS * XOR_LOOKUPS_PER_ROUND; -pub const RANGE_LOOKUPS: usize = ROUNDS * RANGE_LOOKUPS_PER_ROUND; +pub const AND_LOOKUPS: usize = AND_LOOKUPS_PER_ROUND; +pub const XOR_LOOKUPS: usize = XOR_LOOKUPS_PER_ROUND; +pub const RANGE_LOOKUPS: usize = RANGE_LOOKUPS_PER_ROUND; pub const KECCAK_OUT_EVAL_SIZE: usize = KECCAK_INPUT_SIZE + KECCAK_OUTPUT_SIZE + LOOKUP_FELTS_PER_ROUND; pub const KECCAK_WIT_SIZE_PER_ROUND: usize = 1264; -pub const KECCAK_WIT_SIZE: usize = KECCAK_WIT_SIZE_PER_ROUND * ROUNDS + KECCAK_LAYER_BYTE_SIZE; +pub const KECCAK_WIT_SIZE: usize = KECCAK_WIT_SIZE_PER_ROUND + KECCAK_LAYER_BYTE_SIZE; #[allow(unused)] macro_rules! allocate_and_split { @@ -449,13 +442,13 @@ impl ProtocolBuilder for KeccakLayout { // TODO it should be at least 2 group. // TODO - group1: lookup one group (due to same tower prover length) // TODO - group2: read/write another group - let (bases, [eq]) = chip.allocate_wits_in_zero_layer::(); + let (bases, [eq]) = chip.allocate_wits_in_zero_layer::(); for (openings, wit) in bases.iter().enumerate() { chip.allocate_opening(openings, wit.1.clone()); } let keccak_input8 = &bases[..KECCAK_LAYER_BYTE_SIZE]; - let keccak_output8 = &bases[KECCAK_WIT_SIZE_PER_ROUND - KECCAK_LAYER_BYTE_SIZE..]; + let keccak_output8 = &bases[KECCAK_WIT_SIZE - KECCAK_LAYER_BYTE_SIZE..]; let mut system = ConstraintSystem::new(); @@ -707,9 +700,9 @@ impl ProtocolBuilder for KeccakLayout { *idx += 1; } - assert!(global_and_lookup == 3 * AND_LOOKUPS); - assert!(global_xor_lookup == 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); - assert!(global_range_lookup == LOOKUP_FELTS_PER_ROUND * ROUNDS); + assert_eq!(global_and_lookup, 3 * AND_LOOKUPS); + assert_eq!(global_xor_lookup, 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS); + assert_eq!(global_range_lookup, LOOKUP_FELTS_PER_ROUND); let keccak_input8: ArrayView<(WitIn, EvalExpression), Ix3> = ArrayView::from_shape((5, 5, 8), keccak_input8).unwrap(); @@ -818,17 +811,20 @@ where // TODO take structural id information from circuit to do wits assignment // 1 instance will derive 24 round result + 8 round padding to pow2 for easiler rotation design - let mut wits = - Vec::with_capacity(KECCAK_WIT_SIZE_PER_ROUND * ROUNDS.next_power_of_two()); + let mut wits = Vec::with_capacity(KECCAK_WIT_SIZE * ROUNDS.next_power_of_two()); let mut push_instance = |new_wits: Vec| { let felts = u64s_to_felts::(new_wits); wits.extend(felts); }; - push_instance(state8.into_iter().flatten().flatten().collect_vec()); - #[allow(clippy::needless_range_loop)] for round in 0..ROUNDS { + if round == 0 { + push_instance(state8.into_iter().flatten().flatten().collect_vec()); + } else { + push_instance(vec![0u64; KECCAK_LAYER_BYTE_SIZE]); + } + let mut c_aux64 = [[0u64; 5]; 5]; let mut c_aux8 = [[[0u64; 8]; 5]; 5]; @@ -959,17 +955,13 @@ where // padding to next_power_of_2 rounds for rotation wits.extend( - (0..(ROUNDS.next_power_of_two() - ROUNDS) * KECCAK_WIT_SIZE_PER_ROUND) + (0..(ROUNDS.next_power_of_two() - ROUNDS) * KECCAK_WIT_SIZE) .map(|_| E::BaseField::ZERO), ); wits }) .collect(); - RowMajorMatrix::new_by_values( - wits, - KECCAK_WIT_SIZE_PER_ROUND, - InstancePaddingStrategy::Default, - ) + RowMajorMatrix::new_by_values(wits, KECCAK_WIT_SIZE, InstancePaddingStrategy::Default) } fn gkr_witness( @@ -978,15 +970,15 @@ where phase1: &RowMajorMatrix, _challenges: &[E], ) -> (GKRCircuitWitness<'a, E>, GKRCircuitOutput) { - // TODO: fix efficient as here as it convert felts back to u64 + // TODO: fix inefficiency as here as it convert felts back to u64 let instances_rounds = phase1 .values .par_iter() .map(|wit| wit.to_canonical_u64()) .collect::>(); - let num_instances_with_rotations = phase1.num_vars(); + let num_instances_with_rotations = 1 << phase1.num_vars(); let num_cols = phase1.n_col(); - assert_eq!(num_cols, KECCAK_WIT_SIZE_PER_ROUND); + assert_eq!(num_cols, KECCAK_WIT_SIZE); let to_5x5x8_array = |input: &[u64]| -> [[[u64; 8]; 5]; 5] { assert_eq!(input.len(), 5 * 5 * 8); @@ -1210,6 +1202,11 @@ where }) .collect(); + assert_eq!( + output_bases.len(), + num_instances_with_rotations * KECCAK_OUT_EVAL_SIZE + ); + let bases = phase1.to_mles().into_iter().map(Arc::new).collect_vec(); let output_bases = RowMajorMatrix::new_by_values( output_bases @@ -1245,15 +1242,16 @@ pub fn setup_gkr_circuit() -> (KeccakLayout, GKRCircuit (layout, chip.gkr_circuit()) } -pub fn run_faster_keccakf( +pub fn run_faster_keccakf( (layout, gkr_circuit): (KeccakLayout, GKRCircuit), states: Vec<[u64; 25]>, verify: bool, test_outputs: bool, ) { let num_instances = states.len(); - let log2_num_instances = ceil_log2(num_instances); - let num_threads = optimal_sumcheck_threads(log2_num_instances); + let num_instances_rounds = num_instances * ROUNDS.next_power_of_two(); + let log2_num_instance_rounds = ceil_log2(num_instances_rounds); + let num_threads = optimal_sumcheck_threads(log2_num_instance_rounds); let mut instances = Vec::with_capacity(num_instances); for state in &states { @@ -1281,8 +1279,12 @@ pub fn run_faster_keccakf( let (gkr_witness, _gkr_output) = layout.gkr_witness(&gkr_circuit, &phase1_witness, &[]); let out_evals = { - let log2_num_instances = num_instances.next_power_of_two().trailing_zeros(); - let point = vec![E::from_u64(29); log2_num_instances as usize] as Point; + let mut point = Vec::with_capacity(log2_num_instance_rounds); + point.extend( + prover_transcript + .sample_vec(log2_num_instance_rounds) + .to_vec(), + ); if test_outputs { // Confront outputs with tiny_keccak::keccakf call @@ -1295,7 +1297,10 @@ pub fn run_faster_keccakf( .iter() .take(KECCAK_OUTPUT_SIZE) { - assert_eq!(base.evaluations().len(), num_instances); + assert_eq!( + base.evaluations().len(), + num_instances * ROUNDS.next_power_of_two() + ); for i in 0..num_instances { instance_outputs[i].push(base.get_base_field_vec()[i]); } @@ -1317,10 +1322,8 @@ pub fn run_faster_keccakf( // } } - let out_evals = gkr_witness - .layers - .last() - .unwrap() + let out_evals = _gkr_output + .0 .bases .iter() .map(|base| PointAndEval { @@ -1334,11 +1337,10 @@ pub fn run_faster_keccakf( out_evals }; - dbg!(&gkr_circuit.layers.len()); let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove( num_threads, - log2_num_instances, + log2_num_instance_rounds, gkr_witness, &out_evals, &[], @@ -1350,9 +1352,17 @@ pub fn run_faster_keccakf( { let mut verifier_transcript = BasicTranscript::::new(b"protocol"); + // This is to make prover/verifier match + let mut _point = Vec::with_capacity(log2_num_instance_rounds); + _point.extend( + verifier_transcript + .sample_vec(log2_num_instance_rounds) + .to_vec(), + ); + gkr_circuit .verify( - log2_num_instances, + log2_num_instance_rounds, gkr_proof, &out_evals, &[], @@ -1368,10 +1378,12 @@ pub fn run_faster_keccakf( #[cfg(test)] mod tests { use super::*; + use ff_ext::GoldilocksExt2; use rand::{Rng, SeedableRng}; #[test] fn test_keccakf() { + type E = GoldilocksExt2; std::thread::Builder::new() .name("keccak_test".into()) .stack_size(64 * 1024 * 1024) @@ -1384,7 +1396,7 @@ mod tests { states.push(std::array::from_fn(|_| rng.gen())); } - run_faster_keccakf(setup_gkr_circuit(), states, false, true); + run_faster_keccakf(setup_gkr_circuit::(), states, false, true); }) .unwrap() .join() @@ -1394,6 +1406,7 @@ mod tests { #[ignore] #[test] fn test_keccakf_nonpow2() { + type E = GoldilocksExt2; std::thread::Builder::new() .name("keccak_test".into()) .stack_size(64 * 1024 * 1024) @@ -1406,7 +1419,7 @@ mod tests { states.push(std::array::from_fn(|_| rng.gen())); } - run_faster_keccakf(setup_gkr_circuit(), states, false, true); + run_faster_keccakf(setup_gkr_circuit::(), states, false, true); }) .unwrap() .join() diff --git a/gkr_iop/src/precompiles/mod.rs b/gkr_iop/src/precompiles/mod.rs index 6b58a4c19..431d6a431 100644 --- a/gkr_iop/src/precompiles/mod.rs +++ b/gkr_iop/src/precompiles/mod.rs @@ -1,9 +1,9 @@ mod bitwise_keccakf; mod lookup_keccakf; mod utils; -pub use bitwise_keccakf::{run_keccakf, setup_gkr_circuit}; +pub use bitwise_keccakf::{run_keccakf, setup_gkr_circuit as setup_keccak_bitwise_circuit}; pub use lookup_keccakf::{ AND_LOOKUPS, AND_LOOKUPS_PER_ROUND, KECCAK_OUT_EVAL_SIZE, KeccakLayout, KeccakParams, KeccakTrace, RANGE_LOOKUPS, RANGE_LOOKUPS_PER_ROUND, XOR_LOOKUPS, XOR_LOOKUPS_PER_ROUND, - run_faster_keccakf, + run_faster_keccakf, setup_gkr_circuit as setup_keccak_lookup_circuit, }; From 8cb1aee7785bf318c43558fe79f991f0a5dadb52 Mon Sep 17 00:00:00 2001 From: Ming Date: Thu, 29 May 2025 14:54:16 +0800 Subject: [PATCH 18/28] optimize extrapolation with zero field inverse during runtime (#956) Extracted from #952. Observe a bottleneck on previous interpolation which contribute to most of time due to `vector.extend` operation and bunch of allocations. This PR rewrite univariate extrapolation 1. as the point to be interpolate are fixed set, we can pre-compute all stuff require field inverse 2. in-place change to avoid allocation In Ceno opcode main sumcheck part we batch different degree > 1 into one batch so this function will be used. It shows a slightly improvement (~3%) on Fibonacci 2^24 e2e | Benchmark | Median Time (s) | Median Change (%) | |----------------------------------|------------------|--------------------| | fibonacci_max_steps_1048576 | 2.3978 | +0.9805% (No significant change ) | | fibonacci_max_steps_2097152 | 4.2579 | +1.7587% (Change within noise) | | fibonacci_max_steps_4194304 | 7.7561 | -3.5338% | --- ceno_zkvm/src/scheme/prover.rs | 4 +- ceno_zkvm/src/scheme/verifier.rs | 7 +- ceno_zkvm/src/utils.rs | 1 - mpcs/src/basefold/commit_phase.rs | 7 +- sumcheck/src/extrapolate.rs | 139 +++++++++++++++++++++++ sumcheck/src/lib.rs | 1 + sumcheck/src/prover.rs | 121 ++++++++++---------- sumcheck/src/structs.rs | 3 - sumcheck/src/util.rs | 176 +++++++++++------------------- 9 files changed, 268 insertions(+), 191 deletions(-) create mode 100644 sumcheck/src/extrapolate.rs diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index 5bda176ee..c5b66ebd8 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -24,7 +24,7 @@ use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterato use sumcheck::{ macros::{entered_span, exit_span}, structs::{IOPProverMessage, IOPProverState}, - util::optimal_sumcheck_threads, + util::{get_challenge_pows, optimal_sumcheck_threads}, }; use transcript::Transcript; use witness::{RowMajorMatrix, next_pow2_instance_padding}; @@ -44,7 +44,7 @@ use crate::{ GKRIOPProvingKey, KeccakGKRIOP, ProvingKey, TowerProver, TowerProverSpec, ZKVMProvingKey, ZKVMWitnesses, }, - utils::{add_mle_list_by_expr, get_challenge_pows}, + utils::add_mle_list_by_expr, }; use multilinear_extensions::Instance; diff --git a/ceno_zkvm/src/scheme/verifier.rs b/ceno_zkvm/src/scheme/verifier.rs index 2094dbba1..0ab145c72 100644 --- a/ceno_zkvm/src/scheme/verifier.rs +++ b/ceno_zkvm/src/scheme/verifier.rs @@ -14,7 +14,10 @@ use multilinear_extensions::{ }; use p3::field::PrimeCharacteristicRing; use std::collections::HashSet; -use sumcheck::structs::{IOPProof, IOPVerifierState}; +use sumcheck::{ + structs::{IOPProof, IOPVerifierState}, + util::get_challenge_pows, +}; use transcript::{ForkableTranscript, Transcript}; use witness::next_pow2_instance_padding; @@ -25,7 +28,7 @@ use crate::{ structs::{ GKRIOPVerifyingKey, KeccakGKRIOP, PointAndEval, TowerProofs, VerifyingKey, ZKVMVerifyingKey, }, - utils::{eq_eval_less_or_equal_than, eval_wellform_address_vec, get_challenge_pows}, + utils::{eq_eval_less_or_equal_than, eval_wellform_address_vec}, }; use multilinear_extensions::{Instance, StructuralWitIn, utils::eval_by_expr_with_instance}; diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index bdc1d5952..37e159669 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -13,7 +13,6 @@ use multilinear_extensions::{ Expression, mle::ArcMultilinearExtension, virtual_polys::VirtualPolynomialsBuilder, }; use p3::field::Field; -use transcript::Transcript; pub fn i64_to_base(x: i64) -> F { if x >= 0 { diff --git a/mpcs/src/basefold/commit_phase.rs b/mpcs/src/basefold/commit_phase.rs index 4b8a4c071..28410448d 100644 --- a/mpcs/src/basefold/commit_phase.rs +++ b/mpcs/src/basefold/commit_phase.rs @@ -242,7 +242,6 @@ where IOPProverState::prover_init_with_extrapolation_aux( thread_id == 0, // set thread_id 0 to be main worker poly, - vec![(vec![], vec![])], Some(log2_num_threads), Some(poly_meta.clone()), ) @@ -282,11 +281,7 @@ where let merge_sumcheck_prover_state_span = entered_span!("merge_sumcheck_prover_state"); let poly = merge_sumcheck_prover_state(&prover_states); let mut prover_states = vec![IOPProverState::prover_init_with_extrapolation_aux( - true, - poly, - vec![(vec![], vec![])], - None, - None, + true, poly, None, None, )]; exit_span!(merge_sumcheck_prover_state_span); diff --git a/sumcheck/src/extrapolate.rs b/sumcheck/src/extrapolate.rs new file mode 100644 index 000000000..0b7d0e5e3 --- /dev/null +++ b/sumcheck/src/extrapolate.rs @@ -0,0 +1,139 @@ +use ff_ext::ExtensionField; +use itertools::Itertools; +use std::{ + any::{Any, TypeId}, + collections::BTreeMap, + marker::PhantomData, + sync::{Arc, Mutex, OnceLock}, +}; + +/// Precomputed extrapolation weights using the second form of barycentric interpolation. +/// +/// This table supports extrapolation of univariate polynomials where: +/// - The known values are at integer points `x = 0, 1, ..., d` +/// - The degree `d` is in a fixed range [`min_degree`, `max_degree`] +/// - A univariate polynomial of degree `d` has exactly `d + 1` evaluation points +/// - The extrapolated values are computed at integer points `z > d`, up to `max_degree` +/// - No field inversions are required at runtime +/// +/// The second form of the barycentric interpolation formula is: +/// +/// ```text +/// L(z) = ∑_{j=0}^d (w_j / (z - x_j)) / ∑_{j=0}^d (w_j / (z - x_j)) * f(x_j) +/// = ∑_{j=0}^d v_j * f(x_j) +/// ``` +/// +/// Where: +/// - `x_j = j` (fixed integer evaluation points) +/// - `w_j = 1 / ∏_{i ≠ j} (x_j - x_i)` are barycentric weights (precomputed) +/// - `v_j = (w_j / (z - x_j)) / denom` are normalized interpolation coefficients (precomputed) +/// +/// This structure stores all `v_j` coefficients for each `(degree, target_z)` pair. +/// At runtime, extrapolation is done by a simple dot product of `v_j` with the known values `f(x_j)`, +/// without needing any inverses. +pub struct ExtrapolationTable { + /// weights[degree][z - degree - 1][j] = coefficient for f(x_j) when extrapolating to z + pub weights: Vec>>, +} + +impl ExtrapolationTable { + pub fn new(min_degree: usize, max_degree: usize) -> Self { + let mut weights = Vec::new(); + + for d in min_degree..=max_degree { + let mut degree_weights = Vec::new(); + + let xs: Vec = (0..=d as u64).map(E::from_u64).collect_vec(); + let mut bary_weights = Vec::new(); + + // Compute barycentric weights w_j = 1 / prod_{i != j} (x_j - x_i) + for j in 0..=d { + let mut w = E::ONE; + for i in 0..=d { + if i != j { + w *= xs[j] - xs[i]; + } + } + bary_weights.push(w.inverse()); // safe because all x_i are distinct + } + + for z_idx in d + 1..=max_degree { + let z = E::from_u64(z_idx as u64); + let mut den = E::ZERO; + let mut tmp: Vec = Vec::with_capacity(d + 1); + + for j in 0..=d { + let t = bary_weights[j] / (z - xs[j]); + tmp.push(t); + den += t; + } + + // Normalize + for t in tmp.iter_mut() { + *t = *t / den; + } + + degree_weights.push(tmp); + } + + weights.push(degree_weights); + } + + Self { weights } + } +} + +pub struct ExtrapolationCache { + _marker: PhantomData, +} + +impl ExtrapolationCache { + fn global_cache() -> &'static Mutex>> { + static GLOBAL_CACHE: OnceLock>>> = + OnceLock::new(); + GLOBAL_CACHE.get_or_init(|| Mutex::new(BTreeMap::new())) + } + + #[allow(clippy::type_complexity)] + fn cache_map() -> Arc>>>> { + let global = Self::global_cache(); + let mut map = global.lock().unwrap(); + + map.entry(TypeId::of::()) + .or_insert_with(|| { + Box::new(Arc::new(Mutex::new(BTreeMap::< + (usize, usize), + Arc>, + >::new()))) as Box + }) + .downcast_ref::>>>>>() + .expect("TypeId mapped to wrong type") + .clone() + } + + /// precompute and cache `ExtrapolationTable`s for all `(min_degree, max_degree)` + /// pairs where `2 ≤ max_degree` and `1 ≤ min_degree < max_degree`. + pub fn warm_up(max_degree: usize) { + assert!(max_degree >= 2, "max_degree must be at least 2"); + + for max in 2..=max_degree { + for min in 1..max { + let _ = Self::get(min, max); + } + } + } + + /// get or create a cached `ExtrapolationTable` for the range `(min_degree, max_degree)`. + pub fn get(min_degree: usize, max_degree: usize) -> Arc> { + let cache = Self::cache_map(); + let mut map = cache.lock().unwrap(); + + if let Some(existing) = map.get(&(min_degree, max_degree)) { + return existing.clone(); + } + + let table = Arc::new(ExtrapolationTable::new(min_degree, max_degree)); + map.insert((min_degree, max_degree), table.clone()); + table + } +} diff --git a/sumcheck/src/lib.rs b/sumcheck/src/lib.rs index 75c4f50ea..630d44b57 100644 --- a/sumcheck/src/lib.rs +++ b/sumcheck/src/lib.rs @@ -1,5 +1,6 @@ #![deny(clippy::cargo)] pub use multilinear_extensions::macros; +pub mod extrapolate; mod prover; pub mod structs; pub mod util; diff --git a/sumcheck/src/prover.rs b/sumcheck/src/prover.rs index 33a4af290..9acfd38de 100644 --- a/sumcheck/src/prover.rs +++ b/sumcheck/src/prover.rs @@ -1,5 +1,6 @@ use std::{mem, sync::Arc}; +use crate::{extrapolate::ExtrapolationCache, util::extrapolate_from_table}; use crossbeam_channel::bounded; use ff_ext::ExtensionField; use itertools::Itertools; @@ -23,8 +24,7 @@ use crate::{ macros::{entered_span, exit_span}, structs::{IOPProof, IOPProverMessage, IOPProverState}, util::{ - AdditiveArray, AdditiveVec, barycentric_weights, ceil_log2, extrapolate, - merge_sumcheck_polys, merge_sumcheck_prover_state, serial_extrapolate, + AdditiveArray, AdditiveVec, ceil_log2, merge_sumcheck_polys, merge_sumcheck_prover_state, }, }; use p3::field::PrimeCharacteristicRing; @@ -34,7 +34,12 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { /// multi-threads model follow https://arxiv.org/pdf/2210.00264#page=8 "distributed sumcheck" /// This is experiment features. It's preferable that we move parallel level up more to /// "bould_poly" so it can be more isolation - #[tracing::instrument(skip_all, name = "sumcheck::prove", level = "trace")] + #[tracing::instrument( + skip_all, + name = "sumcheck::prove", + level = "trace", + fields(profiling_5) + )] pub fn prove( virtual_poly: VirtualPolynomials<'a, E>, transcript: &mut impl Transcript, @@ -58,27 +63,35 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { polys[0].aux_info.max_degree, ); - // extrapolation_aux only need to init once - let extrapolation_aux: Vec<(Vec, Vec)> = (1..max_degree) - .map(|degree| { - let points = (0..1 + degree as u64).map(E::from_u64).collect::>(); - let weights = barycentric_weights(&points); - (points, weights) + let min_degree = polys[0] + .products + .iter() + .flat_map(|monomial_terms| { + monomial_terms + .terms + .iter() + .map(|Term { product, .. }| product.len()) }) - .collect::>(); + .min() + .unwrap(); + if min_degree < max_degree { + // warm up cache giving min/max_degree + let _ = ExtrapolationCache::::get(min_degree, max_degree); + } transcript.append_message(&(num_variables + log2_max_thread_id).to_le_bytes()); transcript.append_message(&max_degree.to_le_bytes()); let (phase1_point, mut prover_state, mut prover_msgs) = if num_variables > 0 { + let span = entered_span!("phase1_sumcheck", profiling_6 = true); let (mut prover_states, prover_msgs) = Self::phase1_sumcheck( max_thread_id, num_variables, - extrapolation_aux.clone(), poly_meta, polys, max_degree, transcript, ); + exit_span!(span); if log2_max_thread_id == 0 { let prover_state = mem::take(&mut prover_states[0]); return ( @@ -93,22 +106,17 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { prover_state, ); } + let span = entered_span!("merged_poly", profiling_6 = true); let point = prover_states[0] .challenges .iter() .map(|c| c.elements) .collect_vec(); let poly = merge_sumcheck_prover_state(&prover_states); - + exit_span!(span); ( point, - Self::prover_init_with_extrapolation_aux( - true, - poly, - extrapolation_aux.clone(), - None, - None, - ), + Self::prover_init_with_extrapolation_aux(true, poly, None, None), prover_msgs, ) } else { @@ -117,7 +125,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { Self::prover_init_with_extrapolation_aux( true, merge_sumcheck_polys(polys.iter().collect_vec(), Some(poly_meta)), - extrapolation_aux.clone(), None, None, ), @@ -126,7 +133,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { }; let mut challenge = None; - let span = entered_span!("prove_rounds_stage2"); + let span = entered_span!("prove_rounds_stage2", profiling_6 = true); for _ in 0..log2_max_thread_id { let prover_msg = IOPProverState::prove_round_and_update_state(&mut prover_state, &challenge); @@ -163,7 +170,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { fn phase1_sumcheck( max_thread_id: usize, num_variables: usize, - extrapolation_aux: Vec<(Vec, Vec)>, poly_meta: Vec, mut polys: Vec>, max_degree: usize, @@ -183,7 +189,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let mut prover_state = Self::prover_init_with_extrapolation_aux( false, mem::take(poly), - extrapolation_aux.clone(), Some(log2_max_thread_id), Some(poly_meta.clone()), ); @@ -194,7 +199,7 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { // Note: This span is not nested into the "spawn loop" span, although lexically it looks so. // Nesting is possible, but then `tracing-forest` does the wrong thing when measuring duration. // TODO: investigate possibility of nesting with correct duration of parent span - let span = entered_span!("prove_rounds", profiling_5 = true); + let span = entered_span!("prove_rounds"); for _ in 0..num_variables { let prover_msg = IOPProverState::prove_round_and_update_state( &mut prover_state, @@ -228,7 +233,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let mut prover_state = Self::prover_init_with_extrapolation_aux( true, mem::take(&mut polys[main_thread_id]), - extrapolation_aux.clone(), Some(log2_max_thread_id), Some(poly_meta.clone()), ); @@ -316,7 +320,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { pub fn prover_init_with_extrapolation_aux( is_main_worker: bool, polynomial: VirtualPolynomial<'a, E>, - extrapolation_aux: Vec<(Vec, Vec)>, phase2_numvar: Option, poly_meta: Option>, ) -> Self { @@ -334,8 +337,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { } exit_span!(start); - let max_degree = polynomial.aux_info.max_degree; - assert!(extrapolation_aux.len() == max_degree - 1); let num_polys = polynomial.flattened_ml_extensions.len(); Self { @@ -344,7 +345,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, - extrapolation_aux, poly_meta: poly_meta.unwrap_or_else(|| vec![PolyMeta::Normal; num_polys]), phase2_numvar, } @@ -412,7 +412,8 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let f = &self.poly.flattened_ml_extensions; let f_type = &self.poly_meta; let get_poly_meta = || f_type[prod[0]]; - let mut uni_variate: Vec = match prod.len() { + let mut uni_variate: Vec = vec![E::ZERO; self.poly.aux_info.max_degree + 1]; + let uni_variate_monomial: Vec = match prod.len() { 1 => sumcheck_code_gen!(1, false, |i| &f[prod[i]], || get_poly_meta()) .to_vec(), 2 => sumcheck_code_gen!(2, false, |i| &f[prod[i]], || get_poly_meta()) @@ -430,16 +431,19 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { uni_variate .iter_mut() - .for_each(|sum| either::for_both!(scalar, scalar => *sum *= *scalar)); - - let extrapolation = (0..self.poly.aux_info.max_degree - prod.len()) - .map(|i| { - let (points, weights) = &self.extrapolation_aux[prod.len() - 1]; - let at = E::from_u64((prod.len() + 1 + i) as u64); - serial_extrapolate(points, weights, &uni_variate, &at) - }) - .collect::>(); - uni_variate.extend(extrapolation); + .zip(uni_variate_monomial) + .take(prod.len() + 1) + .for_each(|(eval, monimial_eval,)| either::for_both!(scalar, scalar => *eval = monimial_eval**scalar)); + + + if prod.len() < self.poly.aux_info.max_degree { + // Perform extrapolation using the precomputed extrapolation table + extrapolate_from_table( + &mut uni_variate, + prod.len() + 1, + ); + } + uni_polys += AdditiveVec(uni_variate); } uni_polys @@ -589,7 +593,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { "Attempt to prove a constant." ); - let max_degree = polynomial.aux_info.max_degree; let num_polys = polynomial.flattened_ml_extensions.len(); let poly_meta = vec![PolyMeta::Normal; num_polys]; let prover_state = Self { @@ -598,13 +601,6 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { challenges: Vec::with_capacity(polynomial.aux_info.max_num_variables), round: 0, poly: polynomial, - extrapolation_aux: (1..max_degree) - .map(|degree| { - let points = (0..1 + degree as u64).map(E::from_u64).collect::>(); - let weights = barycentric_weights(&points); - (points, weights) - }) - .collect(), poly_meta, phase2_numvar: None, }; @@ -675,7 +671,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { let f = &self.poly.flattened_ml_extensions; let f_type = &self.poly_meta; let get_poly_meta = || f_type[prod[0]]; - let mut sum: Vec = match prod.len() { + let mut uni_variate: Vec = + vec![E::ZERO; self.poly.aux_info.max_degree + 1]; + let uni_variate_monomial: Vec = match prod.len() { 1 => sumcheck_code_gen!(1, true, |i| &f[prod[i]], || get_poly_meta()) .to_vec(), 2 => sumcheck_code_gen!(2, true, |i| &f[prod[i]], || get_poly_meta()) @@ -690,19 +688,18 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> { .to_vec(), _ => unimplemented!("do not support degree {} > 6", prod.len()), }; - sum.iter_mut() - .for_each(|sum| either::for_both!(*scalar, scalar => *sum *= scalar)); - - let extrapolation = (0..self.poly.aux_info.max_degree - prod.len()) - .into_par_iter() - .map(|i| { - let (points, weights) = &self.extrapolation_aux[prod.len() - 1]; - let at = E::from_u64((prod.len() + 1 + i) as u64); - extrapolate(points, weights, &sum, &at) - }) - .collect::>(); - sum.extend(extrapolation); - uni_polys += AdditiveVec(sum); + uni_variate + .iter_mut() + .zip(uni_variate_monomial) + .take(prod.len() + 1) + .for_each(|(eval, monimial_eval,)| either::for_both!(scalar, scalar => *eval = monimial_eval**scalar)); + + + if prod.len() < self.poly.aux_info.max_degree { + // Perform extrapolation using the precomputed extrapolation table + extrapolate_from_table(&mut uni_variate, prod.len() + 1); + } + uni_polys += AdditiveVec(uni_variate); } uni_polys }, diff --git a/sumcheck/src/structs.rs b/sumcheck/src/structs.rs index cc4a3e05e..1f84da32c 100644 --- a/sumcheck/src/structs.rs +++ b/sumcheck/src/structs.rs @@ -45,9 +45,6 @@ pub struct IOPProverState<'a, E: ExtensionField> { pub(crate) round: usize, /// pointer to the virtual polynomial pub(crate) poly: VirtualPolynomial<'a, E>, - /// points with precomputed barycentric weights for extrapolating smaller - /// degree uni-polys to `max_degree + 1` evaluations. - pub(crate) extrapolation_aux: Vec<(Vec, Vec)>, pub(crate) max_num_variables: usize, pub(crate) poly_meta: Vec, /// phase 1 and phase 2 sumcheck we share similar implementation diff --git a/sumcheck/src/util.rs b/sumcheck/src/util.rs index 9c958eabc..ae1460ead 100644 --- a/sumcheck/src/util.rs +++ b/sumcheck/src/util.rs @@ -1,6 +1,5 @@ use std::{ array, - cmp::max, iter::Sum, ops::{Add, AddAssign, Deref, DerefMut, Mul, MulAssign}, sync::Arc, @@ -17,130 +16,44 @@ use multilinear_extensions::{ virtual_polys::PolyMeta, }; use p3::field::Field; -use rayon::{prelude::ParallelIterator, slice::ParallelSliceMut}; use transcript::Transcript; -use crate::structs::IOPProverState; +use crate::{extrapolate::ExtrapolationCache, structs::IOPProverState}; -pub fn barycentric_weights(points: &[F]) -> Vec { - let mut weights = points - .iter() - .enumerate() - .map(|(j, point_j)| { - points - .iter() - .enumerate() - .filter(|&(i, _)| (i != j)) - .map(|(_, point_i)| *point_j - *point_i) - .reduce(|acc, value| acc * value) - .unwrap_or(F::ONE) - }) - .collect::>(); - batch_inversion(&mut weights); - weights -} - -// Computes the inverse of each field element in a vector {v_i} using a parallelized batch inversion. -pub fn batch_inversion(v: &mut [F]) { - batch_inversion_and_mul(v, &F::ONE); -} +/// extrapolates values of a univariate polynomial in-place using precomputed barycentric weights. +/// +/// this function fills in the remaining entries of `uni_variate[start..]` assuming the first `start` +/// values are evaluations of a univariate polynomial at `0, 1, ..., start - 1`. +/// it uses a precomputed [`ExtrapolationTable`] from [`ExtrapolationCache`] to perform +/// efficient barycentric extrapolation without requiring any inverse operations at runtime. +/// +/// Note: this function is highly optimized without field inverse. see [`ExtrapolationTable`] for how to achieve it +pub fn extrapolate_from_table(uni_variate: &mut [E], start: usize) { + let cur_degree = start - 1; + let table = ExtrapolationCache::::get(cur_degree, uni_variate.len() - 1); + let target_len = uni_variate.len(); + assert!(start > 0, "start must be > 0 to define a degree"); + assert!( + target_len > start, + "no extrapolation needed if target_len <= start" + ); -// Computes the inverse of each field element in a vector {v_i} sequentially (serial version). -pub fn serial_batch_inversion(v: &mut [F]) { - serial_batch_inversion_and_mul(v, &F::ONE) -} + let (known, to_extrapolate) = uni_variate.split_at_mut(start); + let weight_sets = &table.weights[0]; // since min_degree == cur_degree -// Given a vector of field elements {v_i}, compute the vector {coeff * v_i^(-1)} -pub fn batch_inversion_and_mul(v: &mut [F], coeff: &F) { - // Divide the vector v evenly between all available cores - let min_elements_per_thread = 1; - let num_cpus_available = rayon::current_num_threads(); - let num_elems = v.len(); - let num_elem_per_thread = max(num_elems / num_cpus_available, min_elements_per_thread); - - // Batch invert in parallel, without copying the vector - v.par_chunks_mut(num_elem_per_thread).for_each(|chunk| { - serial_batch_inversion_and_mul(chunk, coeff); - }); -} + for (offset, target) in to_extrapolate.iter_mut().enumerate() { + let weights = &weight_sets[offset]; + assert_eq!(weights.len(), known.len()); -/// Given a vector of field elements {v_i}, compute the vector {coeff * v_i^(-1)}. -/// This method is explicitly single-threaded. -fn serial_batch_inversion_and_mul(v: &mut [F], coeff: &F) { - // Montgomery’s Trick and Fast Implementation of Masked AES - // Genelle, Prouff and Quisquater - // Section 3.2 - // but with an optimization to multiply every element in the returned vector by - // coeff - - // First pass: compute [a, ab, abc, ...] - let mut prod = Vec::with_capacity(v.len()); - let mut tmp = F::ONE; - for f in v.iter().filter(|f| !f.is_zero()) { - tmp.mul_assign(*f); - prod.push(tmp); - } + let acc = weights + .iter() + .zip(known.iter()) + .fold(E::ZERO, |acc, (w, x)| acc + (*w * *x)); - // Invert `tmp`. - tmp = tmp.try_inverse().unwrap(); // Guaranteed to be nonzero. - - // Multiply product by coeff, so all inverses will be scaled by coeff - tmp *= *coeff; - - // Second pass: iterate backwards to compute inverses - for (f, s) in v - .iter_mut() - // Backwards - .rev() - // Ignore normalized elements - .filter(|f| !f.is_zero()) - // Backwards, skip last element, fill in one for last term. - .zip(prod.into_iter().rev().skip(1).chain(Some(F::ONE))) - { - // tmp := tmp * f; f := tmp * s = 1/f - let new_tmp = tmp * *f; - *f = tmp * s; - tmp = new_tmp; + *target = acc; } } -pub(crate) fn extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { - inner_extrapolate::(points, weights, evals, at) -} - -pub(crate) fn serial_extrapolate(points: &[F], weights: &[F], evals: &[F], at: &F) -> F { - inner_extrapolate::(points, weights, evals, at) -} - -fn inner_extrapolate( - points: &[F], - weights: &[F], - evals: &[F], - at: &F, -) -> F { - let (coeffs, sum_inv) = { - let mut coeffs = points.iter().map(|point| *at - *point).collect::>(); - if IS_PARALLEL { - batch_inversion(&mut coeffs); - } else { - serial_batch_inversion(&mut coeffs); - } - let mut sum = F::ZERO; - coeffs.iter_mut().zip(weights).for_each(|(coeff, weight)| { - *coeff *= *weight; - sum += *coeff - }); - let sum_inv = sum.try_inverse().unwrap_or(F::ZERO); - (coeffs, sum_inv) - }; - coeffs - .iter() - .zip(evals) - .map(|(coeff, eval)| *coeff * *eval) - .sum::() - * sum_inv -} - /// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this /// polynomial at `eval_at`: /// @@ -444,3 +357,36 @@ impl Mul for AdditiveVec { self } } +#[cfg(test)] +mod tests { + use super::*; + use ff_ext::GoldilocksExt2; + use p3::field::PrimeCharacteristicRing; + + #[test] + fn test_extrapolate_from_table() { + type E = GoldilocksExt2; + fn f(x: u64) -> E { + E::from_u64(2u64) * E::from_u64(x) + E::from_u64(3u64) + } + // Test a known linear polynomial: f(x) = 2x + 3 + + let degree = 1; + let target_len = 5; // Extrapolate up to x=4 + + // Known values at x=0 and x=1 + let mut values: Vec = (0..=degree as u64).map(f).collect(); + + // Allocate extra space for extrapolated values + values.resize(target_len, E::ZERO); + + // Run extrapolation + extrapolate_from_table(&mut values, degree + 1); + + // Verify values against f(x) + for (x, val) in values.iter().enumerate() { + let expected = f(x as u64); + assert_eq!(*val, expected, "Mismatch at x={}", x); + } + } +} From 048e03cc28b31bb6553845741450dcb2507f409d Mon Sep 17 00:00:00 2001 From: Ming Date: Tue, 20 May 2025 12:09:58 +0800 Subject: [PATCH 19/28] add jemallocator as optional global allocator (#946) benchmark shows there are quite of time spending on glibc free (drop) when object end of its scopes. Follow openvm using [jemalloc](https://github.com/openvm-org/openvm/blob/c771a213f5e7f0732e0ddbafb273e15d99c5049d/crates/vm/Cargo.toml#L56) as global allocators. and set jemalloc parameter follows https://github.com/openvm-org/openvm/blob/c771a213f5e7f0732e0ddbafb273e15d99c5049d/.github/workflows/benchmark-call.yml#L218 > I do not use jemalloc "background_thread: true" as I thought thread in background might occupied other schedule which affect cpu intensive program ### change scope - enable jemalloc by default when compiling ceno_cli - support `cargo make cli` to install ceno_cli - introduce "jemalloc" features ### benchmark benchmark on AMD EPYC 32 cores with command `JEMALLOC_SYS_WITH_MALLOC_CONF="retain:true,metadata_thp:always,thp:always,dirty_decay_ms:-1,muzzy_decay_ms:-1,abort_conf:true" cargo bench --bench fibonacci --features jemalloc --package ceno_zkvm -- --baseline opt-baseline` | Benchmark | Average Time | Improvement | Throughput (instructions/sec) | |-----------------|--------------|-------------|---------------------------| | fibonacci 2^20 | 2.0020 s | -14.74% | 523.76k | | fibonacci 2^21 | 3.5903 s | -18.89% | 584.34k | | fibonacci 2^22 | 6.6531 s | -24.69% | 630.28k | --------- Co-authored-by: Zhang Zhuo --- .github/workflows/integration.yml | 8 ++++++ .github/workflows/lints.yml | 1 + Cargo.lock | 34 ++++++++++++++++++++++++++ Makefile.toml | 14 +++++++++++ ceno_cli/Cargo.toml | 5 ++++ ceno_cli/src/main.rs | 11 +++++++++ ceno_zkvm/Cargo.toml | 6 +++++ ceno_zkvm/benches/alloc.rs | 4 +++ ceno_zkvm/benches/fibonacci.rs | 1 + ceno_zkvm/benches/fibonacci_witness.rs | 1 + ceno_zkvm/benches/is_prime.rs | 1 + ceno_zkvm/benches/quadratic_sorting.rs | 1 + ceno_zkvm/benches/riscv_add.rs | 1 + ceno_zkvm/src/bin/e2e.rs | 12 +++++++++ ceno_zkvm/src/lib.rs | 2 ++ ceno_zkvm/src/utils.rs | 13 ++++++++++ 16 files changed, 115 insertions(+) create mode 100644 ceno_zkvm/benches/alloc.rs diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 715e80444..b275433f4 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -58,3 +58,11 @@ jobs: env: RUSTFLAGS: "-C opt-level=3" run: cargo run --release --package ceno_zkvm --bin e2e -- --platform=ceno --hints=10 --public-io=4191 examples/target/riscv32im-ceno-zkvm-elf/release/examples/fibonacci + + - name: Install cargo make + run: | + cargo make --version || cargo install cargo-make + + - name: Test install Ceno cli + run: | + cargo make cli diff --git a/.github/workflows/lints.yml b/.github/workflows/lints.yml index b8423b497..1c91eeb7a 100644 --- a/.github/workflows/lints.yml +++ b/.github/workflows/lints.yml @@ -43,6 +43,7 @@ jobs: - name: Install cargo make run: | cargo make --version || cargo install cargo-make + - name: Check code format run: cargo fmt --all --check diff --git a/Cargo.lock b/Cargo.lock index c6ebbe73c..4a843a701 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -344,6 +344,7 @@ dependencies = [ "parse-size", "serde", "tempfile", + "tikv-jemallocator", "tracing", "tracing-forest", "tracing-subscriber", @@ -481,6 +482,8 @@ dependencies = [ "sumcheck", "tempfile", "thread_local", + "tikv-jemalloc-ctl", + "tikv-jemallocator", "tiny-keccak", "tracing", "tracing-forest", @@ -3044,6 +3047,37 @@ dependencies = [ "once_cell", ] +[[package]] +name = "tikv-jemalloc-ctl" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f21f216790c8df74ce3ab25b534e0718da5a1916719771d3fec23315c99e468b" +dependencies = [ + "libc", + "paste", + "tikv-jemalloc-sys", +] + +[[package]] +name = "tikv-jemalloc-sys" +version = "0.6.0+5.3.0-1-ge13ca993e8ccb9ba9847cc330696e02839f328f7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd3c60906412afa9c2b5b5a48ca6a5abe5736aec9eb48ad05037a677e52e4e2d" +dependencies = [ + "cc", + "libc", +] + +[[package]] +name = "tikv-jemallocator" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cec5ff18518d81584f477e9bfdf957f5bb0979b0bac3af4ca30b5b3ae2d2865" +dependencies = [ + "libc", + "tikv-jemalloc-sys", +] + [[package]] name = "time" version = "0.3.41" diff --git a/Makefile.toml b/Makefile.toml index 64563086a..15dd542ea 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -30,3 +30,17 @@ args = [ ] command = "cargo" workspace = false + +[tasks.cli] +args = [ + "install", + "--features", + "jemalloc", + "--features", + "nightly-features", + "--path", + "./ceno_cli", +] +command = "cargo" +env = { "JEMALLOC_SYS_WITH_MALLOC_CONF" = "retain:true,metadata_thp:always,thp:always,dirty_decay_ms:-1,muzzy_decay_ms:-1,abort_conf:true" } +workspace = false diff --git a/ceno_cli/Cargo.toml b/ceno_cli/Cargo.toml index d9ecd4dc3..8ba753539 100644 --- a/ceno_cli/Cargo.toml +++ b/ceno_cli/Cargo.toml @@ -23,6 +23,9 @@ tracing.workspace = true tracing-forest.workspace = true tracing-subscriber.workspace = true +[target.'cfg(unix)'.dependencies] +tikv-jemallocator = { version = "0.6", optional = true } + ceno_emul = { path = "../ceno_emul" } ceno_host = { path = "../ceno_host" } ceno_zkvm = { path = "../ceno_zkvm" } @@ -33,6 +36,8 @@ mpcs = { path = "../mpcs" } vergen-git2 = { version = "1", features = ["build", "cargo", "rustc", "emit_and_set"] } [features] +jemalloc = ["dep:tikv-jemallocator", "ceno_zkvm/jemalloc"] +jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ "ceno_zkvm/nightly-features", "ff_ext/nightly-features", diff --git a/ceno_cli/src/main.rs b/ceno_cli/src/main.rs index 9d9363981..df7390590 100644 --- a/ceno_cli/src/main.rs +++ b/ceno_cli/src/main.rs @@ -1,10 +1,17 @@ use crate::{commands::*, utils::*}; use anyhow::Context; +#[cfg(all(feature = "jemalloc", unix, not(test)))] +use ceno_zkvm::print_allocated_bytes; use clap::{Args, Parser, Subcommand}; mod commands; mod utils; +// Use jemalloc as global allocator for performance +#[cfg(all(feature = "jemalloc", unix, not(test)))] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + const CENO_VERSION: &str = env!("CENO_VERSION"); #[derive(Parser)] @@ -86,4 +93,8 @@ fn main() { print_error(e); std::process::exit(1); } + #[cfg(all(feature = "jemalloc", unix, not(test)))] + { + print_allocated_bytes(); + } } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index ebcf5d6bc..ae80bf4fb 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -51,6 +51,10 @@ tempfile = "3.14" thread_local = "1.1" tiny-keccak.workspace = true +[target.'cfg(unix)'.dependencies] +tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true } +tikv-jemallocator = { version = "0.6", optional = true } + [dev-dependencies] cfg-if.workspace = true criterion.workspace = true @@ -65,6 +69,8 @@ glob = "0.3" default = ["forbid_overflow"] flamegraph = ["pprof2/flamegraph", "pprof2/criterion"] forbid_overflow = [] +jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] +jemalloc-prof = ["jemalloc", "tikv-jemallocator?/profiling"] nightly-features = [ "p3/nightly-features", "ff_ext/nightly-features", diff --git a/ceno_zkvm/benches/alloc.rs b/ceno_zkvm/benches/alloc.rs new file mode 100644 index 000000000..581bb2ff0 --- /dev/null +++ b/ceno_zkvm/benches/alloc.rs @@ -0,0 +1,4 @@ +// Use jemalloc as global allocator for performance +#[cfg(all(feature = "jemalloc", unix))] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; diff --git a/ceno_zkvm/benches/fibonacci.rs b/ceno_zkvm/benches/fibonacci.rs index 4f43130e7..350113368 100644 --- a/ceno_zkvm/benches/fibonacci.rs +++ b/ceno_zkvm/benches/fibonacci.rs @@ -7,6 +7,7 @@ use ceno_zkvm::{ e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, scheme::{constants::MAX_NUM_VARIABLES, verifier::ZKVMVerifier}, }; +mod alloc; use criterion::*; use ff_ext::GoldilocksExt2; diff --git a/ceno_zkvm/benches/fibonacci_witness.rs b/ceno_zkvm/benches/fibonacci_witness.rs index beb06bd76..54443faf7 100644 --- a/ceno_zkvm/benches/fibonacci_witness.rs +++ b/ceno_zkvm/benches/fibonacci_witness.rs @@ -7,6 +7,7 @@ use ceno_zkvm::{ e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, scheme::constants::MAX_NUM_VARIABLES, }; +mod alloc; use criterion::*; use ff_ext::GoldilocksExt2; use mpcs::{BasefoldDefault, SecurityLevel}; diff --git a/ceno_zkvm/benches/is_prime.rs b/ceno_zkvm/benches/is_prime.rs index 21ec18071..5c40f6fa4 100644 --- a/ceno_zkvm/benches/is_prime.rs +++ b/ceno_zkvm/benches/is_prime.rs @@ -7,6 +7,7 @@ use ceno_zkvm::{ e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, scheme::constants::MAX_NUM_VARIABLES, }; +mod alloc; use criterion::*; use ff_ext::GoldilocksExt2; use mpcs::{BasefoldDefault, SecurityLevel}; diff --git a/ceno_zkvm/benches/quadratic_sorting.rs b/ceno_zkvm/benches/quadratic_sorting.rs index fa53c5ce9..643f0c631 100644 --- a/ceno_zkvm/benches/quadratic_sorting.rs +++ b/ceno_zkvm/benches/quadratic_sorting.rs @@ -7,6 +7,7 @@ use ceno_zkvm::{ e2e::{Checkpoint, Preset, run_e2e_with_checkpoint, setup_platform}, scheme::constants::MAX_NUM_VARIABLES, }; +mod alloc; use criterion::*; use ff_ext::GoldilocksExt2; use mpcs::{BasefoldDefault, SecurityLevel}; diff --git a/ceno_zkvm/benches/riscv_add.rs b/ceno_zkvm/benches/riscv_add.rs index 0a1f920d9..761bf743e 100644 --- a/ceno_zkvm/benches/riscv_add.rs +++ b/ceno_zkvm/benches/riscv_add.rs @@ -6,6 +6,7 @@ use ceno_zkvm::{ scheme::prover::ZKVMProver, structs::{ZKVMConstraintSystem, ZKVMFixedTraces}, }; +mod alloc; use criterion::*; use ceno_zkvm::scheme::constants::MAX_NUM_VARIABLES; diff --git a/ceno_zkvm/src/bin/e2e.rs b/ceno_zkvm/src/bin/e2e.rs index 19eb18afc..b32ab9ec0 100644 --- a/ceno_zkvm/src/bin/e2e.rs +++ b/ceno_zkvm/src/bin/e2e.rs @@ -1,5 +1,7 @@ use ceno_emul::{IterAddresses, Platform, Program, WORD_SIZE, Word}; use ceno_host::{CenoStdin, memory_from_file}; +#[cfg(all(feature = "jemalloc", unix, not(test)))] +use ceno_zkvm::print_allocated_bytes; use ceno_zkvm::{ e2e::{ Checkpoint, FieldType, PcsKind, Preset, run_e2e_with_checkpoint, setup_platform, @@ -26,6 +28,11 @@ use tracing_subscriber::{ }; use transcript::BasicTranscript as Transcript; +// Use jemalloc as global allocator for performance +#[cfg(all(feature = "jemalloc", unix, not(test)))] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + fn parse_size(s: &str) -> Result { parse_size::Config::new() .with_binary() @@ -272,6 +279,11 @@ fn main() { Checkpoint::PrepVerify, // FIXME: when whir and babybear is ready ) } + }; + + #[cfg(all(feature = "jemalloc", unix, not(test)))] + { + print_allocated_bytes(); } } diff --git a/ceno_zkvm/src/lib.rs b/ceno_zkvm/src/lib.rs index 12fa66161..ed572de3a 100644 --- a/ceno_zkvm/src/lib.rs +++ b/ceno_zkvm/src/lib.rs @@ -20,6 +20,8 @@ pub mod stats; pub mod structs; mod uint; mod utils; +#[cfg(all(feature = "jemalloc", unix, not(test)))] +pub use utils::print_allocated_bytes; mod witness; pub use structs::ROMType; diff --git a/ceno_zkvm/src/utils.rs b/ceno_zkvm/src/utils.rs index 37e159669..acf725417 100644 --- a/ceno_zkvm/src/utils.rs +++ b/ceno_zkvm/src/utils.rs @@ -274,6 +274,19 @@ pub fn add_mle_list_by_expr<'a, E: ExtensionField>( .collect::>() } +#[cfg(all(feature = "jemalloc", unix, not(test)))] +pub fn print_allocated_bytes() { + use tikv_jemalloc_ctl::{epoch, stats}; + + // Advance the epoch to refresh the stats + let e = epoch::mib().unwrap(); + e.advance().unwrap(); + + // Read allocated bytes + let allocated = stats::allocated::read().unwrap(); + tracing::info!("jemalloc total allocated bytes: {}", allocated); +} + #[cfg(test)] mod tests { use ff_ext::GoldilocksExt2; From 8e0d7f3ee18a8a647fe0156b1ec1c7fb247088b3 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 15:56:57 +0800 Subject: [PATCH 20/28] jemalloc in gkr-iop --- Cargo.lock | 2 ++ gkr_iop/Cargo.toml | 7 +++++++ gkr_iop/benches/alloc.rs | 4 ++++ gkr_iop/benches/bitwise_keccakf.rs | 1 + gkr_iop/benches/lookup_keccakf.rs | 1 + 5 files changed, 15 insertions(+) create mode 100644 gkr_iop/benches/alloc.rs diff --git a/Cargo.lock b/Cargo.lock index 4a843a701..6aa94fd14 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1129,6 +1129,8 @@ dependencies = [ "subprotocols", "sumcheck", "thiserror 1.0.69", + "tikv-jemalloc-ctl", + "tikv-jemallocator", "tiny-keccak", "tracing", "tracing-forest", diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index a92dc48d0..d0dee4a7b 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -32,6 +32,10 @@ tracing.workspace = true tracing-subscriber.workspace = true witness = { path = "../witness" } +[target.'cfg(unix)'.dependencies] +tikv-jemalloc-ctl = { version = "0.6", features = ["stats"], optional = true } +tikv-jemallocator = { version = "0.6", optional = true } + [dev-dependencies] criterion.workspace = true @@ -42,3 +46,6 @@ name = "bitwise_keccakf" [[bench]] harness = false name = "lookup_keccakf" + +[features] +jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] \ No newline at end of file diff --git a/gkr_iop/benches/alloc.rs b/gkr_iop/benches/alloc.rs new file mode 100644 index 000000000..581bb2ff0 --- /dev/null +++ b/gkr_iop/benches/alloc.rs @@ -0,0 +1,4 @@ +// Use jemalloc as global allocator for performance +#[cfg(all(feature = "jemalloc", unix))] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; diff --git a/gkr_iop/benches/bitwise_keccakf.rs b/gkr_iop/benches/bitwise_keccakf.rs index 34bfeaaba..0b760f852 100644 --- a/gkr_iop/benches/bitwise_keccakf.rs +++ b/gkr_iop/benches/bitwise_keccakf.rs @@ -5,6 +5,7 @@ use ff_ext::GoldilocksExt2; use gkr_iop::precompiles::{run_keccakf, setup_keccak_bitwise_circuit}; use itertools::Itertools; use rand::{Rng, SeedableRng}; +mod alloc; criterion_group!(benches, keccak_f_fn); criterion_main!(benches); diff --git a/gkr_iop/benches/lookup_keccakf.rs b/gkr_iop/benches/lookup_keccakf.rs index 9aa6f78bf..9e1c5309d 100644 --- a/gkr_iop/benches/lookup_keccakf.rs +++ b/gkr_iop/benches/lookup_keccakf.rs @@ -6,6 +6,7 @@ use gkr_iop::precompiles::{run_faster_keccakf, setup_keccak_lookup_circuit}; use itertools::Itertools; use rand::{Rng, SeedableRng}; +mod alloc; criterion_group!(benches, keccak_f_fn); criterion_main!(benches); From 9ce5ad98b48128e4b7e8f04269d4ec474db7010b Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 15:59:42 +0800 Subject: [PATCH 21/28] chores: bench --- gkr_iop/benches/lookup_keccakf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gkr_iop/benches/lookup_keccakf.rs b/gkr_iop/benches/lookup_keccakf.rs index 9e1c5309d..79fada463 100644 --- a/gkr_iop/benches/lookup_keccakf.rs +++ b/gkr_iop/benches/lookup_keccakf.rs @@ -14,7 +14,7 @@ const NUM_SAMPLES: usize = 10; fn keccak_f_fn(c: &mut Criterion) { // Benchmark the proving time - for log_instances in 10..12 { + for log_instances in 12..14 { let num_instance = 1 << log_instances; // expand more input size once runtime is acceptable let mut group = c.benchmark_group(format!("keccak_lookup_f_{}", num_instance)); From 473aeb5288ccdc5e28bcc645a532eef090043efb Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 16:13:12 +0800 Subject: [PATCH 22/28] temporarily extract subprotocols to another PR --- Cargo.lock | 20 - Cargo.toml | 2 - ceno_zkvm/Cargo.toml | 1 - ceno_zkvm/src/scheme/prover.rs | 9 +- gkr_iop/Cargo.toml | 1 - subprotocols/Cargo.toml | 30 -- subprotocols/benches/expr_based_logup.rs | 144 ------- subprotocols/examples/zerocheck_logup.rs | 93 ----- subprotocols/src/error.rs | 9 - subprotocols/src/expression.rs | 167 -------- subprotocols/src/expression/evaluate.rs | 459 --------------------- subprotocols/src/expression/macros.rs | 100 ----- subprotocols/src/expression/op.rs | 81 ---- subprotocols/src/lib.rs | 9 - subprotocols/src/points.rs | 75 ---- subprotocols/src/sumcheck.rs | 454 --------------------- subprotocols/src/test_utils.rs | 46 --- subprotocols/src/utils.rs | 235 ----------- subprotocols/src/zerocheck.rs | 495 ----------------------- 19 files changed, 5 insertions(+), 2425 deletions(-) delete mode 100644 subprotocols/Cargo.toml delete mode 100644 subprotocols/benches/expr_based_logup.rs delete mode 100644 subprotocols/examples/zerocheck_logup.rs delete mode 100644 subprotocols/src/error.rs delete mode 100644 subprotocols/src/expression.rs delete mode 100644 subprotocols/src/expression/evaluate.rs delete mode 100644 subprotocols/src/expression/macros.rs delete mode 100644 subprotocols/src/expression/op.rs delete mode 100644 subprotocols/src/lib.rs delete mode 100644 subprotocols/src/points.rs delete mode 100644 subprotocols/src/sumcheck.rs delete mode 100644 subprotocols/src/test_utils.rs delete mode 100644 subprotocols/src/utils.rs delete mode 100644 subprotocols/src/zerocheck.rs diff --git a/Cargo.lock b/Cargo.lock index 6aa94fd14..0e9045eb4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -478,7 +478,6 @@ dependencies = [ "serde_json", "strum", "strum_macros", - "subprotocols", "sumcheck", "tempfile", "thread_local", @@ -1126,7 +1125,6 @@ dependencies = [ "rand", "rayon", "serde", - "subprotocols", "sumcheck", "thiserror 1.0.69", "tikv-jemalloc-ctl", @@ -2839,24 +2837,6 @@ dependencies = [ "syn 2.0.101", ] -[[package]] -name = "subprotocols" -version = "0.1.0" -dependencies = [ - "ark-std", - "criterion", - "ff_ext", - "itertools 0.13.0", - "multilinear_extensions", - "p3-field", - "p3-goldilocks", - "rand", - "rayon", - "serde", - "thiserror 1.0.69", - "transcript", -] - [[package]] name = "substrate-bn" version = "0.6.0" diff --git a/Cargo.toml b/Cargo.toml index 370340c4f..543e2f7d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,6 @@ members = [ "sumcheck_macro", "poseidon", "gkr_iop", - "subprotocols", "sumcheck", "transcript", "whir", @@ -76,7 +75,6 @@ serde_json = "1.0" strum = "0.26" thiserror = "1" # do we need this? strum_macros = "0.26" -subprotocols = { path = "subprotocols" } substrate-bn = { version = "0.6.0" } sumcheck = { path = "sumcheck" } tiny-keccak = { version = "2.0.2", features = ["keccak"] } diff --git a/ceno_zkvm/Cargo.toml b/ceno_zkvm/Cargo.toml index ae80bf4fb..1e712d8c8 100644 --- a/ceno_zkvm/Cargo.toml +++ b/ceno_zkvm/Cargo.toml @@ -24,7 +24,6 @@ rand_chacha.workspace = true rayon.workspace = true serde.workspace = true serde_json.workspace = true -subprotocols.workspace = true sumcheck.workspace = true transcript = { path = "../transcript" } witness = { path = "../witness" } diff --git a/ceno_zkvm/src/scheme/prover.rs b/ceno_zkvm/src/scheme/prover.rs index c5b66ebd8..325efdf13 100644 --- a/ceno_zkvm/src/scheme/prover.rs +++ b/ceno_zkvm/src/scheme/prover.rs @@ -782,7 +782,7 @@ impl> ZKVMProver { .iter() .map(|base| PointAndEval { point: input_open_point.clone(), - eval: subprotocols::utils::evaluate_mle_ext(base, &input_open_point), + eval: base.evaluate(&input_open_point), }) .collect_vec(); @@ -1073,9 +1073,10 @@ impl> ZKVMProver { .into_iter() .zip(w_wit_layers) .flat_map(|(r, w)| { - vec![TowerProverSpec { witness: r }, TowerProverSpec { - witness: w, - }] + vec![ + TowerProverSpec { witness: r }, + TowerProverSpec { witness: w }, + ] }) .collect_vec(), lk_wit_layers diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index d0dee4a7b..2937e57fe 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -22,7 +22,6 @@ p3-goldilocks.workspace = true rand.workspace = true rayon.workspace = true serde.workspace = true -subprotocols = { path = "../subprotocols" } sumcheck.workspace = true thiserror.workspace = true tiny-keccak.workspace = true diff --git a/subprotocols/Cargo.toml b/subprotocols/Cargo.toml deleted file mode 100644 index b425bf380..000000000 --- a/subprotocols/Cargo.toml +++ /dev/null @@ -1,30 +0,0 @@ -[package] -categories.workspace = true -description = "Subprotocols" -edition.workspace = true -keywords.workspace = true -license.workspace = true -name = "subprotocols" -readme.workspace = true -repository.workspace = true -version.workspace = true - -[dependencies] -ark-std = { version = "0.5" } -ff_ext = { path = "../ff_ext" } -itertools.workspace = true -multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } -p3-field.workspace = true -rand.workspace = true -rayon.workspace = true -serde.workspace = true -thiserror = "1" -transcript = { path = "../transcript" } - -[dev-dependencies] -criterion.workspace = true -p3-goldilocks.workspace = true - -[[bench]] -harness = false -name = "expr_based_logup" diff --git a/subprotocols/benches/expr_based_logup.rs b/subprotocols/benches/expr_based_logup.rs deleted file mode 100644 index 7c1bbaef7..000000000 --- a/subprotocols/benches/expr_based_logup.rs +++ /dev/null @@ -1,144 +0,0 @@ -use std::{array, time::Duration}; - -use ark_std::test_rng; -use criterion::*; -use ff_ext::FromUniformBytes; -use itertools::Itertools; -use p3_field::extension::BinomialExtensionField; -use p3_goldilocks::Goldilocks; -use subprotocols::{ - expression::{Constant, Expression, Witness}, - sumcheck::SumcheckProverState, - test_utils::{random_point, random_poly}, - zerocheck::ZerocheckProverState, -}; -use transcript::BasicTranscript as Transcript; - -criterion_group!(benches, zerocheck_fn, sumcheck_fn); -criterion_main!(benches); - -const NUM_SAMPLES: usize = 10; -const NV: [usize; 2] = [25, 26]; - -fn sumcheck_fn(c: &mut Criterion) { - type E = BinomialExtensionField; - - for nv in NV { - // expand more input size once runtime is acceptable - let mut group = c.benchmark_group(format!("logup_sumcheck_nv_{}", nv)); - group.sample_size(NUM_SAMPLES); - - // Benchmark the proving time - group.bench_function( - BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), - |b| { - b.iter_custom(|iters| { - let mut time = Duration::new(0, 0); - for _ in 0..iters { - let mut rng = test_rng(); - // Initialize logup expression. - let eq = Expression::Wit(Witness::EqPoly(0)); - let beta = Expression::Const(Constant::Challenge(0)); - let [d0, d1, n0, n1] = - array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); - let expr = eq * (d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0)); - - // Randomly generate point and witness. - let point = random_point(&mut rng, nv); - - let d0 = random_poly(&mut rng, nv); - let d1 = random_poly(&mut rng, nv); - let n0 = random_poly(&mut rng, nv); - let n1 = random_poly(&mut rng, nv); - let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; - - let challenges = vec![E::random(&mut rng)]; - - let ext_mle_refs = - ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); - - let mut prover_transcript = Transcript::new(b"test"); - let prover = SumcheckProverState::new( - expr, - &[&point], - ext_mle_refs, - vec![], - &challenges, - &mut prover_transcript, - ); - - let instant = std::time::Instant::now(); - let _ = black_box(prover.prove()); - let elapsed = instant.elapsed(); - time += elapsed; - } - - time - }); - }, - ); - - group.finish(); - } -} - -fn zerocheck_fn(c: &mut Criterion) { - type E = BinomialExtensionField; - - for nv in NV { - // expand more input size once runtime is acceptable - let mut group = c.benchmark_group(format!("logup_sumcheck_nv_{}", nv)); - group.sample_size(NUM_SAMPLES); - - // Benchmark the proving time - group.bench_function( - BenchmarkId::new("prove_sumcheck", format!("sumcheck_nv_{}", nv)), - |b| { - b.iter_custom(|iters| { - let mut time = Duration::new(0, 0); - for _ in 0..iters { - // Initialize logup expression. - let mut rng = test_rng(); - let beta = Expression::Const(Constant::Challenge(0)); - let [d0, d1, n0, n1] = - array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); - let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); - - // Randomly generate point and witness. - let point = random_point(&mut rng, nv); - - let d0 = random_poly(&mut rng, nv); - let d1 = random_poly(&mut rng, nv); - let n0 = random_poly(&mut rng, nv); - let n1 = random_poly(&mut rng, nv); - let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; - - let challenges = vec![E::random(&mut rng)]; - - let ext_mle_refs = - ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); - - let mut prover_transcript = Transcript::new(b"test"); - let prover = ZerocheckProverState::new( - vec![expr], - &[&point], - ext_mle_refs, - vec![], - &challenges, - &mut prover_transcript, - ); - - let instant = std::time::Instant::now(); - let _ = black_box(prover.prove()); - let elapsed = instant.elapsed(); - time += elapsed; - } - - time - }); - }, - ); - - group.finish(); - } -} diff --git a/subprotocols/examples/zerocheck_logup.rs b/subprotocols/examples/zerocheck_logup.rs deleted file mode 100644 index 36c227b7e..000000000 --- a/subprotocols/examples/zerocheck_logup.rs +++ /dev/null @@ -1,93 +0,0 @@ -use std::array; - -use ff_ext::{ExtensionField, FromUniformBytes}; -use itertools::{Itertools, izip}; -use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; -use p3_goldilocks::Goldilocks as F; -use rand::thread_rng; -use subprotocols::{ - expression::{Constant, Expression, Witness}, - sumcheck::{SumcheckProof, SumcheckProverOutput}, - test_utils::{random_point, random_poly}, - utils::eq_vecs, - zerocheck::{ZerocheckProverState, ZerocheckVerifierState}, -}; -use transcript::BasicTranscript; - -type E = BinomialExtensionField; - -fn run_prover( - point: &[E], - ext_mles: &mut [Vec], - expr: Expression, - challenges: Vec, -) -> SumcheckProof { - let timer = std::time::Instant::now(); - let ext_mle_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); - - let mut prover_transcript = BasicTranscript::new(b"test"); - let prover = ZerocheckProverState::new( - vec![expr], - &[point], - ext_mle_refs, - vec![], - &challenges, - &mut prover_transcript, - ); - - let SumcheckProverOutput { proof, .. } = prover.prove(); - println!("Proving time: {:?}", timer.elapsed()); - proof -} - -fn run_verifier( - proof: SumcheckProof, - ans: &E, - point: &[E], - expr: Expression, - challenges: Vec, -) { - let mut verifier_transcript = BasicTranscript::new(b"test"); - let verifier = ZerocheckVerifierState::new( - vec![*ans], - vec![expr], - vec![], - vec![point], - proof, - &challenges, - &mut verifier_transcript, - ); - - verifier.verify().expect("verification failed"); -} - -fn main() { - let num_vars = 20; - let mut rng = thread_rng(); - - // Initialize logup expression. - let beta = Expression::Const(Constant::Challenge(0)); - let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); - let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); - - // Randomly generate point and witness. - let point = random_point(&mut rng, num_vars); - - let d0 = random_poly(&mut rng, num_vars); - let d1 = random_poly(&mut rng, num_vars); - let n0 = random_poly(&mut rng, num_vars); - let n1 = random_poly(&mut rng, num_vars); - let mut ext_mles = [d0.clone(), d1.clone(), n0.clone(), n1.clone()]; - - let challenges = vec![E::random(&mut rng)]; - - let proof = run_prover(&point, &mut ext_mles, expr.clone(), challenges.clone()); - - let eqs = eq_vecs([point.as_slice()].into_iter(), &[E::ONE]); - - let ans: E = izip!(&eqs[0], &d0, &d1, &n0, &n1) - .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) - .sum(); - - run_verifier(proof, &ans, &point, expr, challenges); -} diff --git a/subprotocols/src/error.rs b/subprotocols/src/error.rs deleted file mode 100644 index ca8eddefe..000000000 --- a/subprotocols/src/error.rs +++ /dev/null @@ -1,9 +0,0 @@ -use thiserror::Error; - -use crate::expression::Expression; - -#[derive(Clone, Debug, Error)] -pub enum VerifierError { - #[error("Claim not match: expr: {0:?}\n (expr name: {3:?})\n expect: {1:?}, got: {2:?}")] - ClaimNotMatch(Expression, E, E, String), -} diff --git a/subprotocols/src/expression.rs b/subprotocols/src/expression.rs deleted file mode 100644 index 60fd882cb..000000000 --- a/subprotocols/src/expression.rs +++ /dev/null @@ -1,167 +0,0 @@ -use std::sync::Arc; - -use ff_ext::ExtensionField; -use serde::{Deserialize, Serialize}; - -mod evaluate; -mod op; - -mod macros; - -pub type Point = Arc>; - -#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub enum Constant { - /// Base field - Base(i64), - /// Challenge - Challenge(usize), - /// Sum - Sum(Box, Box), - /// Product - Product(Box, Box), - /// Neg - Neg(Box), - /// Pow - Pow(Box, usize), -} - -impl Default for Constant { - fn default() -> Self { - Constant::Base(0) - } -} - -#[derive(Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub enum Witness { - /// Base field polynomial (index). - BasePoly(usize), - /// Extension field polynomial (index). - ExtPoly(usize), - /// Eq polynomial - EqPoly(usize), -} - -impl Default for Witness { - fn default() -> Self { - Witness::BasePoly(0) - } -} - -#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)] -pub enum Expression { - /// Constant - Const(Constant), - /// Witness. - Wit(Witness), - /// This is the sum of two expressions, with `degree`. - Sum(Box, Box, usize), - /// This is the product of two expressions, with `degree`. - Product(Box, Box, usize), - /// Neg, with `degree`. - Neg(Box, usize), - /// Pow, with `D` and `degree`. - Pow(Box, usize, usize), -} - -impl std::fmt::Debug for Expression { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Expression::Const(c) => write!(f, "{:?}", c), - Expression::Wit(w) => write!(f, "{:?}", w), - Expression::Sum(a, b, _) => write!(f, "({:?} + {:?})", a, b), - Expression::Product(a, b, _) => write!(f, "({:?} * {:?})", a, b), - Expression::Neg(a, _) => write!(f, "(-{:?})", a), - Expression::Pow(a, n, _) => write!(f, "({:?})^({})", a, n), - } - } -} - -impl std::fmt::Debug for Witness { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Witness::BasePoly(i) => write!(f, "BP[{}]", i), - Witness::ExtPoly(i) => write!(f, "EP[{}]", i), - Witness::EqPoly(i) => write!(f, "EQ[{}]", i), - } - } -} - -/// Vector of univariate polys. -#[derive(Clone, Debug)] -enum UniPolyVectorType { - Base(Vec>), - Ext(Vec>), -} - -/// Vector of field type. -#[derive(Clone, PartialEq, Eq)] -pub enum VectorType { - Base(Vec), - Ext(Vec), -} - -impl std::fmt::Debug for VectorType { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - VectorType::Base(v) => { - let mut v = v.iter(); - write!(f, "[")?; - if let Some(e) = v.next() { - write!(f, "{:?}", e)?; - } - for _ in 0..2 { - if let Some(e) = v.next() { - write!(f, ", {:?}", e)?; - } else { - break; - } - } - if v.next().is_some() { - write!(f, ", ...]")?; - } else { - write!(f, "]")?; - }; - Ok(()) - } - VectorType::Ext(v) => { - let mut v = v.iter(); - write!(f, "[")?; - if let Some(e) = v.next() { - write!(f, "{:?}", e)?; - } - for _ in 0..2 { - if let Some(e) = v.next() { - write!(f, ", {:?}", e)?; - } else { - break; - } - } - if v.next().is_some() { - write!(f, ", ...]")?; - } else { - write!(f, "]")?; - }; - Ok(()) - } - } - } -} - -#[derive(Clone, Debug)] -enum ScalarType { - Base(E::BaseField), - Ext(E), -} - -impl From for Expression { - fn from(w: Witness) -> Self { - Expression::Wit(w) - } -} - -impl From for Expression { - fn from(c: Constant) -> Self { - Expression::Const(c) - } -} diff --git a/subprotocols/src/expression/evaluate.rs b/subprotocols/src/expression/evaluate.rs deleted file mode 100644 index cb17c9e34..000000000 --- a/subprotocols/src/expression/evaluate.rs +++ /dev/null @@ -1,459 +0,0 @@ -use ff_ext::ExtensionField; -use itertools::{Itertools, zip_eq}; -use multilinear_extensions::virtual_poly::eq_eval; -use p3_field::{Field, PrimeCharacteristicRing}; - -use crate::{op_by_type, utils::i64_to_field}; - -use super::{Constant, Expression, ScalarType, UniPolyVectorType, VectorType, Witness}; - -impl Expression { - pub fn degree(&self) -> usize { - match self { - Expression::Const(_) => 0, - Expression::Wit(_) => 1, - Expression::Sum(_, _, degree) => *degree, - Expression::Product(_, _, degree) => *degree, - Expression::Neg(_, degree) => *degree, - Expression::Pow(_, _, degree) => *degree, - } - } - - pub fn is_ext(&self) -> bool { - match self { - Expression::Const(c) => c.is_ext(), - Expression::Wit(w) => w.is_ext(), - Expression::Sum(e0, e1, _) | Expression::Product(e0, e1, _) => { - e0.is_ext() || e1.is_ext() - } - Expression::Neg(e, _) => e.is_ext(), - Expression::Pow(e, d, _) => { - if *d > 0 { - e.is_ext() - } else { - false - } - } - } - } - - pub fn evaluate( - &self, - ext_mle_evals: &[E], - base_mle_evals: &[E], - out_points: &[&[E]], - in_point: &[E], - challenges: &[E], - ) -> E { - match self { - Expression::Const(constant) => constant.evaluate(challenges), - Expression::Wit(w) => w.evaluate(base_mle_evals, ext_mle_evals, out_points, in_point), - Expression::Sum(e0, e1, _) => { - e0.evaluate( - ext_mle_evals, - base_mle_evals, - out_points, - in_point, - challenges, - ) + e1.evaluate( - ext_mle_evals, - base_mle_evals, - out_points, - in_point, - challenges, - ) - } - Expression::Product(e0, e1, _) => { - e0.evaluate( - ext_mle_evals, - base_mle_evals, - out_points, - in_point, - challenges, - ) * e1.evaluate( - ext_mle_evals, - base_mle_evals, - out_points, - in_point, - challenges, - ) - } - Expression::Neg(e, _) => -e.evaluate( - ext_mle_evals, - base_mle_evals, - out_points, - in_point, - challenges, - ), - Expression::Pow(e, d, _) => e - .evaluate( - ext_mle_evals, - base_mle_evals, - out_points, - in_point, - challenges, - ) - .exp_u64(*d as u64), - } - } - - pub fn calc( - &self, - ext: &[Vec], - base: &[Vec], - eqs: &[Vec], - challenges: &[E], - ) -> VectorType { - assert!(!(ext.is_empty() && base.is_empty())); - let size = if !ext.is_empty() { - ext[0].len() - } else { - base[0].len() - }; - match self { - Expression::Const(constant) => { - VectorType::Ext(vec![constant.evaluate(challenges); size]) - } - Expression::Wit(w) => match w { - Witness::BasePoly(index) => VectorType::Base(base[*index].clone()), - Witness::ExtPoly(index) => VectorType::Ext(ext[*index].clone()), - Witness::EqPoly(index) => VectorType::Ext(eqs[*index].clone()), - }, - Expression::Sum(e0, e1, _) => { - e0.calc(ext, base, eqs, challenges) + e1.calc(ext, base, eqs, challenges) - } - Expression::Product(e0, e1, _) => { - e0.calc(ext, base, eqs, challenges) * e1.calc(ext, base, eqs, challenges) - } - Expression::Neg(e, _) => -e.calc(ext, base, eqs, challenges), - Expression::Pow(e, d, _) => { - let poly = e.calc(ext, base, eqs, challenges); - op_by_type!( - VectorType, - poly, - |poly| { poly.into_iter().map(|x| x.exp_u64(*d as u64)).collect_vec() }, - |ext| VectorType::Ext(ext), - |base| VectorType::Base(base) - ) - } - } - } - - #[allow(clippy::too_many_arguments)] - pub fn sumcheck_uni_poly( - &self, - ext_mles: &[&mut [E]], - base_after_mles: &[Vec], - base_mles: &[&[E::BaseField]], - eqs: &[Vec], - challenges: &[E], - size: usize, - degree: usize, - ) -> Vec { - let poly = self.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - eqs, - challenges, - size, - degree, - ); - op_by_type!(UniPolyVectorType, poly, |poly| { - poly.into_iter().fold(vec![E::ZERO; degree], |acc, x| { - zip_eq(acc, x).map(|(a, b)| a + b).collect_vec() - }) - }) - } - - /// Compute \sum_x (eq(0, x) + eq(1, x)) * expr_0(X, x) - #[allow(clippy::too_many_arguments)] - pub fn zerocheck_uni_poly<'a, E: ExtensionField>( - &self, - ext_mles: &[&mut [E]], - base_after_mles: &[Vec], - base_mles: &[&[E::BaseField]], - challenges: &[E], - coeffs: impl Iterator, - size: usize, - ) -> Vec { - let degree = self.degree(); - let poly = self.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - &[], - challenges, - size, - degree, - ); - - op_by_type!(UniPolyVectorType, poly, |poly| { - zip_eq(coeffs, poly).fold(vec![E::ZERO; degree], |mut acc, (c, poly)| { - zip_eq(&mut acc, poly).for_each(|(a, x)| *a += *c * x); - acc - }) - }) - } - - /// Compute the extension field univariate polynomial evaluated on 1..degree + 1. - #[allow(clippy::too_many_arguments)] - fn uni_poly_inner( - &self, - ext_mles: &[&mut [E]], - base_after_mles: &[Vec], - base_mles: &[&[E::BaseField]], - eqs: &[Vec], - challenges: &[E], - size: usize, - degree: usize, - ) -> UniPolyVectorType { - match self { - Expression::Const(constant) => { - let value = constant.evaluate(challenges); - UniPolyVectorType::Ext(vec![vec![value; degree]; size >> 1]) - } - Expression::Wit(w) => match w { - Witness::BasePoly(index) => { - if !base_mles.is_empty() { - UniPolyVectorType::Base(uni_poly_helper(base_mles[*index], size, degree)) - } else { - UniPolyVectorType::Ext(uni_poly_helper( - &base_after_mles[*index], - size, - degree, - )) - } - } - Witness::ExtPoly(index) => { - UniPolyVectorType::Ext(uni_poly_helper(ext_mles[*index], size, degree)) - } - Witness::EqPoly(index) => { - UniPolyVectorType::Ext(uni_poly_helper(&eqs[*index], size, degree)) - } - }, - Expression::Sum(expr0, expr1, _) => { - let poly0 = expr0.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - eqs, - challenges, - size, - degree, - ); - let poly1 = expr1.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - eqs, - challenges, - size, - degree, - ); - poly0 + poly1 - } - Expression::Product(expr0, expr1, _) => { - let poly0 = expr0.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - eqs, - challenges, - size, - degree, - ); - let poly1 = expr1.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - eqs, - challenges, - size, - degree, - ); - poly0 * poly1 - } - Expression::Neg(expr, _) => { - let poly = expr.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - eqs, - challenges, - size, - degree, - ); - -poly - } - Expression::Pow(expr, d, _) => { - let poly = expr.uni_poly_inner( - ext_mles, - base_after_mles, - base_mles, - eqs, - challenges, - size, - degree, - ); - op_by_type!( - UniPolyVectorType, - poly, - |poly| { - poly.into_iter() - .map(|x| x.iter().map(|x| x.exp_u64(*d as u64)).collect_vec()) - .collect_vec() - }, - |ext| UniPolyVectorType::Ext(ext), - |base| UniPolyVectorType::Base(base) - ) - } - } - } -} - -impl Constant { - pub fn is_ext(&self) -> bool { - match self { - Constant::Base(_) => false, - Constant::Challenge(_) => true, - Constant::Sum(c0, c1) | Constant::Product(c0, c1) => c0.is_ext() || c1.is_ext(), - Constant::Neg(c) => c.is_ext(), - Constant::Pow(c, _) => c.is_ext(), - } - } - - pub fn evaluate(&self, challenges: &[E]) -> E { - let res = self.evaluate_inner(challenges); - op_by_type!(ScalarType, res, |b| b, |e| e, |bf| E::from(bf)) - } - - fn evaluate_inner(&self, challenges: &[E]) -> ScalarType { - match self { - Constant::Base(value) => ScalarType::Base(i64_to_field(*value)), - Constant::Challenge(index) => ScalarType::Ext(challenges[*index]), - Constant::Sum(c0, c1) => c0.evaluate_inner(challenges) + c1.evaluate_inner(challenges), - Constant::Product(c0, c1) => { - c0.evaluate_inner(challenges) * c1.evaluate_inner(challenges) - } - Constant::Neg(c) => -c.evaluate_inner(challenges), - Constant::Pow(c, degree) => { - let value = c.evaluate_inner(challenges); - op_by_type!( - ScalarType, - value, - |value| { value.exp_u64(*degree as u64) }, - |ext| ScalarType::Ext(ext), - |base| ScalarType::Base(base) - ) - } - } - } - - pub fn entry(&self, challenges: &[E]) -> E { - match self { - Constant::Challenge(index) => challenges[*index], - _ => unreachable!(), - } - } - - pub fn entry_mut<'a, E: ExtensionField>(&self, challenges: &'a mut [E]) -> &'a mut E { - match self { - Constant::Challenge(index) => &mut challenges[*index], - _ => unreachable!(), - } - } -} - -impl Witness { - pub fn is_ext(&self) -> bool { - match self { - Witness::BasePoly(_) => false, - Witness::ExtPoly(_) => true, - Witness::EqPoly(_) => true, - } - } - - pub fn evaluate( - &self, - base_mle_evals: &[E], - ext_mle_evals: &[E], - out_point: &[&[E]], - in_point: &[E], - ) -> E { - match self { - Witness::BasePoly(index) => base_mle_evals[*index], - Witness::ExtPoly(index) => ext_mle_evals[*index], - Witness::EqPoly(index) => eq_eval(out_point[*index], in_point), - } - } - - pub fn base<'a, T>(&self, base_mle_evals: &'a [T]) -> &'a T { - match self { - Witness::BasePoly(index) => &base_mle_evals[*index], - _ => unreachable!(), - } - } - - pub fn base_mut<'a, T>(&self, base_mle_evals: &'a mut [T]) -> &'a mut T { - match self { - Witness::BasePoly(index) => &mut base_mle_evals[*index], - _ => unreachable!(), - } - } - - pub fn ext<'a, T>(&self, ext_mle_evals: &'a [T]) -> &'a T { - match self { - Witness::ExtPoly(index) => &ext_mle_evals[*index], - _ => unreachable!(), - } - } - - pub fn ext_mut<'a, T>(&self, ext_mle_evals: &'a mut [T]) -> &'a mut T { - match self { - Witness::ExtPoly(index) => &mut ext_mle_evals[*index], - _ => unreachable!(), - } - } -} - -/// Compute the univariate polynomial evaluated on 1..degree. -#[inline] -fn uni_poly_helper(mle: &[F], size: usize, degree: usize) -> Vec> { - mle.chunks(2) - .take(size >> 1) - .map(|p| { - let start = p[0]; - let step = p[1] - start; - (0..degree) - .scan(start, |state, _| { - *state += step; - Some(*state) - }) - .collect_vec() - }) - .collect_vec() -} - -#[cfg(test)] -mod test { - use crate::field_vec; - use p3_field::PrimeCharacteristicRing; - use p3_goldilocks::Goldilocks as F; - - #[test] - fn test_uni_poly_helper() { - // (x + 2), (3x + 4), (5x + 6), (7x + 8) - let mle = field_vec![F, 2, 3, 4, 7, 6, 11, 8, 15, 11, 13, 17, 19, 23, 29, 31, 37]; - let size = 8; - let degree = 3; - let expected = vec![ - field_vec![F, 3, 4, 5], - field_vec![F, 7, 10, 13], - field_vec![F, 11, 16, 21], - field_vec![F, 15, 22, 29], - ]; - let result = super::uni_poly_helper(&mle, size, degree); - assert_eq!(result, expected); - } -} diff --git a/subprotocols/src/expression/macros.rs b/subprotocols/src/expression/macros.rs deleted file mode 100644 index 8930be70f..000000000 --- a/subprotocols/src/expression/macros.rs +++ /dev/null @@ -1,100 +0,0 @@ -#[macro_export] -macro_rules! op_by_type { - ($ele_type:ident, $ele:ident, |$x:ident| $op:expr, |$y_ext:ident| $convert_ext:expr, |$y_base:ident| $convert_base:expr) => { - match $ele { - $ele_type::Base($x) => { - let $y_base = $op; - $convert_base - } - $ele_type::Ext($x) => { - let $y_ext = $op; - $convert_ext - } - } - }; - - ($ele_type:ident, $ele:ident, |$x:ident| $op:expr, |$y_base:ident| $convert_base:expr) => { - match $ele { - $ele_type::Base($x) => { - let $y_base = $op; - $convert_base - } - $ele_type::Ext($x) => $op, - } - }; - - ($ele_type:ident, $ele:ident, |$x:ident| $op:expr) => { - match $ele { - $ele_type::Base($x) => $op, - $ele_type::Ext($x) => $op, - } - }; -} - -#[macro_export] -macro_rules! define_commutative_op_mle2 { - ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident, $y:ident| $op:expr) => { - impl $trait_type for $ele_type { - type Output = Self; - - fn $func_type(self, other: Self) -> Self::Output { - #[allow(unused)] - match (self, other) { - ($ele_type::Base(mut $x), $ele_type::Base($y)) => $ele_type::Base($op), - ($ele_type::Ext(mut $x), $ele_type::Base($y)) - | ($ele_type::Base($y), $ele_type::Ext(mut $x)) => $ele_type::Ext($op), - ($ele_type::Ext(mut $x), $ele_type::Ext($y)) => $ele_type::Ext($op), - } - } - } - - // impl<'a, E: ExtensionField> $trait_type<&'a Self> for $ele_type { - // type Output = Self; - - // fn $func_type(self, other: &'a Self) -> Self::Output { - // #[allow(unused)] - // match (self, other) { - // ($ele_type::Base(mut $x), $ele_type::Base($y)) => $ele_type::Base($op), - // ($ele_type::Ext(mut $x), $ele_type::Base($y)) => $ele_type::Ext($op), - // ($ele_type::Base($y), $ele_type::Ext($x)) => { - // let mut $x = $x.clone(); - // $ele_type::Ext($op) - // } - // ($ele_type::Ext(mut $x), $ele_type::Ext($y)) => $ele_type::Ext($op), - // } - // } - // } - }; -} - -#[macro_export] -macro_rules! define_op_mle2 { - ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident, $y:ident| $op:expr) => { - impl $trait_type for $ele_type { - type Output = Self; - - fn $func_type(self, other: Self) -> Self::Output { - let $x = self; - let $y = other; - $op - } - } - }; -} - -#[macro_export] -macro_rules! define_op_mle { - ($ele_type:ident, $trait_type:ident, $func_type:ident, |$x:ident| $op:expr) => { - impl $trait_type for $ele_type { - type Output = Self; - - fn $func_type(self) -> Self::Output { - #[allow(unused)] - match (self) { - $ele_type::Base(mut $x) => $ele_type::Base($op), - $ele_type::Ext(mut $x) => $ele_type::Ext($op), - } - } - } - }; -} diff --git a/subprotocols/src/expression/op.rs b/subprotocols/src/expression/op.rs deleted file mode 100644 index 1690c24e4..000000000 --- a/subprotocols/src/expression/op.rs +++ /dev/null @@ -1,81 +0,0 @@ -use std::{ - cmp::max, - ops::{Add, Mul, Neg, Sub}, -}; - -use ff_ext::ExtensionField; -use itertools::zip_eq; - -use crate::{define_commutative_op_mle2, define_op_mle, define_op_mle2}; - -use super::{Expression, ScalarType, UniPolyVectorType, VectorType}; - -impl Add for Expression { - type Output = Self; - - fn add(self, other: Self) -> Self { - let degree = max(self.degree(), other.degree()); - Expression::Sum(Box::new(self), Box::new(other), degree) - } -} - -impl Mul for Expression { - type Output = Self; - - fn mul(self, other: Self) -> Self { - #[allow(clippy::suspicious_arithmetic_impl)] - let degree = self.degree() + other.degree(); - Expression::Product(Box::new(self), Box::new(other), degree) - } -} - -impl Neg for Expression { - type Output = Self; - - fn neg(self) -> Self { - let deg = self.degree(); - Expression::Neg(Box::new(self), deg) - } -} - -impl Sub for Expression { - type Output = Self; - - fn sub(self, other: Self) -> Self { - self + (-other) - } -} - -define_commutative_op_mle2!(UniPolyVectorType, Add, add, |x, y| { - zip_eq(&mut x, y).for_each(|(x, y)| zip_eq(x, y).for_each(|(x, y)| *x += y)); - x -}); -define_commutative_op_mle2!(UniPolyVectorType, Mul, mul, |x, y| { - zip_eq(&mut x, y).for_each(|(x, y)| zip_eq(x, y).for_each(|(x, y)| *x *= y)); - x -}); -define_op_mle2!(UniPolyVectorType, Sub, sub, |x, y| x + (-y)); -define_op_mle!(UniPolyVectorType, Neg, neg, |x| { - x.iter_mut() - .for_each(|x| x.iter_mut().for_each(|x| *x = -(*x))); - x -}); - -define_commutative_op_mle2!(VectorType, Add, add, |x, y| { - zip_eq(&mut x, y).for_each(|(x, y)| *x += y); - x -}); -define_commutative_op_mle2!(VectorType, Mul, mul, |x, y| { - zip_eq(&mut x, y).for_each(|(x, y)| *x *= y); - x -}); -define_op_mle2!(VectorType, Sub, sub, |x, y| x + (-y)); -define_op_mle!(VectorType, Neg, neg, |x| { - x.iter_mut().for_each(|x| *x = -(*x)); - x -}); - -define_commutative_op_mle2!(ScalarType, Add, add, |x, y| x + y); -define_commutative_op_mle2!(ScalarType, Mul, mul, |x, y| x * y); -define_op_mle2!(ScalarType, Sub, sub, |x, y| x + (-y)); -define_op_mle!(ScalarType, Neg, neg, |x| -x); diff --git a/subprotocols/src/lib.rs b/subprotocols/src/lib.rs deleted file mode 100644 index a86f12c8f..000000000 --- a/subprotocols/src/lib.rs +++ /dev/null @@ -1,9 +0,0 @@ -pub mod error; -pub mod expression; -pub mod points; -pub mod sumcheck; -pub mod utils; -pub mod zerocheck; - -#[macro_use] -pub mod test_utils; diff --git a/subprotocols/src/points.rs b/subprotocols/src/points.rs deleted file mode 100644 index 9d128e0ac..000000000 --- a/subprotocols/src/points.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::sync::Arc; - -use ff_ext::ExtensionField; -use itertools::izip; - -use crate::expression::Point; - -pub trait PointBeforeMerge { - fn point_before_merge(&self, pos: &[usize]) -> Point; -} - -pub trait PointBeforePartition { - fn point_before_partition( - &self, - pos_and_var_ids: &[(usize, usize)], - challenges: &[E], - ) -> Point; -} - -/// Suppose we have several vectors v_0, ..., v_{N-1}, and want to merge it through n = log(N) variables, -/// x_0, ..., x_{n-1}, at the positions i_0, ..., i_{n - 1}. Suppose the output point is P, then the point -/// before it is P_0, ..., P_{i_0 - 1}, P_{i_0 + 1}, ..., P_{i_1 - 1}, ..., P_{i_{n - 1} + 1}, ..., P_{N - 1}. -impl PointBeforeMerge for Point { - fn point_before_merge(&self, pos: &[usize]) -> Point { - if pos.is_empty() { - return self.clone(); - } - - assert!(izip!(pos.iter(), pos.iter().skip(1)).all(|(i, j)| i < j)); - - let mut new_point = Vec::with_capacity(self.len() - pos.len()); - let mut i = 0usize; - for (j, p) in self.iter().enumerate() { - if j != pos[i] { - new_point.push(*p); - } else { - i += 1; - } - } - - Arc::new(new_point) - } -} - -/// Suppose we have a vector v, and want to partition it through n = log(N) variables -/// x_0, ..., x_{n-1}, at the positions i_0, ..., i_{n - 1}. Suppose the output point -/// is P, then the point before it is P after calling P.insert(i_0, x_0), ... -impl PointBeforePartition for Point { - fn point_before_partition( - &self, - pos_and_var_ids: &[(usize, usize)], - challenges: &[E], - ) -> Point { - if pos_and_var_ids.is_empty() { - return self.clone(); - } - - assert!( - izip!(pos_and_var_ids.iter(), pos_and_var_ids.iter().skip(1)).all(|(i, j)| i.0 < j.0) - ); - - let mut new_point = Vec::with_capacity(self.len() + pos_and_var_ids.len()); - let mut i = 0usize; - for (j, p) in self.iter().enumerate() { - if i + j != pos_and_var_ids[i].0 { - new_point.push(*p); - } else { - new_point.push(challenges[pos_and_var_ids[i].1]); - i += 1; - } - } - - Arc::new(new_point) - } -} diff --git a/subprotocols/src/sumcheck.rs b/subprotocols/src/sumcheck.rs deleted file mode 100644 index 73c331f57..000000000 --- a/subprotocols/src/sumcheck.rs +++ /dev/null @@ -1,454 +0,0 @@ -use std::{iter, mem, sync::Arc, vec}; - -use ark_std::log2; -use ff_ext::ExtensionField; -use itertools::chain; -use serde::{Deserialize, Serialize, de::DeserializeOwned}; -use transcript::Transcript; - -use crate::{ - error::VerifierError, - expression::{Expression, Point}, - utils::eq_vecs, -}; - -use super::utils::{fix_variables_ext, fix_variables_inplace, interpolate_uni_poly}; - -/// This is an randomly combined sumcheck protocol for the following equation: -/// \sigma = \sum_x expr(x) -pub struct SumcheckProverState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - /// Expression. - expr: Expression, - - /// Extension field mles. - ext_mles: Vec<&'a mut [E]>, - /// Base field mles after the first round. - base_mles_after: Vec>, - /// Base field mles. - base_mles: Vec<&'a [E::BaseField]>, - /// Eq polys - eqs: Vec>, - /// Challenges occurred in expressions - challenges: &'a [E], - - transcript: &'a mut Trans, - - degree: usize, - num_vars: usize, -} - -#[derive(Clone, Serialize, Deserialize)] -#[serde(bound( - serialize = "E::BaseField: Serialize", - deserialize = "E::BaseField: DeserializeOwned" -))] -pub struct SumcheckProof { - /// Messages for each round. - pub univariate_polys: Vec>>, - pub ext_mle_evals: Vec, - pub base_mle_evals: Vec, -} - -pub struct SumcheckProverOutput { - pub proof: SumcheckProof, - pub point: Point, -} - -impl<'a, E, Trans> SumcheckProverState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - #[allow(clippy::too_many_arguments)] - pub fn new( - expr: Expression, - points: &[&[E]], - ext_mles: Vec<&'a mut [E]>, - base_mles: Vec<&'a [E::BaseField]>, - challenges: &'a [E], - transcript: &'a mut Trans, - ) -> Self { - assert!(!(ext_mles.is_empty() && base_mles.is_empty())); - - let num_vars = if !ext_mles.is_empty() { - log2(ext_mles[0].len()) as usize - } else { - log2(base_mles[0].len()) as usize - }; - - // The length of all mles should be 2^{num_vars}. - assert!(ext_mles.iter().all(|mle| mle.len() == 1 << num_vars)); - assert!(base_mles.iter().all(|mle| mle.len() == 1 << num_vars)); - - let degree = expr.degree(); - - let eqs = eq_vecs(points.iter().copied(), &vec![E::ONE; points.len()]); - - Self { - expr, - ext_mles, - base_mles_after: vec![], - base_mles, - eqs, - challenges, - transcript, - num_vars, - degree, - } - } - - pub fn prove(mut self) -> SumcheckProverOutput { - let (univariate_polys, point) = (0..self.num_vars) - .map(|round| { - let round_msg = self.compute_univariate_poly(round); - self.transcript.append_field_element_exts(&round_msg); - - let r = self - .transcript - .sample_and_append_challenge(b"sumcheck round") - .elements; - self.update_mles(&r, round); - (vec![round_msg], r) - }) - .unzip(); - let point = Arc::new(point); - - // Send the final evaluations - let SumcheckProverState { - ext_mles, - base_mles_after, - base_mles, - .. - } = self; - let ext_mle_evaluations = ext_mles.into_iter().map(|mle| mle[0]).collect(); - let base_mle_evaluations = if !base_mles.is_empty() { - base_mles.into_iter().map(|mle| E::from(mle[0])).collect() - } else { - base_mles_after.into_iter().map(|mle| mle[0]).collect() - }; - - SumcheckProverOutput { - proof: SumcheckProof { - univariate_polys, - ext_mle_evals: ext_mle_evaluations, - base_mle_evals: base_mle_evaluations, - }, - point, - } - } - - /// Compute f(X) = r^0 \sum_x expr_0(X || x) + r^1 \sum_x expr_1(X || x) + ... - fn compute_univariate_poly(&self, round: usize) -> Vec { - self.expr.sumcheck_uni_poly( - &self.ext_mles, - &self.base_mles_after, - &self.base_mles, - &self.eqs, - self.challenges, - 1 << (self.num_vars - round), - self.degree, - ) - } - - fn update_mles(&mut self, r: &E, round: usize) { - // fix variables of eq polynomials - self.eqs.iter_mut().for_each(|eq| { - fix_variables_inplace(eq, r); - }); - // fix variables of ext field polynomials. - self.ext_mles.iter_mut().for_each(|mle| { - fix_variables_inplace(mle, r); - }); - // fix variables of base field polynomials. - if round == 0 { - self.base_mles_after = mem::take(&mut self.base_mles) - .into_iter() - .map(|mle| fix_variables_ext(mle, r)) - .collect(); - } else { - self.base_mles_after - .iter_mut() - .for_each(|mle| fix_variables_inplace(mle, r)); - } - } -} - -pub struct SumcheckVerifierState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - sigma: E, - expr: Expression, - proof: SumcheckProof, - expr_names: Vec, - challenges: &'a [E], - transcript: &'a mut Trans, - out_points: Vec<&'a [E]>, -} - -pub struct SumcheckClaims { - pub in_point: Point, - pub base_mle_evals: Vec, - pub ext_mle_evals: Vec, -} - -impl<'a, E, Trans> SumcheckVerifierState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - pub fn new( - sigma: E, - expr: Expression, - out_points: Vec<&'a [E]>, - proof: SumcheckProof, - challenges: &'a [E], - transcript: &'a mut Trans, - expr_names: Vec, - ) -> Self { - // Fill in missing debug data - let mut expr_names = expr_names; - expr_names.resize(1, "nothing".to_owned()); - Self { - sigma, - expr, - proof, - challenges, - transcript, - out_points, - expr_names, - } - } - - pub fn verify(self) -> Result, VerifierError> { - let SumcheckVerifierState { - sigma, - expr, - proof, - challenges, - transcript, - out_points, - expr_names, - } = self; - let SumcheckProof { - univariate_polys, - ext_mle_evals, - base_mle_evals, - } = proof; - - let (in_point, expected_claim) = univariate_polys.into_iter().fold( - (vec![], sigma), - |(mut last_point, last_sigma), msg| { - let msg = msg.into_iter().next().unwrap(); - transcript.append_field_element_exts(&msg); - - let len = msg.len() + 1; - let eval_at_0 = last_sigma - msg[0]; - - // Evaluations on degree, degree - 1, ..., 1, 0. - let evals_iter_rev = chain![msg.into_iter().rev(), iter::once(eval_at_0)]; - - let r = transcript - .sample_and_append_challenge(b"sumcheck round") - .elements; - let sigma = interpolate_uni_poly(evals_iter_rev, len, r); - last_point.push(r); - (last_point, sigma) - }, - ); - - // Check the final evaluations. - let got_claim = expr.evaluate( - &ext_mle_evals, - &base_mle_evals, - &out_points, - &in_point, - challenges, - ); - - if expected_claim != got_claim { - return Err(VerifierError::ClaimNotMatch( - expr, - expected_claim, - got_claim, - expr_names[0].clone(), - )); - } - - let in_point = Arc::new(in_point); - Ok(SumcheckClaims { - in_point, - base_mle_evals, - ext_mle_evals, - }) - } -} - -#[cfg(test)] -mod test { - use std::array; - - use ff_ext::ExtensionField; - use itertools::{Itertools, izip}; - use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; - use p3_goldilocks::Goldilocks as F; - use transcript::BasicTranscript; - - type E = BinomialExtensionField; - - use crate::{ - expression::{Constant, Expression, Witness}, - field_vec, - utils::eq_vecs, - }; - - use super::{SumcheckProverOutput, SumcheckProverState, SumcheckVerifierState}; - - #[allow(clippy::too_many_arguments)] - fn run( - points: Vec<&[E]>, - expr: Expression, - ext_mle_refs: Vec<&mut [E]>, - base_mle_refs: Vec<&[E::BaseField]>, - challenges: Vec, - - sigma: E, - ) { - let mut prover_transcript = BasicTranscript::new(b"test"); - let prover = SumcheckProverState::new( - expr.clone(), - &points, - ext_mle_refs, - base_mle_refs, - &challenges, - &mut prover_transcript, - ); - - let SumcheckProverOutput { proof, .. } = prover.prove(); - - let mut verifier_transcript = BasicTranscript::new(b"test"); - let verifier = SumcheckVerifierState::new( - sigma, - expr, - points, - proof, - &challenges, - &mut verifier_transcript, - vec![], - ); - - verifier.verify().expect("verification failed"); - } - - #[test] - fn test_sumcheck_trivial() { - let f = field_vec![F, 2]; - let g = field_vec![F, 3]; - let out_point = vec![]; - - let base_mle_refs = vec![f.as_slice(), g.as_slice()]; - let f = Expression::Wit(Witness::BasePoly(0)); - let g = Expression::Wit(Witness::BasePoly(1)); - let expr = f * g; - - run( - vec![out_point.as_slice()], - expr, - vec![], - base_mle_refs, - vec![], - E::from_u64(6), - ); - } - - #[test] - fn test_sumcheck_simple() { - let f = field_vec![F, 1, 2, 3, 4]; - let ans = E::from(f.iter().fold(F::ZERO, |acc, x| acc + *x)); - let base_mle_refs = vec![f.as_slice()]; - let expr = Expression::Wit(Witness::BasePoly(0)); - - run(vec![], expr, vec![], base_mle_refs, vec![], ans); - } - - #[test] - fn test_sumcheck_logup() { - let point = field_vec![E, 2, 3]; - - let eqs = eq_vecs([point.as_slice()].into_iter(), &[E::ONE]); - - let d0 = field_vec![E, 1, 2, 3, 4]; - let d1 = field_vec![E, 5, 6, 7, 8]; - let n0 = field_vec![E, 9, 10, 11, 12]; - let n1 = field_vec![E, 13, 14, 15, 16]; - - let challenges = vec![E::from_u64(7)]; - let ans = izip!(&eqs[0], &d0, &d1, &n0, &n1) - .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) - .sum(); - - let mut ext_mles = [d0, d1, n0, n1]; - let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); - let eq = Expression::Wit(Witness::EqPoly(0)); - let beta = Expression::Const(Constant::Challenge(0)); - - let expr = eq * (d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0)); - - let ext_mle_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); - run( - vec![point.as_slice()], - expr, - ext_mle_refs, - vec![], - challenges, - ans, - ); - } - - #[test] - fn test_sumcheck_multi_points() { - let challenges = vec![E::from_u64(2)]; - - let points = [ - field_vec![E, 2, 3], - field_vec![E, 5, 7], - field_vec![E, 2, 5], - ]; - let point_refs = points.iter().map(|v| v.as_slice()).collect_vec(); - - let eqs = eq_vecs(point_refs.clone().into_iter(), &vec![E::ONE; points.len()]); - - let d0 = field_vec![F, 1, 2, 3, 4]; - let d1 = field_vec![F, 5, 6, 7, 8]; - let n0 = field_vec![F, 9, 10, 11, 12]; - let n1 = field_vec![F, 13, 14, 15, 16]; - - let ans_0 = izip!(&eqs[0], &d0, &d1) - .map(|(eq0, d0, d1)| *eq0 * *d0 * *d1) - .sum::(); - let ans_1 = izip!(&eqs[1], &d0, &n1) - .map(|(eq1, d0, n1)| *eq1 * *d0 * *n1) - .sum::(); - let ans_2 = izip!(&eqs[2], &d1, &n0) - .map(|(eq2, d1, n0)| *eq2 * *d1 * *n0) - .sum::(); - let ans = (ans_0 * challenges[0] + ans_1) * challenges[0] + ans_2; - - let base_mles = [d0, d1, n0, n1]; - let [eq0, eq1, eq2] = array::from_fn(|i| Expression::Wit(Witness::EqPoly(i))); - let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::BasePoly(i))); - let rlc_challenge = Expression::Const(Constant::Challenge(0)); - - let expr = (eq0 * d0.clone() * d1.clone() * rlc_challenge.clone() + eq1 * d0 * n1) - * rlc_challenge - + eq2 * d1 * n0; - - let base_mle_refs = base_mles.iter().map(|v| v.as_slice()).collect_vec(); - run(point_refs, expr, vec![], base_mle_refs, challenges, ans); - } -} diff --git a/subprotocols/src/test_utils.rs b/subprotocols/src/test_utils.rs deleted file mode 100644 index cb0812e1a..000000000 --- a/subprotocols/src/test_utils.rs +++ /dev/null @@ -1,46 +0,0 @@ -use ff_ext::{ExtensionField, FromUniformBytes}; -use itertools::Itertools; -use p3_field::Field; -use rand::RngCore; - -pub fn random_point(mut rng: impl RngCore, num_vars: usize) -> Vec { - (0..num_vars).map(|_| E::random(&mut rng)).collect_vec() -} - -pub fn random_vec(mut rng: impl RngCore, len: usize) -> Vec { - (0..len).map(|_| E::random(&mut rng)).collect_vec() -} - -pub fn random_poly(mut rng: impl RngCore, num_vars: usize) -> Vec { - (0..1 << num_vars) - .map(|_| E::random(&mut rng)) - .collect_vec() -} - -#[macro_export] -macro_rules! field_vec { - () => ( - $crate::vec::Vec::new() - ); - ($field_type:ident; $elem:expr; $n:expr) => ( - $crate::vec::from_elem({ - if $x < 0 { - -$field_type::from((-$x) as u64) - } else { - $field_type::from($x as u64) - } - }, $n) - ); - ($field_type:ident, $($x:expr),+ $(,)?) => ( - <[_]>::into_vec( - std::boxed::Box::new([$({ - let x = $x as i64; - if $x < 0 { - -$field_type::from_u64((-x) as u64) - } else { - $field_type::from_u64(x as u64) - } - }),+]) - ) - ); -} diff --git a/subprotocols/src/utils.rs b/subprotocols/src/utils.rs deleted file mode 100644 index 025fdf3c3..000000000 --- a/subprotocols/src/utils.rs +++ /dev/null @@ -1,235 +0,0 @@ -use std::{iter, ops::Mul}; - -use ff_ext::ExtensionField; -use itertools::{Itertools, chain, izip}; -use multilinear_extensions::virtual_poly::build_eq_x_r_vec_with_scalar; -use p3_field::Field; -use rayon::iter::{ - IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator, -}; - -pub fn i64_to_field(i: i64) -> F { - if i < 0 { - -F::from_u64(i.unsigned_abs()) - } else { - F::from_u64(i as u64) - } -} - -pub fn power_list(ele: &F, size: usize) -> Vec { - (0..size) - .scan(F::ONE, |state, _| { - let last = *state; - *state *= *ele; - Some(last) - }) - .collect() -} - -/// Grand product of ele, start from 1, with length ele.len() + 1. -pub fn grand_product(ele: &[F]) -> Vec { - let one = F::ONE; - chain![iter::once(&one), ele.iter()] - .scan(F::ONE, |state, e| { - *state *= *e; - Some(*state) - }) - .collect() -} - -pub fn eq_vecs<'a, E: ExtensionField>( - points: impl Iterator, - scalars: &[E], -) -> Vec> { - izip!(points, scalars) - .map(|(point, scalar)| build_eq_x_r_vec_with_scalar(point, *scalar)) - .collect_vec() -} - -#[inline(always)] -pub fn eq(x: &F, y: &F) -> F { - // x * y + (1 - x) * (1 - y) - let xy = *x * *y; - xy + xy - *x - *y + F::ONE -} - -pub fn fix_variables_ext(base_mle: &[E::BaseField], r: &E) -> Vec { - base_mle - .par_iter() - .chunks(2) - .with_min_len(64) - .map(|buf| *r * (*buf[1] - *buf[0]) + *buf[0]) - .collect() -} - -pub fn fix_variables_inplace(ext_mle: &mut [E], r: &E) { - ext_mle - .par_iter_mut() - .chunks(2) - .with_min_len(64) - .for_each(|mut buf| *buf[0] = *buf[0] + (*buf[1] - *buf[0]) * *r); - // sequentially update buf[b1, b2,..bt] = buf[b1, b2,..bt, 0] - let half_len = ext_mle.len() >> 1; - for index in 0..half_len { - ext_mle[index] = ext_mle[index << 1]; - } -} - -pub fn evaluate_mle_inplace(mle: &mut [E], point: &[E]) -> E { - for r in point { - fix_variables_inplace(mle, r); - } - mle[0] -} - -pub fn evaluate_mle_ext(mle: &[E::BaseField], point: &[E]) -> E { - let mut ext_mle = fix_variables_ext(mle, &point[0]); - evaluate_mle_inplace(&mut ext_mle, &point[1..]) -} - -/// Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this -/// polynomial at `eval_at`: -/// -/// \sum_{i=0}^len p_i * (\prod_{j!=i} (eval_at - j)/(i-j) ) -/// -/// This implementation is linear in number of inputs in terms of field -/// operations. It also has a quadratic term in primitive operations which is -/// negligible compared to field operations. -/// TODO: The quadratic term can be removed by precomputing the lagrange -/// coefficients. -pub(crate) fn interpolate_uni_poly>( - p_iter_rev: impl Iterator, - len: usize, - eval_at: E, -) -> E { - let mut evals = vec![eval_at]; - let mut prod = eval_at; - - // `prod = \prod_{j} (eval_at - j)` - for j in 1..len { - let tmp = eval_at - E::from_u64(j as u64); - evals.push(tmp); - prod *= tmp; - } - let mut res = E::ZERO; - // we want to compute \prod (j!=i) (i-j) for a given i - // - // we start from the last step, which is - // denom[len-1] = (len-1) * (len-2) *... * 2 * 1 - // the step before that is - // denom[len-2] = (len-2) * (len-3) * ... * 2 * 1 * -1 - // and the step before that is - // denom[len-3] = (len-3) * (len-4) * ... * 2 * 1 * -1 * -2 - // - // i.e., for any i, the one before this will be derived from - // denom[i-1] = denom[i] * (len-i) / i - // - // that is, we only need to store - // - the last denom for i = len-1, and - // - the ratio between current step and fhe last step, which is the product of (len-i) / i from - // all previous steps and we store this product as a fraction number to reduce field - // divisions. - - let mut denom_up = field_factorial::(len - 1); - let mut denom_down = F::ONE; - - for (j, p_i) in p_iter_rev.enumerate() { - let i = len - j - 1; - res += prod * p_i * denom_down * (evals[i] * denom_up).inverse(); - - // compute denom for the next step is current_denom * (len-i)/i - if i != 0 { - denom_up *= -F::from_u64((j + 1) as u64); - denom_down *= F::from_u64(i as u64); - } - } - res -} - -/// compute the factorial(a) = 1 * 2 * ... * a -#[inline] -fn field_factorial(a: usize) -> F { - let mut res = F::ONE; - for i in 2..=a { - res *= F::from_u64(i as u64); - } - res -} - -#[cfg(test)] -mod test { - use itertools::Itertools; - use multilinear_extensions::virtual_poly::eq_eval; - use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; - use p3_goldilocks::Goldilocks as F; - - use crate::field_vec; - - use super::*; - - type E = BinomialExtensionField; - - #[test] - fn test_power_list() { - let ele = F::from_u64(3u64); - let list = power_list(&ele, 4); - assert_eq!(list, field_vec![F, 1, 3, 9, 27]); - } - - #[test] - fn test_grand_product() { - let ele = field_vec![F, 2, 3, 4, 5]; - let expected = field_vec![F, 1, 2, 6, 24, 120]; - assert_eq!(grand_product(&ele), expected); - } - - #[test] - fn test_eq_vecs() { - let points = [field_vec![E, 2, 3, 5], field_vec![E, 7, 11, 13]]; - let point_refs = points.iter().map(|p| p.as_slice()).collect_vec(); - - let scalars = field_vec![E, 3, 5]; - - let eq_evals = eq_vecs(point_refs.into_iter(), &scalars); - - let expected = vec![ - field_vec![E, -24, 48, 36, -72, 30, -60, -45, 90], - field_vec![E, -3600, 4200, 3960, -4620, 3900, -4550, -4290, 5005], - ]; - assert_eq!(eq_evals, expected); - } - - #[test] - fn test_eq_eval() { - let xs = field_vec![E, 2, 3, 5]; - let ys = field_vec![E, 7, 11, 13]; - let expected = E::from_u64(119780); - assert_eq!(eq_eval(&xs, &ys), expected); - } - - #[test] - fn test_fix_variables_ext() { - let base_mle = field_vec![F, 1, 2, 3, 4, 5, 6]; - let r = E::from_u64(3u64); - let expected = field_vec![E, 4, 6, 8]; - assert_eq!(fix_variables_ext(&base_mle, &r), expected); - } - - #[test] - fn test_fix_variables_inplace() { - let mut ext_mle = field_vec![E, 1, 2, 3, 4, 5, 6]; - let r = E::from_u64(3u64); - fix_variables_inplace(&mut ext_mle, &r); - let expected = field_vec![E, 4, 6, 8]; - assert_eq!(ext_mle[..3], expected); - } - - #[test] - fn test_interpolate_uni_poly() { - // p(x) = x^3 + 2x^2 + 3x + 4 - let p_iter = field_vec![F, 4, 10, 26, 58].into_iter().rev(); - let eval_at = E::from_u64(11); - let expected = E::from_u64(1610); - assert_eq!(interpolate_uni_poly(p_iter, 4, eval_at), expected); - } -} diff --git a/subprotocols/src/zerocheck.rs b/subprotocols/src/zerocheck.rs deleted file mode 100644 index 00f2fd296..000000000 --- a/subprotocols/src/zerocheck.rs +++ /dev/null @@ -1,495 +0,0 @@ -use std::{iter, mem, sync::Arc, vec}; - -use ark_std::log2; -use ff_ext::ExtensionField; -use itertools::{Itertools, chain, izip, zip_eq}; -use p3_field::batch_multiplicative_inverse; -use transcript::Transcript; - -use crate::{ - error::VerifierError, - expression::Expression, - sumcheck::{SumcheckProof, SumcheckProverOutput}, -}; - -use super::{ - sumcheck::SumcheckClaims, - utils::{ - eq_vecs, fix_variables_ext, fix_variables_inplace, grand_product, interpolate_uni_poly, - }, -}; - -/// This is an randomly combined zerocheck protocol for the following equation: -/// \sigma = \sum_x (r^0 eq_0(X) \cdot expr_0(x) + r^1 eq_1(X) \cdot expr_1(x) + ...) -pub struct ZerocheckProverState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - /// Expressions and corresponding half eq reference. - exprs: Vec<(Expression, Vec)>, - - /// Extension field mles. - ext_mles: Vec<&'a mut [E]>, - /// Base field mles after the first round. - base_mles_after: Vec>, - /// Base field mles. - base_mles: Vec<&'a [E::BaseField]>, - /// Challenges occurred in expressions - challenges: &'a [E], - /// For each point in points, the inverse of prod_{j < i}(1 - point[i]) for 0 <= i < point.len(). - grand_prod_of_not_inv: Vec>, - - transcript: &'a mut Trans, - - num_vars: usize, -} - -impl<'a, E, Trans> ZerocheckProverState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - #[allow(clippy::too_many_arguments)] - pub fn new( - exprs: Vec, - points: &[&[E]], - ext_mles: Vec<&'a mut [E]>, - base_mles: Vec<&'a [E::BaseField]>, - challenges: &'a [E], - transcript: &'a mut Trans, - ) -> Self { - assert!(!(ext_mles.is_empty() && base_mles.is_empty())); - - let num_vars = if !ext_mles.is_empty() { - log2(ext_mles[0].len()) as usize - } else { - log2(base_mles[0].len()) as usize - }; - - // For each point, compute eq(point[1..], b) for b in [0, 2^{num_vars - 1}). - let (exprs, grand_prod_of_not_inv) = if num_vars > 0 { - let half_eq_evals = eq_vecs( - points.iter().map(|point| &point[1..]), - &vec![E::ONE; exprs.len()], - ); - let exprs = zip_eq(exprs, half_eq_evals).collect_vec(); - let grand_prod_of_not_inv = points - .iter() - .flat_map(|point| point[1..].iter().map(|p| E::ONE - *p).collect_vec()) - .collect_vec(); - let grand_prod_of_not_inv = batch_multiplicative_inverse(&grand_prod_of_not_inv); - let (_, grand_prod_of_not_inv) = - points - .iter() - .fold((0usize, vec![]), |(start, mut last_vec), point| { - let end = start + point.len() - 1; - last_vec.push(grand_product(&grand_prod_of_not_inv[start..end])); - (end, last_vec) - }); - (exprs, grand_prod_of_not_inv) - } else { - let expr = exprs.into_iter().map(|expr| (expr, vec![])).collect_vec(); - (expr, vec![]) - }; - - // The length of all mles should be 2^{num_vars}. - assert!(ext_mles.iter().all(|mle| mle.len() == 1 << num_vars)); - assert!(base_mles.iter().all(|mle| mle.len() == 1 << num_vars)); - - Self { - exprs, - ext_mles, - base_mles_after: vec![], - base_mles, - challenges, - grand_prod_of_not_inv, - transcript, - num_vars, - } - } - - pub fn prove(mut self) -> SumcheckProverOutput { - let (univariate_polys, point) = (0..self.num_vars) - .map(|round| { - let round_msg = self.compute_univariate_poly(round); - round_msg - .iter() - .for_each(|poly| self.transcript.append_field_element_exts(poly)); - - let r = self - .transcript - .sample_and_append_challenge(b"sumcheck round") - .elements; - self.update_mles(&r, round); - (round_msg, r) - }) - .unzip(); - let point = Arc::new(point); - - // Send the final evaluations - let ZerocheckProverState { - ext_mles, - base_mles_after, - base_mles, - .. - } = self; - let ext_mle_evaluations = ext_mles.into_iter().map(|mle| mle[0]).collect(); - let base_mle_evaluations = if !base_mles.is_empty() { - base_mles.into_iter().map(|mle| E::from(mle[0])).collect() - } else { - base_mles_after.into_iter().map(|mle| mle[0]).collect() - }; - - SumcheckProverOutput { - proof: SumcheckProof { - univariate_polys, - ext_mle_evals: ext_mle_evaluations, - base_mle_evals: base_mle_evaluations, - }, - point, - } - } - - /// Compute f_i(X) = \sum_x eq_i(x) expr_i(X || x) - fn compute_univariate_poly(&self, round: usize) -> Vec> { - izip!(&self.exprs, &self.grand_prod_of_not_inv) - .map(|((expr, half_eq_mle), coeff)| { - let mut uni_poly = expr.zerocheck_uni_poly( - &self.ext_mles, - &self.base_mles_after, - &self.base_mles, - self.challenges, - half_eq_mle.iter().step_by(1 << round), - 1 << (self.num_vars - round), - ); - uni_poly.iter_mut().for_each(|x| *x *= coeff[round]); - uni_poly - }) - .collect_vec() - } - - fn update_mles(&mut self, r: &E, round: usize) { - // fix variables of base field polynomials. - self.ext_mles.iter_mut().for_each(|mle| { - fix_variables_inplace(mle, r); - }); - if round == 0 { - self.base_mles_after = mem::take(&mut self.base_mles) - .into_iter() - .map(|mle| fix_variables_ext(mle, r)) - .collect(); - } else { - self.base_mles_after - .iter_mut() - .for_each(|mle| fix_variables_inplace(mle, r)); - } - } -} - -pub struct ZerocheckVerifierState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - sigmas: Vec, - inv_of_one_minus_points: Vec>, - exprs: Vec<(Expression, &'a [E])>, - proof: SumcheckProof, - expr_names: Vec, - challenges: &'a [E], - transcript: &'a mut Trans, -} - -impl<'a, E, Trans> ZerocheckVerifierState<'a, E, Trans> -where - E: ExtensionField, - Trans: Transcript, -{ - pub fn new( - sigmas: Vec, - exprs: Vec, - expr_names: Vec, - points: Vec<&'a [E]>, - proof: SumcheckProof, - challenges: &'a [E], - transcript: &'a mut Trans, - ) -> Self { - // Fill in missing debug data - let mut expr_names = expr_names; - expr_names.resize(exprs.len(), "nothing".to_owned()); - - let inv_of_one_minus_points = points - .iter() - .flat_map(|point| point.iter().map(|p| E::ONE - *p).collect_vec()) - .collect_vec(); - let inv_of_one_minus_points = batch_multiplicative_inverse(&inv_of_one_minus_points); - let (_, inv_of_one_minus_points) = - points - .iter() - .fold((0usize, vec![]), |(start, mut last_vec), point| { - let end = start + point.len(); - last_vec.push(inv_of_one_minus_points[start..start + point.len()].to_vec()); - (end, last_vec) - }); - - let exprs = zip_eq(exprs, points).collect_vec(); - Self { - sigmas, - inv_of_one_minus_points, - exprs, - proof, - challenges, - transcript, - expr_names, - } - } - - pub fn verify(self) -> Result, VerifierError> { - let ZerocheckVerifierState { - sigmas, - inv_of_one_minus_points, - exprs, - proof, - challenges, - transcript, - expr_names, - .. - } = self; - let SumcheckProof { - univariate_polys, - ext_mle_evals, - base_mle_evals, - } = proof; - - let (in_point, expected_claims) = univariate_polys.into_iter().enumerate().fold( - (vec![], sigmas), - |(mut last_point, last_sigmas), (round, round_msg)| { - round_msg - .iter() - .for_each(|poly| transcript.append_field_element_exts(poly)); - let r = transcript - .sample_and_append_challenge(b"sumcheck round") - .elements; - last_point.push(r); - - let sigmas = izip!(&exprs, &inv_of_one_minus_points, round_msg, last_sigmas) - .map(|((_, point), inv_of_one_minus_point, poly, last_sigma)| { - let len = poly.len() + 1; - // last_sigma = (1 - point[round]) * eval_at_0 + point[round] * eval_at_1 - // eval_at_0 = (last_sigma - point[round] * eval_at_1) * inv(1 - point[round]) - let eval_at_0 = if !poly.is_empty() { - (last_sigma - point[round] * poly[0]) * inv_of_one_minus_point[round] - } else { - last_sigma - }; - - // Evaluations on degree, degree - 1, ..., 1, 0. - let evals_iter_rev = chain![poly.into_iter().rev(), iter::once(eval_at_0)]; - - interpolate_uni_poly(evals_iter_rev, len, r) - }) - .collect_vec(); - - (last_point, sigmas) - }, - ); - - // Check the final evaluations. - assert_eq!(expr_names.len(), exprs.len()); - // assert_eq!(expected_claims.len(), expr_names.len()); - - for (expected_claim, (expr, _), expr_name) in izip!(expected_claims, exprs, expr_names) { - let got_claim = expr.evaluate(&ext_mle_evals, &base_mle_evals, &[], &[], challenges); - - if expected_claim != got_claim { - return Err(VerifierError::ClaimNotMatch( - expr, - expected_claim, - got_claim, - expr_name.clone(), - )); - } - } - - let in_point = Arc::new(in_point); - Ok(SumcheckClaims { - in_point, - ext_mle_evals, - base_mle_evals, - }) - } -} - -#[cfg(test)] -mod test { - use std::array; - - use ff_ext::ExtensionField; - use itertools::{Itertools, izip}; - use p3_field::{PrimeCharacteristicRing, extension::BinomialExtensionField}; - use p3_goldilocks::Goldilocks as F; - use transcript::BasicTranscript; - - use crate::{ - expression::{Constant, Expression, Witness}, - field_vec, - sumcheck::SumcheckProverOutput, - }; - - use super::{ZerocheckProverState, ZerocheckVerifierState}; - - type E = BinomialExtensionField; - - #[allow(clippy::too_many_arguments)] - fn run<'a, E: ExtensionField>( - points: Vec<&[E]>, - exprs: Vec, - ext_mle_refs: Vec<&'a mut [E]>, - base_mle_refs: Vec<&'a [E::BaseField]>, - challenges: Vec, - - sigmas: Vec, - ) { - let mut prover_transcript = BasicTranscript::new(b"test"); - let prover = ZerocheckProverState::new( - exprs.clone(), - &points, - ext_mle_refs, - base_mle_refs, - &challenges, - &mut prover_transcript, - ); - - let SumcheckProverOutput { proof, .. } = prover.prove(); - - let mut verifier_transcript = BasicTranscript::new(b"test"); - let verifier = ZerocheckVerifierState::new( - sigmas, - exprs, - vec![], - points, - proof, - &challenges, - &mut verifier_transcript, - ); - - verifier.verify().expect("verification failed"); - } - - #[test] - fn test_zerocheck_trivial() { - let f = field_vec![F, 2]; - let g = field_vec![F, 3]; - let out_point = vec![]; - - let base_mle_refs = vec![f.as_slice(), g.as_slice()]; - let f = Expression::Wit(Witness::BasePoly(0)); - let g = Expression::Wit(Witness::BasePoly(1)); - let expr = f * g; - - run( - vec![out_point.as_slice()], - vec![expr], - vec![], - base_mle_refs, - vec![], - vec![E::from_u64(6)], - ); - } - - #[test] - fn test_zerocheck_simple() { - let f = field_vec![F, 1, 2, 3, 4, 5, 6, 7, 8]; - let out_point = field_vec![E, 2, 3, 5]; - let out_eq = field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30]; - let ans = izip!(&out_eq, &f).fold(E::ZERO, |acc, (c, x)| acc + *c * *x); - - let base_mle_refs = vec![f.as_slice()]; - let expr = Expression::Wit(Witness::BasePoly(0)); - run( - vec![out_point.as_slice()], - vec![expr.clone()], - vec![], - base_mle_refs, - vec![], - vec![ans], - ); - } - - #[test] - fn test_zerocheck_logup() { - let out_point = field_vec![E, 2, 3, 5]; - let out_eq = field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30]; - - let d0 = field_vec![E, 1, 2, 3, 4, 5, 6, 7, 8]; - let d1 = field_vec![E, 9, 10, 11, 12, 13, 14, 15, 16]; - let n0 = field_vec![E, 17, 18, 19, 20, 21, 22, 23, 24]; - let n1 = field_vec![E, 25, 26, 27, 28, 29, 30, 31, 32]; - - let challenges = vec![E::from_u64(7)]; - let ans = izip!(&out_eq, &d0, &d1, &n0, &n1) - .map(|(eq, d0, d1, n0, n1)| *eq * (*d0 * *d1 + challenges[0] * (*d0 * *n1 + *d1 * *n0))) - .sum(); - - let mut ext_mles = [d0, d1, n0, n1]; - let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::ExtPoly(i))); - let beta = Expression::Const(Constant::Challenge(0)); - let expr = d0.clone() * d1.clone() + beta * (d0 * n1 + d1 * n0); - - let ext_mles_refs = ext_mles.iter_mut().map(|v| v.as_mut_slice()).collect_vec(); - run( - vec![out_point.as_slice()], - vec![expr.clone()], - ext_mles_refs, - vec![], - challenges, - vec![ans], - ); - } - - #[test] - fn test_zerocheck_multi_points() { - let points = [ - field_vec![E, 2, 3, 5], - field_vec![E, 7, 11, 13], - field_vec![E, 17, 19, 23], - ]; - let out_eqs = [ - field_vec![E, -8, 16, 12, -24, 10, -20, -15, 30], - field_vec![E, -720, 840, 792, -924, 780, -910, -858, 1001], - field_vec![E, -6336, 6732, 6688, -7106, 6624, -7038, -6992, 7429], - ]; - let point_refs = points.iter().map(|v| v.as_slice()).collect_vec(); - - let d0 = field_vec![F, 1, 2, 3, 4, 5, 6, 7, 8]; - let d1 = field_vec![F, 9, 10, 11, 12, 13, 14, 15, 16]; - let n0 = field_vec![F, 17, 18, 19, 20, 21, 22, 23, 24]; - let n1 = field_vec![F, 25, 26, 27, 28, 29, 30, 31, 32]; - - let ans_0 = izip!(&out_eqs[0], &d0, &d1) - .map(|(eq0, d0, d1)| *eq0 * *d0 * *d1) - .sum(); - let ans_1 = izip!(&out_eqs[1], &d0, &n1) - .map(|(eq1, d0, n1)| *eq1 * *d0 * *n1) - .sum(); - let ans_2 = izip!(&out_eqs[2], &d1, &n0) - .map(|(eq2, d1, n0)| *eq2 * *d1 * *n0) - .sum(); - - let base_mles = [d0, d1, n0, n1]; - let [d0, d1, n0, n1] = array::from_fn(|i| Expression::Wit(Witness::BasePoly(i))); - - let exprs = vec![d0.clone() * d1.clone(), d0 * n1, d1 * n0]; - - let base_mle_refs = base_mles.iter().map(|v| v.as_slice()).collect_vec(); - run( - point_refs, - exprs, - vec![], - base_mle_refs, - vec![], - vec![ans_0, ans_1, ans_2], - ); - } -} From a516ef029d8f1291de140be6f4aca0422565ef23 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 16:39:06 +0800 Subject: [PATCH 23/28] show gkr iop proof size --- Cargo.lock | 1 + Cargo.toml | 2 +- gkr_iop/Cargo.toml | 11 ++++++----- gkr_iop/src/gkr.rs | 20 ++++++++++++++++++++ gkr_iop/src/precompiles/lookup_keccakf.rs | 8 +++++--- sumcheck/Cargo.toml | 2 +- 6 files changed, 34 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0e9045eb4..b00d53a7e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1113,6 +1113,7 @@ name = "gkr_iop" version = "0.1.0" dependencies = [ "ark-std", + "bincode", "clap", "criterion", "either", diff --git a/Cargo.toml b/Cargo.toml index 543e2f7d7..1167fd29a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -73,10 +73,10 @@ secp = "0.4.1" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" strum = "0.26" -thiserror = "1" # do we need this? strum_macros = "0.26" substrate-bn = { version = "0.6.0" } sumcheck = { path = "sumcheck" } +thiserror = "1" # do we need this? tiny-keccak = { version = "2.0.2", features = ["keccak"] } tracing = { version = "0.1", features = [ "attributes", diff --git a/gkr_iop/Cargo.toml b/gkr_iop/Cargo.toml index 2937e57fe..dc5322eeb 100644 --- a/gkr_iop/Cargo.toml +++ b/gkr_iop/Cargo.toml @@ -10,9 +10,10 @@ repository.workspace = true version.workspace = true [dependencies] -clap.workspace = true -tracing-forest.workspace = true ark-std = { version = "0.5" } +bincode.workspace = true +clap.workspace = true +either.workspace = true ff_ext = { path = "../ff_ext" } itertools.workspace = true multilinear_extensions = { version = "0.1.0", path = "../multilinear_extensions" } @@ -25,10 +26,10 @@ serde.workspace = true sumcheck.workspace = true thiserror.workspace = true tiny-keccak.workspace = true -either.workspace = true -transcript = { path = "../transcript" } tracing.workspace = true +tracing-forest.workspace = true tracing-subscriber.workspace = true +transcript = { path = "../transcript" } witness = { path = "../witness" } [target.'cfg(unix)'.dependencies] @@ -47,4 +48,4 @@ harness = false name = "lookup_keccakf" [features] -jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] \ No newline at end of file +jemalloc = ["dep:tikv-jemallocator", "dep:tikv-jemalloc-ctl"] diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 1d7bb5144..00cfb9b80 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,3 +1,5 @@ +use core::fmt; + use ff_ext::ExtensionField; use itertools::{Itertools, chain, izip}; use layer::{Layer, LayerWitness, sumcheck_layer::SumcheckLayerProof}; @@ -147,3 +149,21 @@ impl GKRCircuit { .collect_vec() } } + +impl fmt::Display for GKRProof { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // overall size + let overall_size = bincode::serialized_size(&self).expect("serialization error"); + + write!( + f, + "overall_size {:.2}mb. \n\ + ", + byte_to_mb(overall_size), + ) + } +} + +fn byte_to_mb(byte_size: u64) -> f64 { + byte_size as f64 / (1024.0 * 1024.0) +} diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 9314d9503..4ae79de54 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -14,9 +14,10 @@ use witness::{InstancePaddingStrategy, RowMajorMatrix}; use crate::{ ProtocolBuilder, ProtocolWitnessGenerator, chip::Chip, + error::BackendError, evaluation::EvalExpression, gkr::{ - GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, GKRProverOutput, + GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, GKRProof, GKRProverOutput, layer::{Layer, LayerType, LayerWitness}, }, precompiles::utils::{MaskRepresentation, not8_expr}, @@ -1247,7 +1248,7 @@ pub fn run_faster_keccakf( states: Vec<[u64; 25]>, verify: bool, test_outputs: bool, -) { +) -> Result, BackendError> { let num_instances = states.len(); let num_instances_rounds = num_instances * ROUNDS.next_power_of_two(); let log2_num_instance_rounds = ceil_log2(num_instances_rounds); @@ -1363,7 +1364,7 @@ pub fn run_faster_keccakf( gkr_circuit .verify( log2_num_instance_rounds, - gkr_proof, + gkr_proof.clone(), &out_evals, &[], &mut verifier_transcript, @@ -1373,6 +1374,7 @@ pub fn run_faster_keccakf( // Omit the PCS opening phase. } } + Ok(gkr_proof) } #[cfg(test)] diff --git a/sumcheck/Cargo.toml b/sumcheck/Cargo.toml index d5c8b3f67..ec2dc4bb1 100644 --- a/sumcheck/Cargo.toml +++ b/sumcheck/Cargo.toml @@ -16,10 +16,10 @@ ff_ext = { path = "../ff_ext" } itertools.workspace = true multilinear_extensions = { path = "../multilinear_extensions", features = ["parallel"] } p3.workspace = true -thiserror.workspace = true rayon.workspace = true serde.workspace = true sumcheck_macro = { path = "../sumcheck_macro" } +thiserror.workspace = true tracing.workspace = true transcript = { path = "../transcript" } From 31cfa76d767001ce59afae281e30efbaf8605f90 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 16:50:12 +0800 Subject: [PATCH 24/28] coding style adjustment and cleanup --- gkr_iop/src/bin/bitwise_keccak.rs | 7 +++- gkr_iop/src/precompiles/lookup_keccakf.rs | 50 ++++++++--------------- gkr_iop/src/precompiles/utils.rs | 4 -- 3 files changed, 23 insertions(+), 38 deletions(-) diff --git a/gkr_iop/src/bin/bitwise_keccak.rs b/gkr_iop/src/bin/bitwise_keccak.rs index 14e4af84a..4e27af638 100644 --- a/gkr_iop/src/bin/bitwise_keccak.rs +++ b/gkr_iop/src/bin/bitwise_keccak.rs @@ -8,7 +8,12 @@ use tracing_forest::ForestLayer; use tracing_subscriber::{ EnvFilter, Registry, filter::filter_fn, fmt, layer::SubscriberExt, util::SubscriberInitExt, }; -/// Prove the execution of a fixed RISC-V program. + +// Use jemalloc as global allocator for performance +#[cfg(all(feature = "jemalloc", unix, not(test)))] +#[global_allocator] +static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; + #[derive(Parser, Debug)] #[command(version, about, long_about = None)] struct Args { diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 4ae79de54..9af67e8c7 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -1300,7 +1300,7 @@ pub fn run_faster_keccakf( { assert_eq!( base.evaluations().len(), - num_instances * ROUNDS.next_power_of_two() + (num_instances * ROUNDS.next_power_of_two()).next_power_of_two() ); for i in 0..num_instances { instance_outputs[i].push(base.get_base_field_vec()[i]); @@ -1386,45 +1386,29 @@ mod tests { #[test] fn test_keccakf() { type E = GoldilocksExt2; - std::thread::Builder::new() - .name("keccak_test".into()) - .stack_size(64 * 1024 * 1024) - .spawn(|| { - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - - let num_instances = 8; - let mut states: Vec<[u64; 25]> = Vec::with_capacity(num_instances); - for _ in 0..num_instances { - states.push(std::array::from_fn(|_| rng.gen())); - } + let mut rng = rand::rngs::StdRng::seed_from_u64(42); - run_faster_keccakf(setup_gkr_circuit::(), states, false, true); - }) - .unwrap() - .join() - .unwrap(); + let num_instances = 8; + let mut states: Vec<[u64; 25]> = Vec::with_capacity(num_instances); + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())); + } + let _ = run_faster_keccakf(setup_gkr_circuit::(), states, false, true); } #[ignore] #[test] fn test_keccakf_nonpow2() { type E = GoldilocksExt2; - std::thread::Builder::new() - .name("keccak_test".into()) - .stack_size(64 * 1024 * 1024) - .spawn(|| { - let mut rng = rand::rngs::StdRng::seed_from_u64(42); - - let num_instances = 5; - let mut states: Vec<[u64; 25]> = Vec::with_capacity(num_instances); - for _ in 0..num_instances { - states.push(std::array::from_fn(|_| rng.gen())); - } - run_faster_keccakf(setup_gkr_circuit::(), states, false, true); - }) - .unwrap() - .join() - .unwrap(); + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + let num_instances = 5; + let mut states: Vec<[u64; 25]> = Vec::with_capacity(num_instances); + for _ in 0..num_instances { + states.push(std::array::from_fn(|_| rng.gen())); + } + + let _ = run_faster_keccakf(setup_gkr_circuit::(), states, false, true); } } diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs index dbba574b4..299da91ad 100644 --- a/gkr_iop/src/precompiles/utils.rs +++ b/gkr_iop/src/precompiles/utils.rs @@ -17,10 +17,6 @@ pub fn zero_eval() -> EvalExpression { ) } -pub fn nest(v: &[E::BaseField]) -> Vec> { - v.iter().map(|e| vec![*e]).collect_vec() -} - pub fn u64s_to_felts(words: Vec) -> Vec { words.into_iter().map(E::BaseField::from_u64).collect() } From 63a8a75f93e5dfb34b87ab749ff2ebc1ae13e273 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Fri, 30 May 2025 23:24:01 +0800 Subject: [PATCH 25/28] optimize witness generation --- gkr_iop/src/lib.rs | 32 ++- gkr_iop/src/precompiles/lookup_keccakf.rs | 330 +++------------------- 2 files changed, 69 insertions(+), 293 deletions(-) diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 933117244..7ac59d4c8 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -61,8 +61,7 @@ where challenges: &[E], ) -> (GKRCircuitWitness<'a, E>, GKRCircuitOutput) { // layer order from output to input - let n_layers = 100; - let mut layer_wits = Vec::>::with_capacity(n_layers + 1); + let mut layer_wits = Vec::>::with_capacity(circuit.layers.len() + 1); let phase1_witness_group = phase1_witness_group .to_mles() .into_iter() @@ -72,6 +71,10 @@ where layer_wits.push(LayerWitness::new(phase1_witness_group.clone())); let mut witness_mle_flattern = vec![None; circuit.n_evaluations]; + // this is to record every witness layer out + // only last layer will be use as the whole gkr circuit out + let mut gkr_out_well_order = Vec::with_capacity(circuit.n_evaluations); + // set input to witness_mle_flattern via first layer in_eval_expr circuit.layers.last().map(|first_layer| { first_layer @@ -113,7 +116,14 @@ where .zip_eq(¤t_layer_output) .for_each(|(out_eval, out_mle)| match out_eval { EvalExpression::Single(out) => { - witness_mle_flattern[*out] = Some(out_mle.clone()) + witness_mle_flattern[*out] = Some(out_mle.clone()); + // last layer we record gkr circuit output + if i == circuit.layers.len() - 1 { + gkr_out_well_order.push((*out, out_mle.clone())); + } + } + EvalExpression::Linear(0, _, _) => { // zero expression + // do nothing on zero expression } other => unimplemented!("{:?}", other), }); @@ -122,8 +132,20 @@ where layer_wits.reverse(); - GKRCircuitWitness { layers: layer_wits }; - unimplemented!() + // process and sort by out_id + gkr_out_well_order.sort_by_key(|(i, _)| *i); + let gkr_out_well_order = gkr_out_well_order + .into_iter() + .map(|(_, val)| val) + .collect_vec(); + + ( + GKRCircuitWitness { layers: layer_wits }, + GKRCircuitOutput(LayerWitness { + bases: gkr_out_well_order, + ..Default::default() + }), + ) } } diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index 9af67e8c7..f1e05ad46 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -1,13 +1,16 @@ -use std::{array, cmp::Ordering, marker::PhantomData, sync::Arc}; +use std::{array, cmp::Ordering, marker::PhantomData}; -use ff_ext::{ExtensionField, SmallField}; +use ff_ext::ExtensionField; use itertools::{Itertools, chain, iproduct, zip_eq}; use multilinear_extensions::{Expression, ToExpr, WitIn, mle::PointAndEval, util::ceil_log2}; use ndarray::{ArrayView, Ix2, Ix3, s}; use p3_field::PrimeCharacteristicRing; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; use serde::{Deserialize, Serialize}; -use sumcheck::util::optimal_sumcheck_threads; +use sumcheck::{ + macros::{entered_span, exit_span}, + util::optimal_sumcheck_threads, +}; use transcript::{BasicTranscript, Transcript}; use witness::{InstancePaddingStrategy, RowMajorMatrix}; @@ -17,8 +20,8 @@ use crate::{ error::BackendError, evaluation::EvalExpression, gkr::{ - GKRCircuit, GKRCircuitOutput, GKRCircuitWitness, GKRProof, GKRProverOutput, - layer::{Layer, LayerType, LayerWitness}, + GKRCircuit, GKRProof, GKRProverOutput, + layer::{Layer, LayerType}, }, precompiles::utils::{MaskRepresentation, not8_expr}, }; @@ -437,6 +440,8 @@ impl ProtocolBuilder for KeccakLayout { ] .map(|many| final_outputs_iter.by_ref().take(many).collect_vec()); + assert!(final_outputs_iter.next().is_none()); + let lookup_outputs = lookup_outputs.to_vec(); // TODO we should separate into different eq group, because they should reduce from differenent points @@ -686,17 +691,14 @@ impl ProtocolBuilder for KeccakLayout { .enumerate() { expressions.push(lookup); - let (idx, round) = if i < 3 * AND_LOOKUPS { - let round = i / AND_LOOKUPS; - (&mut global_and_lookup, round) + let idx = if i < 3 * AND_LOOKUPS { + &mut global_and_lookup } else if i < 3 * AND_LOOKUPS + 3 * XOR_LOOKUPS { - let round = (i - 3 * AND_LOOKUPS) / XOR_LOOKUPS; - (&mut global_xor_lookup, round) + &mut global_xor_lookup } else { - let round = (i - 3 * AND_LOOKUPS - 3 * XOR_LOOKUPS) / RANGE_LOOKUPS; - (&mut global_range_lookup, round) + &mut global_range_lookup }; - expr_names.push(format!("round {round}: {i}th lookup felt")); + expr_names.push(format!("round 0th: {i}th lookup felt")); evals.push(lookup_outputs[*idx].clone()); *idx += 1; } @@ -710,6 +712,7 @@ impl ProtocolBuilder for KeccakLayout { let keccak_input32 = keccak_input32.to_vec(); let mut keccak_input32_iter = keccak_input32.iter().cloned(); + // process keccak input for x in 0..5 { for y in 0..5 { for k in 0..2 { @@ -734,6 +737,7 @@ impl ProtocolBuilder for KeccakLayout { ArrayView::from_shape((5, 5, 8), keccak_output8).unwrap(); let mut keccak_output32_iter = keccak_output32.iter().cloned(); + // process keccak output for x in 0..5 { for y in 0..5 { for k in 0..2 { @@ -924,7 +928,8 @@ where // Iota step let mut iota_output64 = chi_output64; let mut iota_output8 = [[[0u64; 8]; 5]; 5]; - iota_output64[0][0] ^= RC[round]; + // TODO figure out how to deal with RC, since it's not a constant in rotation + iota_output64[0][0] ^= RC[0]; for x in 0..5 { for y in 0..5 { @@ -964,277 +969,6 @@ where .collect(); RowMajorMatrix::new_by_values(wits, KECCAK_WIT_SIZE, InstancePaddingStrategy::Default) } - - fn gkr_witness( - &self, - _circuit: &GKRCircuit, - phase1: &RowMajorMatrix, - _challenges: &[E], - ) -> (GKRCircuitWitness<'a, E>, GKRCircuitOutput) { - // TODO: fix inefficiency as here as it convert felts back to u64 - let instances_rounds = phase1 - .values - .par_iter() - .map(|wit| wit.to_canonical_u64()) - .collect::>(); - let num_instances_with_rotations = 1 << phase1.num_vars(); - let num_cols = phase1.n_col(); - assert_eq!(num_cols, KECCAK_WIT_SIZE); - - let to_5x5x8_array = |input: &[u64]| -> [[[u64; 8]; 5]; 5] { - assert_eq!(input.len(), 5 * 5 * 8); - input - .chunks(40) - .map(|chunk| { - chunk - .chunks(8) - .map(|x| x.to_vec().try_into().unwrap()) - .collect_vec() - .try_into() - .unwrap() - }) - .collect_vec() - .try_into() - .unwrap() - }; - let to_5x8_array = |input: &[u64]| -> [[u64; 8]; 5] { - input - .chunks(8) - .map(|x| x.to_vec().try_into().unwrap()) - .collect_vec() - .try_into() - .unwrap() - }; - let u8_slice_to_u64 = - |input: &[u64]| -> u64 { input.iter().rev().fold(0, |acc, &e| (acc << 8) | e) }; - let u8_slice_to_u32_slice = |input: &[u64]| -> [u64; 2] { - input - .chunks(4) - .map(u8_slice_to_u64) - .collect_vec() - .try_into() - .unwrap() - }; - - // process output bases - let output_bases: Vec = (0..num_instances_with_rotations) - .into_par_iter() - .flat_map(|instance_round_id| { - let round = instance_round_id % ROUNDS.next_power_of_two(); - - if round >= ROUNDS { - // padding with zero - return vec![0; KECCAK_OUT_EVAL_SIZE]; - } - - let state8: [[[u64; 8]; 5]; 5] = to_5x5x8_array( - &instances_rounds[instance_round_id * num_cols..][..KECCAK_LAYER_BYTE_SIZE], - ); - let mut keccak_input32 = [[[0u64; 2]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - keccak_input32[x][y] = u8_slice_to_u32_slice(&state8[x][y]); - } - } - // #[allow(clippy::needless_range_loop)] - // for round in 0..ROUNDS { - // TODO use with_capacity and retrive number of lookup from circuit - let mut and_lookups: Vec = vec![]; - let mut xor_lookups: Vec = vec![]; - let mut range_lookups: Vec = vec![]; - - let mut add_and = |a: u64, b: u64| { - let c = a & b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - and_lookups.extend(vec![a, b, c]); - }; - - let mut add_xor = |a: u64, b: u64| { - let c = a ^ b; - assert!(a < (1 << 8)); - assert!(b < (1 << 8)); - xor_lookups.extend(vec![a, b, c]); - }; - - let mut add_range = |value: u64, size: usize| { - assert!(size <= 16, "{size}"); - range_lookups.push(value); - if size < 16 { - range_lookups.push(value << (16 - size)); - assert!(value << (16 - size) < (1 << 16)); - } - }; - - let ( - c_aux8, - _c_temp, - crot8, - d8, - theta_state8, - _rotation_witness, - rhopi_output8, - nonlinear8, - chi_output8, - iota_output8, - ) = split_from_offset!( - instances_rounds[instance_round_id * num_cols..][..num_cols], - KECCAK_LAYER_BYTE_SIZE, // offset - KECCAK_WIT_SIZE_PER_ROUND, - 200, - 30, - 40, - 40, - 200, - 146, - 200, - 200, - 8, - 200 - ); - let c_aux8 = to_5x5x8_array(&c_aux8); - - for i in 0..5 { - for j in 1..5 { - for k in 0..8 { - add_xor(c_aux8[i][j - 1][k], state8[j][i][k]); - } - } - } - - let mut c8 = [[0u64; 8]; 5]; - let mut c64 = [0u64; 5]; - - for x in 0..5 { - c8[x] = c_aux8[x][4]; - c64[x] = u8_slice_to_u64(&c8[x]); - } - - for i in 0..5 { - let rep = MaskRepresentation::new(vec![(64, c64[i]).into()]) - .convert(vec![16, 15, 1, 16, 15, 1]); - for mask in rep.rep { - add_range(mask.value, mask.size); - } - } - - let crot8 = to_5x8_array(&crot8); - let d8 = to_5x8_array(&d8); - for x in 0..5 { - for k in 0..8 { - add_xor(c_aux8[(x + 4) % 5][4][k], crot8[(x + 1) % 5][k]); - } - } - - let theta_state8 = to_5x5x8_array(&theta_state8); - let mut theta_state64 = [[0u64; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - theta_state64[y][x] = u8_slice_to_u64(&theta_state8[y][x]); - } - } - - for x in 0..5 { - for y in 0..5 { - for k in 0..8 { - add_xor(state8[y][x][k], d8[x][k]); - } - - let (sizes, _) = rotation_split(ROTATION_CONSTANTS[y][x]); - let rep = MaskRepresentation::new(vec![(64, theta_state64[y][x]).into()]) - .convert(sizes); - for mask in rep.rep.iter() { - if mask.size != 32 { - add_range(mask.value, mask.size); - } - } - } - } - - // Rho and Pi steps - let rhopi_output8 = to_5x5x8_array(&rhopi_output8); - - // Chi step - let nonlinear8 = to_5x5x8_array(&nonlinear8); - for x in 0..5 { - for y in 0..5 { - for k in 0..8 { - add_and( - 0xFF - rhopi_output8[y][(x + 1) % 5][k], - rhopi_output8[y][(x + 2) % 5][k], - ); - } - } - } - - for x in 0..5 { - for y in 0..5 { - for k in 0..8 { - add_xor(rhopi_output8[y][x][k], nonlinear8[y][x][k]) - } - } - } - - // Iota step - let chi_output8: [u64; 8] = chi_output8.try_into().unwrap(); // only save chi_output8[0][0]; - let iota_output8 = to_5x5x8_array(&iota_output8); - - for k in 0..8 { - add_xor(chi_output8[k], (RC[round] >> (k * 8)) & 0xFF); - } - - // } - - let mut keccak_output32 = [[[0u64; 2]; 5]; 5]; - for x in 0..5 { - for y in 0..5 { - keccak_output32[x][y] = u8_slice_to_u32_slice(&iota_output8[x][y]); - } - } - - chain!( - keccak_output32.into_iter().flatten().flatten(), - keccak_input32.into_iter().flatten().flatten(), - and_lookups, - xor_lookups, - range_lookups - ) - .collect_vec() - }) - .collect(); - - assert_eq!( - output_bases.len(), - num_instances_with_rotations * KECCAK_OUT_EVAL_SIZE - ); - - let bases = phase1.to_mles().into_iter().map(Arc::new).collect_vec(); - let output_bases = RowMajorMatrix::new_by_values( - output_bases - .into_iter() - .map(E::BaseField::from_u64) - .collect(), - KECCAK_OUT_EVAL_SIZE, - InstancePaddingStrategy::Default, - ) - .to_mles() - .into_iter() - .map(Arc::new) - .collect_vec(); - - ( - GKRCircuitWitness { - layers: vec![LayerWitness { - bases, - ..Default::default() - }], - }, - GKRCircuitOutput(LayerWitness { - bases: output_bases, - ..Default::default() - }), - ) - } } pub fn setup_gkr_circuit() -> (KeccakLayout, GKRCircuit) { @@ -1243,6 +977,12 @@ pub fn setup_gkr_circuit() -> (KeccakLayout, GKRCircuit (layout, chip.gkr_circuit()) } +#[tracing::instrument( + skip_all, + name = "run_faster_keccakf", + level = "trace", + fields(profiling_1) +)] pub fn run_faster_keccakf( (layout, gkr_circuit): (KeccakLayout, GKRCircuit), states: Vec<[u64; 25]>, @@ -1255,6 +995,7 @@ pub fn run_faster_keccakf( let num_threads = optimal_sumcheck_threads(log2_num_instance_rounds); let mut instances = Vec::with_capacity(num_instances); + let span = entered_span!("instances", profiling_2 = true); for state in &states { let state_mask64 = MaskRepresentation::from(state.iter().map(|e| (64, *e)).collect_vec()); let state_mask32 = state_mask64.convert(vec![32; 50]); @@ -1269,16 +1010,22 @@ pub fn run_faster_keccakf( .unwrap(), ); } + exit_span!(span); + let span = entered_span!("phase1_witness", profiling_2 = true); let phase1_witness = layout.phase1_witness_group(KeccakTrace { instances: instances, }); + exit_span!(span); let mut prover_transcript = BasicTranscript::::new(b"protocol"); + let span = entered_span!("gkr_witness", profiling_2 = true); // Omit the commit phase1 and phase2. let (gkr_witness, _gkr_output) = layout.gkr_witness(&gkr_circuit, &phase1_witness, &[]); + exit_span!(span); + let span = entered_span!("out_eval", profiling_2 = true); let out_evals = { let mut point = Vec::with_capacity(log2_num_instance_rounds); point.extend( @@ -1329,7 +1076,11 @@ pub fn run_faster_keccakf( .iter() .map(|base| PointAndEval { point: point.clone(), - eval: base.evaluate(&point), + eval: if base.num_vars() == 0 { + base.get_base_field_vec()[0].into() + } else { + base.evaluate(&point) + }, }) .collect_vec(); @@ -1337,7 +1088,9 @@ pub fn run_faster_keccakf( out_evals }; + exit_span!(span); + let span = entered_span!("create_proof", profiling_2 = true); let GKRProverOutput { gkr_proof, .. } = gkr_circuit .prove( num_threads, @@ -1348,6 +1101,7 @@ pub fn run_faster_keccakf( &mut prover_transcript, ) .expect("Failed to prove phase"); + exit_span!(span); if verify { { @@ -1393,7 +1147,7 @@ mod tests { for _ in 0..num_instances { states.push(std::array::from_fn(|_| rng.gen())); } - let _ = run_faster_keccakf(setup_gkr_circuit::(), states, false, true); + let _ = run_faster_keccakf(setup_gkr_circuit::(), states, true, true); } #[ignore] From 14299e3c945cd41b1a65ecfc5c6a0e8d5f1015cc Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 3 Jun 2025 16:14:12 +0800 Subject: [PATCH 26/28] fix typo and add zero_eval as new type --- gkr_iop/src/chip/builder.rs | 4 ++-- gkr_iop/src/evaluation.rs | 2 ++ gkr_iop/src/gkr.rs | 5 +++-- gkr_iop/src/gkr/layer.rs | 4 ++-- gkr_iop/src/gkr/mock.rs | 1 + gkr_iop/src/lib.rs | 2 +- gkr_iop/src/precompiles/lookup_keccakf.rs | 4 ++-- gkr_iop/src/precompiles/utils.rs | 10 ---------- 8 files changed, 13 insertions(+), 19 deletions(-) diff --git a/gkr_iop/src/chip/builder.rs b/gkr_iop/src/chip/builder.rs index de00ce3ae..4f9197b76 100644 --- a/gkr_iop/src/chip/builder.rs +++ b/gkr_iop/src/chip/builder.rs @@ -74,11 +74,11 @@ impl Chip { /// Allocate challenges. pub fn allocate_challenges(&mut self) -> [Expression; N] { - let challanges = array::from_fn(|i| { + let challenges = array::from_fn(|i| { Expression::Challenge((i + self.n_challenges) as ChallengeId, 1, E::ONE, E::ZERO) }); self.n_challenges += N; - challanges + challenges } /// Allocate a PCS opening action to a base polynomial with index diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index af0d59f2d..f7a9479a1 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize, de::DeserializeOwned}; #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(bound = "E: ExtensionField + DeserializeOwned")] pub enum EvalExpression { + Zero, /// Single entry in the evaluation vector. Single(usize), /// Linear expression of an entry with the scalar and offset. @@ -40,6 +41,7 @@ fn evaluate(expr: &Expression, challenges: &[E]) -> E { impl EvalExpression { pub fn evaluate(&self, evals: &[PointAndEval], challenges: &[E]) -> PointAndEval { match self { + EvalExpression::Zero => PointAndEval::default(), EvalExpression::Single(i) => evals[*i].clone(), EvalExpression::Linear(i, c0, c1) => PointAndEval { point: evals[*i].point.clone(), diff --git a/gkr_iop/src/gkr.rs b/gkr_iop/src/gkr.rs index 00cfb9b80..a1e4a6b7d 100644 --- a/gkr_iop/src/gkr.rs +++ b/gkr_iop/src/gkr.rs @@ -1,7 +1,7 @@ use core::fmt; use ff_ext::ExtensionField; -use itertools::{Itertools, chain, izip}; +use itertools::{Itertools, izip}; use layer::{Layer, LayerWitness, sumcheck_layer::SumcheckLayerProof}; use multilinear_extensions::mle::{Point, PointAndEval}; use serde::{Deserialize, Serialize, de::DeserializeOwned}; @@ -140,7 +140,8 @@ impl GKRCircuit { evaluations: &[PointAndEval], challenges: &[E], ) -> Vec> { - chain!(&self.openings, &self.openings) + self.openings + .iter() .map(|(poly, eval)| { let poly = *poly; let PointAndEval { point, eval: value } = eval.evaluate(evaluations, challenges); diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 3d9dfe594..840a8c241 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -39,9 +39,9 @@ pub struct Layer { /// each expression corresponds to an output. While in sumcheck, there /// is only 1 expression, which corresponds to the sum of all outputs. /// This design is for the convenience when building the following - /// expression: `e_0 + beta * e_1 + /// expression: `r^0 e_0 + r^1 * e_1 + ... /// = \sum_x (r^0 eq_0(X) \cdot expr_0(x) + r^1 eq_1(X) \cdot expr_1(x) + ...)`. - /// where `vec![e_0, beta * e_1]` will be the output evaluation expressions. + /// where `vec![e_0, e_1, ...]` will be the output evaluation expressions. pub exprs: Vec>, /// Positions to place the evaluations of the base inputs of this layer. diff --git a/gkr_iop/src/gkr/mock.rs b/gkr_iop/src/gkr/mock.rs index f2ff6db88..6715a9330 100644 --- a/gkr_iop/src/gkr/mock.rs +++ b/gkr_iop/src/gkr/mock.rs @@ -157,6 +157,7 @@ impl EvalExpression { len: usize, ) -> FieldType<'a, E> { match self { + EvalExpression::Zero => FieldType::default(), EvalExpression::Single(i) => evals[*i].clone(), EvalExpression::Linear(i, c0, c1) => Arc::into_inner(wit_infer_by_expr( &[], diff --git a/gkr_iop/src/lib.rs b/gkr_iop/src/lib.rs index 7ac59d4c8..5bd849bd0 100644 --- a/gkr_iop/src/lib.rs +++ b/gkr_iop/src/lib.rs @@ -122,7 +122,7 @@ where gkr_out_well_order.push((*out, out_mle.clone())); } } - EvalExpression::Linear(0, _, _) => { // zero expression + EvalExpression::Zero => { // zero expression // do nothing on zero expression } other => unimplemented!("{:?}", other), diff --git a/gkr_iop/src/precompiles/lookup_keccakf.rs b/gkr_iop/src/precompiles/lookup_keccakf.rs index f1e05ad46..dc721c2ac 100644 --- a/gkr_iop/src/precompiles/lookup_keccakf.rs +++ b/gkr_iop/src/precompiles/lookup_keccakf.rs @@ -26,7 +26,7 @@ use crate::{ precompiles::utils::{MaskRepresentation, not8_expr}, }; -use super::utils::{CenoLookup, u64s_to_felts, zero_eval}; +use super::utils::{CenoLookup, u64s_to_felts}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct KeccakParams {} @@ -137,7 +137,7 @@ impl ConstraintSystem { fn add_constraint(&mut self, expr: Expression, name: String) { self.expressions.push(expr); - self.evals.push(zero_eval()); + self.evals.push(EvalExpression::Zero); self.expr_names.push(name); } diff --git a/gkr_iop/src/precompiles/utils.rs b/gkr_iop/src/precompiles/utils.rs index 299da91ad..0b343d19a 100644 --- a/gkr_iop/src/precompiles/utils.rs +++ b/gkr_iop/src/precompiles/utils.rs @@ -3,20 +3,10 @@ use itertools::Itertools; use multilinear_extensions::{Expression, ToExpr}; use p3_field::PrimeCharacteristicRing; -use crate::evaluation::EvalExpression; - pub fn not8_expr(expr: Expression) -> Expression { E::BaseField::from_u8(0xFF).expr() - expr } -pub fn zero_eval() -> EvalExpression { - EvalExpression::Linear( - 0, - Box::new(E::BaseField::ZERO.expr()), - Box::new(E::BaseField::ZERO.expr()), - ) -} - pub fn u64s_to_felts(words: Vec) -> Vec { words.into_iter().map(E::BaseField::from_u64).collect() } From 536102e34d13c55baaf0785fa7db9358fdd0be13 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 3 Jun 2025 16:30:46 +0800 Subject: [PATCH 27/28] chores: fix typo --- gkr_iop/src/gkr/layer.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 840a8c241..298132d7c 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -233,12 +233,12 @@ impl Layer { for challenge in &self.challenges { let value = transcript.sample_and_append_challenge(b"layer challenge"); match challenge { - Expression::Challenge(challange_id, ..) => { - let challange_id = *challange_id as usize; - if challenges.len() <= challange_id as usize { - challenges.resize(challange_id + 1, E::default()); + Expression::Challenge(challenge_id, ..) => { + let challenge_id = *challenge_id as usize; + if challenges.len() <= challenge_id as usize { + challenges.resize(challenge_id + 1, E::default()); } - challenges[challange_id] = value.elements; + challenges[challenge_id] = value.elements; } _ => unreachable!(), } From cc87e6a154e87adcbfc045eb9e2fe1d0f1536704 Mon Sep 17 00:00:00 2001 From: "sm.wu" Date: Tue, 3 Jun 2025 17:02:31 +0800 Subject: [PATCH 28/28] fix zero eval PointAndEval size --- gkr_iop/src/evaluation.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/gkr_iop/src/evaluation.rs b/gkr_iop/src/evaluation.rs index f7a9479a1..f4910df99 100644 --- a/gkr_iop/src/evaluation.rs +++ b/gkr_iop/src/evaluation.rs @@ -41,7 +41,13 @@ fn evaluate(expr: &Expression, challenges: &[E]) -> E { impl EvalExpression { pub fn evaluate(&self, evals: &[PointAndEval], challenges: &[E]) -> PointAndEval { match self { - EvalExpression::Zero => PointAndEval::default(), + // assume all point in evals are derived in random, thus pick arbirary one is ok + // here we pick first point as representative. + // for zero eval, eval is always zero + EvalExpression::Zero => PointAndEval { + point: evals[0].point.clone(), + eval: E::ZERO, + }, EvalExpression::Single(i) => evals[*i].clone(), EvalExpression::Linear(i, c0, c1) => PointAndEval { point: evals[*i].point.clone(),