diff --git a/air/src/air.rs b/air/src/air.rs index db8446d..8ee089f 100644 --- a/air/src/air.rs +++ b/air/src/air.rs @@ -108,6 +108,10 @@ pub trait AirBuilder: Sized { } } +pub trait AirBuilderWithPublicValues: AirBuilder { + fn public_values(&self) -> &[Self::F]; +} + pub trait PairBuilder: AirBuilder { fn preprocessed(&self) -> Self::M; } diff --git a/keccak-air/examples/prove_baby_bear_keccak.rs b/keccak-air/examples/prove_baby_bear_keccak.rs index 2517af3..e18a5ce 100644 --- a/keccak-air/examples/prove_baby_bear_keccak.rs +++ b/keccak-air/examples/prove_baby_bear_keccak.rs @@ -69,8 +69,8 @@ fn main() -> Result<(), VerificationError> { let mut challenger = Challenger::from_hasher(vec![], byte_hash); - let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); + let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]); let mut challenger = Challenger::from_hasher(vec![], byte_hash); - verify(&config, &KeccakAir {}, &mut challenger, &proof) + verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![]) } diff --git a/keccak-air/examples/prove_baby_bear_poseidon2.rs b/keccak-air/examples/prove_baby_bear_poseidon2.rs index deb3268..87d5e31 100644 --- a/keccak-air/examples/prove_baby_bear_poseidon2.rs +++ b/keccak-air/examples/prove_baby_bear_poseidon2.rs @@ -77,8 +77,8 @@ fn main() -> Result<(), VerificationError> { let mut challenger = Challenger::new(perm.clone()); - let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); + let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]); let mut challenger = Challenger::new(perm); - verify(&config, &KeccakAir {}, &mut challenger, &proof) + verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![]) } diff --git a/keccak-air/examples/prove_goldilocks_keccak.rs b/keccak-air/examples/prove_goldilocks_keccak.rs index 0b721ff..4a9fd09 100644 --- a/keccak-air/examples/prove_goldilocks_keccak.rs +++ b/keccak-air/examples/prove_goldilocks_keccak.rs @@ -69,8 +69,8 @@ fn main() -> Result<(), VerificationError> { let mut challenger = Challenger::from_hasher(vec![], byte_hash); - let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); + let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]); let mut challenger = Challenger::from_hasher(vec![], byte_hash); - verify(&config, &KeccakAir {}, &mut challenger, &proof) + verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![]) } diff --git a/keccak-air/examples/prove_goldilocks_poseidon.rs b/keccak-air/examples/prove_goldilocks_poseidon.rs index f17e54c..39b3506 100644 --- a/keccak-air/examples/prove_goldilocks_poseidon.rs +++ b/keccak-air/examples/prove_goldilocks_poseidon.rs @@ -77,8 +77,8 @@ fn main() -> Result<(), VerificationError> { let mut challenger = Challenger::new(perm.clone()); - let proof = prove::(&config, &KeccakAir {}, &mut challenger, trace); + let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]); let mut challenger = Challenger::new(perm); - verify(&config, &KeccakAir {}, &mut challenger, &proof) + verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![]) } diff --git a/keccak-air/examples/prove_m31_keccak.rs b/keccak-air/examples/prove_m31_keccak.rs index 4b0ca9f..eb48939 100644 --- a/keccak-air/examples/prove_m31_keccak.rs +++ b/keccak-air/examples/prove_m31_keccak.rs @@ -75,10 +75,10 @@ fn main() -> Result<(), VerificationError> { let mut challenger = Challenger::from_hasher(vec![], byte_hash); - let proof = prove::(&config, &air, &mut challenger, trace); + let proof = prove(&config, &air, &mut challenger, trace, &vec![]); let mut challenger = Challenger::from_hasher(vec![], byte_hash); - verify(&config, &air, &mut challenger, &proof)?; + verify(&config, &air, &mut challenger, &proof, &vec![])?; println!("OK!!! 👍"); Ok(()) diff --git a/uni-stark/src/check_constraints.rs b/uni-stark/src/check_constraints.rs index 1971897..50ce043 100644 --- a/uni-stark/src/check_constraints.rs +++ b/uni-stark/src/check_constraints.rs @@ -1,11 +1,13 @@ -use p3_air::{Air, AirBuilder, TwoRowMatrixView}; +use alloc::vec::Vec; + +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, TwoRowMatrixView}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::{Matrix, MatrixRowSlices}; use tracing::instrument; #[instrument(name = "check constraints", skip_all)] -pub(crate) fn check_constraints(air: &A, main: &RowMajorMatrix) +pub(crate) fn check_constraints(air: &A, main: &RowMajorMatrix, public_values: &Vec) where F: Field, A: for<'a> Air>, @@ -25,6 +27,7 @@ where let mut builder = DebugConstraintBuilder { row_index: i, main, + public_values, is_first_row: F::from_bool(i == 0), is_last_row: F::from_bool(i == height - 1), is_transition: F::from_bool(i != height - 1), @@ -39,6 +42,7 @@ where pub struct DebugConstraintBuilder<'a, F: Field> { row_index: usize, main: TwoRowMatrixView<'a, F>, + public_values: &'a [F], is_first_row: F, is_last_row: F, is_transition: F, @@ -92,3 +96,9 @@ where ); } } + +impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> { + fn public_values(&self) -> &[Self::F] { + self.public_values + } +} diff --git a/uni-stark/src/folder.rs b/uni-stark/src/folder.rs index 3015ee8..54c15ab 100644 --- a/uni-stark/src/folder.rs +++ b/uni-stark/src/folder.rs @@ -1,10 +1,13 @@ -use p3_air::{AirBuilder, TwoRowMatrixView}; +use alloc::vec::Vec; + +use p3_air::{AirBuilder, AirBuilderWithPublicValues, TwoRowMatrixView}; use p3_field::AbstractField; use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val}; pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub main: TwoRowMatrixView<'a, PackedVal>, + pub public_values: &'a Vec>, pub is_first_row: PackedVal, pub is_last_row: PackedVal, pub is_transition: PackedVal, @@ -14,6 +17,7 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> { pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> { pub main: TwoRowMatrixView<'a, SC::Challenge>, + pub public_values: &'a Vec>, pub is_first_row: SC::Challenge, pub is_last_row: SC::Challenge, pub is_transition: SC::Challenge, @@ -54,6 +58,12 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> { } } +impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraintFolder<'a, SC> { + fn public_values(&self) -> &[Self::F] { + self.public_values + } +} + impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> { type F = Val; type Expr = SC::Challenge; @@ -86,3 +96,8 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> self.accumulator += x; } } +impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstraintFolder<'a, SC> { + fn public_values(&self) -> &[Self::F] { + self.public_values + } +} diff --git a/uni-stark/src/prover.rs b/uni-stark/src/prover.rs index ef433e1..6042449 100644 --- a/uni-stark/src/prover.rs +++ b/uni-stark/src/prover.rs @@ -28,18 +28,19 @@ pub fn prove< air: &A, challenger: &mut SC::Challenger, trace: RowMajorMatrix>, + public_values: &Vec>, ) -> Proof where SC: StarkGenericConfig, A: Air>> + for<'a> Air>, { #[cfg(debug_assertions)] - crate::check_constraints::check_constraints(air, &trace); + crate::check_constraints::check_constraints(air, &trace, public_values); let degree = trace.height(); let log_degree = log2_strict_usize(degree); - let log_quotient_degree = get_log_quotient_degree::, A>(air); + let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); let quotient_degree = 1 << log_quotient_degree; let pcs = config.pcs(); @@ -58,6 +59,7 @@ where let quotient_values = quotient_values( air, + public_values, trace_domain, quotient_domain, trace_on_quotient_domain, @@ -109,6 +111,7 @@ where #[instrument(name = "compute quotient polynomial", skip_all)] fn quotient_values( air: &A, + public_values: &Vec>, trace_domain: Domain, quotient_domain: Domain, trace_on_quotient_domain: Mat, @@ -162,6 +165,7 @@ where local: &local, next: &next, }, + public_values, is_first_row, is_last_row, is_transition, diff --git a/uni-stark/src/symbolic_builder.rs b/uni-stark/src/symbolic_builder.rs index 9a02e04..e11bd1f 100644 --- a/uni-stark/src/symbolic_builder.rs +++ b/uni-stark/src/symbolic_builder.rs @@ -2,7 +2,7 @@ use alloc::vec; use alloc::vec::Vec; use core::marker::PhantomData; -use p3_air::{Air, AirBuilder}; +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues}; use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; use p3_util::log2_ceil_usize; @@ -12,13 +12,13 @@ use crate::symbolic_expression::SymbolicExpression; use crate::symbolic_variable::SymbolicVariable; #[instrument(name = "infer log of constraint degree", skip_all)] -pub fn get_log_quotient_degree(air: &A) -> usize +pub fn get_log_quotient_degree(air: &A, num_public_values: usize) -> usize where F: Field, A: Air>, { // We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees. - let constraint_degree = get_max_constraint_degree(air).max(2); + let constraint_degree = get_max_constraint_degree(air, num_public_values).max(2); // The quotient's actual degree is approximately (max_constraint_degree - 1) n, // where subtracting 1 comes from division by the zerofier. @@ -27,12 +27,12 @@ where } #[instrument(name = "infer constraint degree", skip_all, level = "debug")] -pub fn get_max_constraint_degree(air: &A) -> usize +pub fn get_max_constraint_degree(air: &A, num_public_values: usize) -> usize where F: Field, A: Air>, { - get_symbolic_constraints(air) + get_symbolic_constraints(air, num_public_values) .iter() .map(|c| c.degree_multiple()) .max() @@ -40,12 +40,15 @@ where } #[instrument(name = "evaluate constraints symbolically", skip_all, level = "debug")] -pub fn get_symbolic_constraints(air: &A) -> Vec> +pub fn get_symbolic_constraints( + air: &A, + num_public_values: usize, +) -> Vec> where F: Field, A: Air>, { - let mut builder = SymbolicAirBuilder::new(air.width()); + let mut builder = SymbolicAirBuilder::new(air.width(), num_public_values); air.eval(&mut builder); builder.constraints() } @@ -53,11 +56,12 @@ where /// An `AirBuilder` for evaluating constraints symbolically, and recording them for later use. pub struct SymbolicAirBuilder { main: RowMajorMatrix>, + public_values: Vec, constraints: Vec>, } impl SymbolicAirBuilder { - pub(crate) fn new(width: usize) -> Self { + pub(crate) fn new(width: usize, num_public_values: usize) -> Self { let values = [false, true] .into_iter() .flat_map(|is_next| { @@ -70,6 +74,8 @@ impl SymbolicAirBuilder { .collect(); Self { main: RowMajorMatrix::new(values, width), + // TODO replace zeros once we have SymbolicExpression::PublicValue + public_values: vec![F::zero(); num_public_values], constraints: vec![], } } @@ -109,3 +115,9 @@ impl AirBuilder for SymbolicAirBuilder { self.constraints.push(x.into()); } } + +impl AirBuilderWithPublicValues for SymbolicAirBuilder { + fn public_values(&self) -> &[Self::F] { + self.public_values.as_slice() + } +} diff --git a/uni-stark/src/verifier.rs b/uni-stark/src/verifier.rs index 3c92dff..8b0917e 100644 --- a/uni-stark/src/verifier.rs +++ b/uni-stark/src/verifier.rs @@ -1,4 +1,5 @@ use alloc::vec; +use alloc::vec::Vec; use itertools::Itertools; use p3_air::{Air, BaseAir, TwoRowMatrixView}; @@ -16,6 +17,7 @@ pub fn verify( air: &A, challenger: &mut SC::Challenger, proof: &Proof, + public_values: &Vec>, ) -> Result<(), VerificationError> where SC: StarkGenericConfig, @@ -29,7 +31,7 @@ where } = proof; let degree = 1 << degree_bits; - let log_quotient_degree = get_log_quotient_degree::, A>(air); + let log_quotient_degree = get_log_quotient_degree::, A>(air, public_values.len()); let quotient_degree = 1 << log_quotient_degree; let pcs = config.pcs(); @@ -118,6 +120,7 @@ where local: &opened_values.trace_local, next: &opened_values.trace_next, }, + public_values, is_first_row: sels.is_first_row, is_last_row: sels.is_last_row, is_transition: sels.is_transition, diff --git a/uni-stark/tests/fib_air.rs b/uni-stark/tests/fib_air.rs new file mode 100644 index 0000000..72da124 --- /dev/null +++ b/uni-stark/tests/fib_air.rs @@ -0,0 +1,170 @@ +use std::borrow::Borrow; + +use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir}; +use p3_baby_bear::{BabyBear, DiffusionMatrixBabybear}; +use p3_challenger::DuplexChallenger; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::{AbstractField, Field, PrimeField64}; +use p3_fri::{FriConfig, TwoAdicFriPcs}; +use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::{Matrix, MatrixRowSlices}; +use p3_merkle_tree::FieldMerkleTreeMmcs; +use p3_poseidon2::Poseidon2; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::{prove, verify, StarkConfig}; +use p3_util::log2_ceil_usize; +use rand::thread_rng; + +/// For testing the public values feature + +pub struct FibonacciAir {} + +impl BaseAir for FibonacciAir { + fn width(&self) -> usize { + NUM_FIBONACCI_COLS + } +} + +impl Air for FibonacciAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let pis = builder.public_values(); + + let a = pis[0]; + let b = pis[1]; + let x = pis[2]; + + let local: &FibonacciRow = main.row_slice(0).borrow(); + let next: &FibonacciRow = main.row_slice(1).borrow(); + + let mut when_first_row = builder.when_first_row(); + + when_first_row.assert_eq(local.left, a); + when_first_row.assert_eq(local.right, b); + + let mut when_transition = builder.when_transition(); + + // a' <- b + when_transition.assert_eq(local.right, next.left); + + // b' <- a + b + when_transition.assert_eq(local.left + local.right, next.right); + + builder.when_last_row().assert_eq(local.right, x); + } +} + +pub fn generate_trace_rows(a: u64, b: u64, n: usize) -> RowMajorMatrix { + assert!(n.is_power_of_two()); + + let mut trace = + RowMajorMatrix::new(vec![F::zero(); n * NUM_FIBONACCI_COLS], NUM_FIBONACCI_COLS); + + let (prefix, rows, suffix) = unsafe { trace.values.align_to_mut::>() }; + assert!(prefix.is_empty(), "Alignment should match"); + assert!(suffix.is_empty(), "Alignment should match"); + assert_eq!(rows.len(), n); + + rows[0] = FibonacciRow::new(F::from_canonical_u64(a), F::from_canonical_u64(b)); + + for i in 1..n { + rows[i].left = rows[i - 1].right; + rows[i].right = rows[i - 1].left + rows[i - 1].right; + } + + trace +} + +const NUM_FIBONACCI_COLS: usize = 2; + +pub struct FibonacciRow { + pub left: F, + pub right: F, +} + +impl FibonacciRow { + fn new(left: F, right: F) -> FibonacciRow { + FibonacciRow { left, right } + } +} + +impl Borrow> for [F] { + fn borrow(&self) -> &FibonacciRow { + debug_assert_eq!(self.len(), NUM_FIBONACCI_COLS); + let (prefix, shorts, suffix) = unsafe { self.align_to::>() }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +type Val = BabyBear; +type Perm = Poseidon2; +type MyHash = PaddingFreeSponge; +type MyCompress = TruncatedPermutation; +type ValMmcs = + FieldMerkleTreeMmcs<::Packing, ::Packing, MyHash, MyCompress, 8>; +type Challenge = BinomialExtensionField; +type ChallengeMmcs = ExtensionMmcs; +type Challenger = DuplexChallenger; +type Dft = Radix2DitParallel; +type Pcs = TwoAdicFriPcs; +type MyConfig = StarkConfig; + +#[test] +fn test_public_value() { + let perm = Perm::new_from_rng(8, 22, DiffusionMatrixBabybear, &mut thread_rng()); + let hash = MyHash::new(perm.clone()); + let compress = MyCompress::new(perm.clone()); + let val_mmcs = ValMmcs::new(hash, compress); + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + let dft = Dft {}; + let trace = generate_trace_rows::(0, 1, 1 << 3); + let fri_config = FriConfig { + log_blowup: 2, + num_queries: 28, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + let pcs = Pcs::new(log2_ceil_usize(trace.height()), dft, val_mmcs, fri_config); + let config = MyConfig::new(pcs); + let mut challenger = Challenger::new(perm.clone()); + let pis = vec![ + BabyBear::from_canonical_u64(0), + BabyBear::from_canonical_u64(1), + BabyBear::from_canonical_u64(21), + ]; + let proof = prove(&config, &FibonacciAir {}, &mut challenger, trace, &pis); + let mut challenger = Challenger::new(perm); + verify(&config, &FibonacciAir {}, &mut challenger, &proof, &pis).expect("verification failed"); +} + +#[test] +#[should_panic(expected = "assertion `left == right` failed: constraints had nonzero value")] +fn test_incorrect_public_value() { + let perm = Perm::new_from_rng(8, 22, DiffusionMatrixBabybear, &mut thread_rng()); + let hash = MyHash::new(perm.clone()); + let compress = MyCompress::new(perm.clone()); + let val_mmcs = ValMmcs::new(hash, compress); + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + let dft = Dft {}; + let fri_config = FriConfig { + log_blowup: 2, + num_queries: 28, + proof_of_work_bits: 8, + mmcs: challenge_mmcs, + }; + let trace = generate_trace_rows::(0, 1, 1 << 3); + let pcs = Pcs::new(log2_ceil_usize(trace.height()), dft, val_mmcs, fri_config); + let config = MyConfig::new(pcs); + let mut challenger = Challenger::new(perm.clone()); + let pis = vec![ + BabyBear::from_canonical_u64(0), + BabyBear::from_canonical_u64(1), + BabyBear::from_canonical_u64(123_123), // incorrect result + ]; + prove(&config, &FibonacciAir {}, &mut challenger, trace, &pis); +} diff --git a/uni-stark/tests/mul_air.rs b/uni-stark/tests/mul_air.rs index 7b44d9c..7a72430 100644 --- a/uni-stark/tests/mul_air.rs +++ b/uni-stark/tests/mul_air.rs @@ -128,7 +128,7 @@ where let trace = air.random_valid_trace(log_height, true); let mut p_challenger = challenger.clone(); - let proof = prove(&config, &air, &mut p_challenger, trace); + let proof = prove(&config, &air, &mut p_challenger, trace, &vec![]); let serialized_proof = postcard::to_allocvec(&proof).expect("unable to serialize proof"); tracing::debug!("serialized_proof len: {} bytes", serialized_proof.len()); @@ -137,7 +137,13 @@ where postcard::from_bytes(&serialized_proof).expect("unable to deserialize proof"); let mut v_challenger = challenger.clone(); - verify(&config, &air, &mut v_challenger, &deserialized_proof) + verify( + &config, + &air, + &mut v_challenger, + &deserialized_proof, + &vec![], + ) } fn do_test_bb_trivial(degree: u64, log_n: usize) -> Result<(), VerificationError> {