Skip to content

Commit

Permalink
Public value for uni-stark and air (#285)
Browse files Browse the repository at this point in the history
* working

* finish basic fibonacci

* refine

* fix test

* add public input support for uni-stark

* fix examples

* fix examples

* make public_value optional

* fix verifier public value type

* make AirBuilder's public_values() return reference instead

* move fibonacci-air crate and use it as aa test for public values

* remove test mod because it is an integration test

* fix CI

* refine

* use rustfmt to make CI happy
  • Loading branch information
patrickmao93 authored Mar 22, 2024
1 parent 4809fa7 commit 2edbd19
Show file tree
Hide file tree
Showing 13 changed files with 250 additions and 26 deletions.
4 changes: 4 additions & 0 deletions air/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ pub trait AirBuilder: Sized {
}
}

pub trait AirBuilderWithPublicValues: AirBuilder {
fn public_values(&self) -> &[Self::F];
}

pub trait PairBuilder: AirBuilder {
fn preprocessed(&self) -> Self::M;
}
Expand Down
4 changes: 2 additions & 2 deletions keccak-air/examples/prove_baby_bear_keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ fn main() -> Result<(), VerificationError> {

let mut challenger = Challenger::from_hasher(vec![], byte_hash);

let proof = prove::<MyConfig, _>(&config, &KeccakAir {}, &mut challenger, trace);
let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]);

let mut challenger = Challenger::from_hasher(vec![], byte_hash);
verify(&config, &KeccakAir {}, &mut challenger, &proof)
verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![])
}
4 changes: 2 additions & 2 deletions keccak-air/examples/prove_baby_bear_poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ fn main() -> Result<(), VerificationError> {

let mut challenger = Challenger::new(perm.clone());

let proof = prove::<MyConfig, _>(&config, &KeccakAir {}, &mut challenger, trace);
let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]);

let mut challenger = Challenger::new(perm);
verify(&config, &KeccakAir {}, &mut challenger, &proof)
verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![])
}
4 changes: 2 additions & 2 deletions keccak-air/examples/prove_goldilocks_keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ fn main() -> Result<(), VerificationError> {

let mut challenger = Challenger::from_hasher(vec![], byte_hash);

let proof = prove::<MyConfig, _>(&config, &KeccakAir {}, &mut challenger, trace);
let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]);

let mut challenger = Challenger::from_hasher(vec![], byte_hash);
verify(&config, &KeccakAir {}, &mut challenger, &proof)
verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![])
}
4 changes: 2 additions & 2 deletions keccak-air/examples/prove_goldilocks_poseidon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ fn main() -> Result<(), VerificationError> {

let mut challenger = Challenger::new(perm.clone());

let proof = prove::<MyConfig, _>(&config, &KeccakAir {}, &mut challenger, trace);
let proof = prove(&config, &KeccakAir {}, &mut challenger, trace, &vec![]);

let mut challenger = Challenger::new(perm);
verify(&config, &KeccakAir {}, &mut challenger, &proof)
verify(&config, &KeccakAir {}, &mut challenger, &proof, &vec![])
}
4 changes: 2 additions & 2 deletions keccak-air/examples/prove_m31_keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,10 @@ fn main() -> Result<(), VerificationError> {

let mut challenger = Challenger::from_hasher(vec![], byte_hash);

let proof = prove::<MyConfig, _>(&config, &air, &mut challenger, trace);
let proof = prove(&config, &air, &mut challenger, trace, &vec![]);

let mut challenger = Challenger::from_hasher(vec![], byte_hash);
verify(&config, &air, &mut challenger, &proof)?;
verify(&config, &air, &mut challenger, &proof, &vec![])?;

println!("OK!!! 👍");
Ok(())
Expand Down
14 changes: 12 additions & 2 deletions uni-stark/src/check_constraints.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
use p3_air::{Air, AirBuilder, TwoRowMatrixView};
use alloc::vec::Vec;

use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, TwoRowMatrixView};
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::{Matrix, MatrixRowSlices};
use tracing::instrument;

#[instrument(name = "check constraints", skip_all)]
pub(crate) fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>)
pub(crate) fn check_constraints<F, A>(air: &A, main: &RowMajorMatrix<F>, public_values: &Vec<F>)
where
F: Field,
A: for<'a> Air<DebugConstraintBuilder<'a, F>>,
Expand All @@ -25,6 +27,7 @@ where
let mut builder = DebugConstraintBuilder {
row_index: i,
main,
public_values,
is_first_row: F::from_bool(i == 0),
is_last_row: F::from_bool(i == height - 1),
is_transition: F::from_bool(i != height - 1),
Expand All @@ -39,6 +42,7 @@ where
pub struct DebugConstraintBuilder<'a, F: Field> {
row_index: usize,
main: TwoRowMatrixView<'a, F>,
public_values: &'a [F],
is_first_row: F,
is_last_row: F,
is_transition: F,
Expand Down Expand Up @@ -92,3 +96,9 @@ where
);
}
}

impl<'a, F: Field> AirBuilderWithPublicValues for DebugConstraintBuilder<'a, F> {
fn public_values(&self) -> &[Self::F] {
self.public_values
}
}
17 changes: 16 additions & 1 deletion uni-stark/src/folder.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
use p3_air::{AirBuilder, TwoRowMatrixView};
use alloc::vec::Vec;

use p3_air::{AirBuilder, AirBuilderWithPublicValues, TwoRowMatrixView};
use p3_field::AbstractField;

use crate::{PackedChallenge, PackedVal, StarkGenericConfig, Val};

pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> {
pub main: TwoRowMatrixView<'a, PackedVal<SC>>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: PackedVal<SC>,
pub is_last_row: PackedVal<SC>,
pub is_transition: PackedVal<SC>,
Expand All @@ -14,6 +17,7 @@ pub struct ProverConstraintFolder<'a, SC: StarkGenericConfig> {

pub struct VerifierConstraintFolder<'a, SC: StarkGenericConfig> {
pub main: TwoRowMatrixView<'a, SC::Challenge>,
pub public_values: &'a Vec<Val<SC>>,
pub is_first_row: SC::Challenge,
pub is_last_row: SC::Challenge,
pub is_transition: SC::Challenge,
Expand Down Expand Up @@ -54,6 +58,12 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for ProverConstraintFolder<'a, SC> {
}
}

impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for ProverConstraintFolder<'a, SC> {
fn public_values(&self) -> &[Self::F] {
self.public_values
}
}

impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC> {
type F = Val<SC>;
type Expr = SC::Challenge;
Expand Down Expand Up @@ -86,3 +96,8 @@ impl<'a, SC: StarkGenericConfig> AirBuilder for VerifierConstraintFolder<'a, SC>
self.accumulator += x;
}
}
impl<'a, SC: StarkGenericConfig> AirBuilderWithPublicValues for VerifierConstraintFolder<'a, SC> {
fn public_values(&self) -> &[Self::F] {
self.public_values
}
}
8 changes: 6 additions & 2 deletions uni-stark/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,19 @@ pub fn prove<
air: &A,
challenger: &mut SC::Challenger,
trace: RowMajorMatrix<Val<SC>>,
public_values: &Vec<Val<SC>>,
) -> Proof<SC>
where
SC: StarkGenericConfig,
A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<ProverConstraintFolder<'a, SC>>,
{
#[cfg(debug_assertions)]
crate::check_constraints::check_constraints(air, &trace);
crate::check_constraints::check_constraints(air, &trace, public_values);

let degree = trace.height();
let log_degree = log2_strict_usize(degree);

let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air);
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air, public_values.len());
let quotient_degree = 1 << log_quotient_degree;

let pcs = config.pcs();
Expand All @@ -58,6 +59,7 @@ where

let quotient_values = quotient_values(
air,
public_values,
trace_domain,
quotient_domain,
trace_on_quotient_domain,
Expand Down Expand Up @@ -109,6 +111,7 @@ where
#[instrument(name = "compute quotient polynomial", skip_all)]
fn quotient_values<SC, A, Mat>(
air: &A,
public_values: &Vec<Val<SC>>,
trace_domain: Domain<SC>,
quotient_domain: Domain<SC>,
trace_on_quotient_domain: Mat,
Expand Down Expand Up @@ -162,6 +165,7 @@ where
local: &local,
next: &next,
},
public_values,
is_first_row,
is_last_row,
is_transition,
Expand Down
28 changes: 20 additions & 8 deletions uni-stark/src/symbolic_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use alloc::vec;
use alloc::vec::Vec;
use core::marker::PhantomData;

use p3_air::{Air, AirBuilder};
use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues};
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;
use p3_util::log2_ceil_usize;
Expand All @@ -12,13 +12,13 @@ use crate::symbolic_expression::SymbolicExpression;
use crate::symbolic_variable::SymbolicVariable;

#[instrument(name = "infer log of constraint degree", skip_all)]
pub fn get_log_quotient_degree<F, A>(air: &A) -> usize
pub fn get_log_quotient_degree<F, A>(air: &A, num_public_values: usize) -> usize
where
F: Field,
A: Air<SymbolicAirBuilder<F>>,
{
// We pad to at least degree 2, since a quotient argument doesn't make sense with smaller degrees.
let constraint_degree = get_max_constraint_degree(air).max(2);
let constraint_degree = get_max_constraint_degree(air, num_public_values).max(2);

// The quotient's actual degree is approximately (max_constraint_degree - 1) n,
// where subtracting 1 comes from division by the zerofier.
Expand All @@ -27,37 +27,41 @@ where
}

#[instrument(name = "infer constraint degree", skip_all, level = "debug")]
pub fn get_max_constraint_degree<F, A>(air: &A) -> usize
pub fn get_max_constraint_degree<F, A>(air: &A, num_public_values: usize) -> usize
where
F: Field,
A: Air<SymbolicAirBuilder<F>>,
{
get_symbolic_constraints(air)
get_symbolic_constraints(air, num_public_values)
.iter()
.map(|c| c.degree_multiple())
.max()
.unwrap_or(0)
}

#[instrument(name = "evaluate constraints symbolically", skip_all, level = "debug")]
pub fn get_symbolic_constraints<F, A>(air: &A) -> Vec<SymbolicExpression<F>>
pub fn get_symbolic_constraints<F, A>(
air: &A,
num_public_values: usize,
) -> Vec<SymbolicExpression<F>>
where
F: Field,
A: Air<SymbolicAirBuilder<F>>,
{
let mut builder = SymbolicAirBuilder::new(air.width());
let mut builder = SymbolicAirBuilder::new(air.width(), num_public_values);
air.eval(&mut builder);
builder.constraints()
}

/// An `AirBuilder` for evaluating constraints symbolically, and recording them for later use.
pub struct SymbolicAirBuilder<F: Field> {
main: RowMajorMatrix<SymbolicVariable<F>>,
public_values: Vec<F>,
constraints: Vec<SymbolicExpression<F>>,
}

impl<F: Field> SymbolicAirBuilder<F> {
pub(crate) fn new(width: usize) -> Self {
pub(crate) fn new(width: usize, num_public_values: usize) -> Self {
let values = [false, true]
.into_iter()
.flat_map(|is_next| {
Expand All @@ -70,6 +74,8 @@ impl<F: Field> SymbolicAirBuilder<F> {
.collect();
Self {
main: RowMajorMatrix::new(values, width),
// TODO replace zeros once we have SymbolicExpression::PublicValue
public_values: vec![F::zero(); num_public_values],
constraints: vec![],
}
}
Expand Down Expand Up @@ -109,3 +115,9 @@ impl<F: Field> AirBuilder for SymbolicAirBuilder<F> {
self.constraints.push(x.into());
}
}

impl<F: Field> AirBuilderWithPublicValues for SymbolicAirBuilder<F> {
fn public_values(&self) -> &[Self::F] {
self.public_values.as_slice()
}
}
5 changes: 4 additions & 1 deletion uni-stark/src/verifier.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use alloc::vec;
use alloc::vec::Vec;

use itertools::Itertools;
use p3_air::{Air, BaseAir, TwoRowMatrixView};
Expand All @@ -16,6 +17,7 @@ pub fn verify<SC, A>(
air: &A,
challenger: &mut SC::Challenger,
proof: &Proof<SC>,
public_values: &Vec<Val<SC>>,
) -> Result<(), VerificationError>
where
SC: StarkGenericConfig,
Expand All @@ -29,7 +31,7 @@ where
} = proof;

let degree = 1 << degree_bits;
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air);
let log_quotient_degree = get_log_quotient_degree::<Val<SC>, A>(air, public_values.len());
let quotient_degree = 1 << log_quotient_degree;

let pcs = config.pcs();
Expand Down Expand Up @@ -118,6 +120,7 @@ where
local: &opened_values.trace_local,
next: &opened_values.trace_next,
},
public_values,
is_first_row: sels.is_first_row,
is_last_row: sels.is_last_row,
is_transition: sels.is_transition,
Expand Down
Loading

0 comments on commit 2edbd19

Please sign in to comment.