Skip to content

Commit

Permalink
Enforce a batch_size of u32
Browse files Browse the repository at this point in the history
Catch error in map
  • Loading branch information
vicsn committed May 12, 2023
1 parent 0ed49bb commit 8fb4a79
Show file tree
Hide file tree
Showing 18 changed files with 183 additions and 140 deletions.
3 changes: 3 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ snarkVM is a big project, so (non-)adherence to best practices related to perfor
- if possible, reuse collections; an example would be a loop that needs a clean vector on each iteration: instead of creating and allocating it over and over, create it _before_ the loop and use `.clear()` on every iteration instead
- try to keep the sizes of `enum` variants uniform; use `Box<T>` on ones that are large

### Cross-platform consistency
- types which contain consensus- or cryptographic logic should have a consistent size across platforms. Their serialized output should not contain `usize`, and at times it may be worth it to avoid using `usize` in the types themselves for clarity.

### Misc. performance

- avoid the `format!()` macro; if it is used only to convert a single value to a `String`, use `.to_string()` instead, which is also available to all the implementors of `Display`
Expand Down
24 changes: 15 additions & 9 deletions algorithms/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,36 @@ pub enum SNARKError {
#[error("{}", _0)]
AnyhowError(#[from] anyhow::Error),

#[error("Batch size was different between public input and proof")]
BatchSizeMismatch,

#[error("Circuit not found")]
CircuitNotFound,

#[error("{}", _0)]
ConstraintFieldError(#[from] ConstraintFieldError),

#[error("{}: {}", _0, _1)]
Crate(&'static str, String),

#[error("Batch size was zero; must be at least 1")]
EmptyBatch,

#[error("Expected a circuit-specific SRS in SNARK")]
ExpectedCircuitSpecificSRS,

#[error(transparent)]
IntError(#[from] std::num::TryFromIntError),

#[error("{}", _0)]
Message(String),

#[error(transparent)]
ParseIntError(#[from] std::num::ParseIntError),

#[error("{}", _0)]
SynthesisError(SynthesisError),

#[error("Batch size was zero; must be at least 1")]
EmptyBatch,

#[error("Batch size was different between public input and proof")]
BatchSizeMismatch,

#[error("Circuit not found")]
CircuitNotFound,

#[error("terminated")]
Terminated,
}
Expand Down
3 changes: 1 addition & 2 deletions algorithms/src/fft/polynomial/sparse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
use crate::fft::{EvaluationDomain, Evaluations, Polynomial};
use snarkvm_fields::{Field, PrimeField};
use snarkvm_utilities::serialize::*;

use std::{collections::BTreeMap, fmt};

/// Stores a sparse polynomial in coefficient form.
#[derive(Clone, PartialEq, Eq, Hash, Default, CanonicalSerialize, CanonicalDeserialize)]
#[derive(Clone, PartialEq, Eq, Hash, Default)]
#[must_use]
pub struct SparsePolynomial<F: Field> {
/// The coefficient a_i of `x^i` is stored as (i, a_i) in `self.coeffs`.
Expand Down
10 changes: 10 additions & 0 deletions algorithms/src/polycommit/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ pub enum PCError {
label: String,
},

/// Could not convert from int.
IntError(std::num::TryFromIntError),

Terminated,
}

Expand All @@ -105,6 +108,12 @@ impl From<anyhow::Error> for PCError {
}
}

impl From<std::num::TryFromIntError> for PCError {
fn from(other: std::num::TryFromIntError) -> Self {
Self::IntError(other)
}
}

impl core::fmt::Display for PCError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Expand Down Expand Up @@ -150,6 +159,7 @@ impl core::fmt::Display for PCError {
(having degree {poly_degree:?}) is greater than the maximum \
supported degree ({supported_degree:?})"
),
Self::IntError(error) => write!(f, "{error}"),
Self::Terminated => write!(f, "terminated"),
}
}
Expand Down
54 changes: 28 additions & 26 deletions algorithms/src/snark/marlin/ahp/ahp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use snarkvm_fields::{Field, PrimeField};
use snarkvm_r1cs::SynthesisError;

use core::{borrow::Borrow, marker::PhantomData};
use std::collections::BTreeMap;
use std::{collections::BTreeMap, num::TryFromIntError};

/// The algebraic holographic proof defined in [CHMMVW19](https://eprint.iacr.org/2019/1047).
/// Currently, this AHP only supports inputs of size one
Expand Down Expand Up @@ -91,24 +91,26 @@ impl<F: PrimeField, MM: MarlinMode> AHPForR1CS<F, MM> {
/// of this protocol.
/// The number of the variables must include the "one" variable. That is, it
/// must be with respect to the number of formatted public inputs.
pub fn max_degree(num_constraints: usize, num_variables: usize, num_non_zero: usize) -> Result<usize, AHPError> {
pub fn max_degree(num_constraints: usize, num_variables: usize, num_non_zero: usize) -> Result<u32, AHPError> {
let padded_matrix_dim = matrices::padded_matrix_dim(num_variables, num_constraints);
let zk_bound = Self::zk_bound().unwrap_or(0);
let constraint_domain_size = EvaluationDomain::<F>::compute_size_of_domain(padded_matrix_dim)
.ok_or(AHPError::PolynomialDegreeTooLarge)?;
let non_zero_domain_size =
EvaluationDomain::<F>::compute_size_of_domain(num_non_zero).ok_or(AHPError::PolynomialDegreeTooLarge)?;

Ok(*[
2 * constraint_domain_size + zk_bound - 2,
if MM::ZK { constraint_domain_size + 3 } else { 0 }, // mask_poly
constraint_domain_size,
constraint_domain_size,
non_zero_domain_size - 1, // non-zero polynomials
]
.iter()
.max()
.unwrap())
Ok(u32::try_from(
*[
2 * constraint_domain_size + zk_bound - 2,
if MM::ZK { constraint_domain_size + 3 } else { 0 }, // mask_poly
constraint_domain_size,
constraint_domain_size,
non_zero_domain_size - 1, // non-zero polynomials
]
.iter()
.max()
.unwrap(),
)?)
}

/// Get all the strict degree bounds enforced in the AHP.
Expand Down Expand Up @@ -243,13 +245,13 @@ impl<F: PrimeField, MM: MarlinMode> AHPForR1CS<F, MM> {
.map(|(&circuit_id, circuit_state)| {
let z_b_i = (0..circuit_state.batch_size)
.map(|i| {
let z_b = witness_label(circuit_id, "z_b", i);
LinearCombination::new(z_b.clone(), [(F::one(), z_b)])
let z_b = witness_label(circuit_id, "z_b", usize::try_from(i)?);
Ok::<_, TryFromIntError>(LinearCombination::new(z_b.clone(), [(F::one(), z_b)]))
})
.collect::<Vec<_>>();
(circuit_id, z_b_i)
.collect::<Result<Vec<_>, _>>()?;
Ok((circuit_id, z_b_i))
})
.collect::<BTreeMap<_, _>>();
.collect::<Result<BTreeMap<CircuitId, _>, TryFromIntError>>()?;

let g_1 = LinearCombination::new("g_1", [(F::one(), "g_1")]);

Expand Down Expand Up @@ -279,18 +281,18 @@ impl<F: PrimeField, MM: MarlinMode> AHPForR1CS<F, MM> {
end_timer!(v_X_at_beta_time);

let z_b_s_at_beta = z_b_s
.values()
.map(|z_b_i| {
let z_b_i_s = z_b_i.iter().map(|z_b| evals.get_lc_eval(z_b, beta)).collect::<Result<Vec<F>, _>>();
z_b_i_s
.iter()
.map(|(circuit_id, z_b_i)| {
let z_b_i_s = z_b_i.iter().map(|z_b| evals.get_lc_eval(z_b, beta)).try_collect()?;
Ok((*circuit_id, z_b_i_s))
})
.collect::<Result<Vec<_>, _>>()?;
.collect::<Result<BTreeMap<CircuitId, Vec<F>>, AHPError>>()?;

let batch_z_b_s_at_beta = z_b_s_at_beta
.iter()
.zip_eq(batch_combiners.iter())
.zip_eq(batch_combiners.values())
.zip_eq(r_alpha_at_beta_s.values())
.map(|((z_b_i_at_beta, (circuit_id, combiners)), &r_alpha_at_beta)| {
.map(|(((circuit_id, z_b_i_at_beta), combiners), &r_alpha_at_beta)| {
let z_b_at_beta = z_b_i_at_beta
.iter()
.zip_eq(&combiners.instance_combiners)
Expand Down Expand Up @@ -325,13 +327,13 @@ impl<F: PrimeField, MM: MarlinMode> AHPForR1CS<F, MM> {
if MM::ZK {
lincheck_sumcheck.add(F::one(), "mask_poly");
}
for (i, (id, c)) in batch_combiners.iter().enumerate() {
for (id, c) in batch_combiners.iter() {
let mut circuit_term = LinearCombination::empty(format!("lincheck_sumcheck term {id}"));
for (j, instance_combiner) in c.instance_combiners.iter().enumerate() {
let z_a_j = witness_label(*id, "z_a", j);
let w_j = witness_label(*id, "w", j);
circuit_term
.add(r_alpha_at_beta_s[id] * instance_combiner * (eta_a + eta_c * z_b_s_at_beta[i][j]), z_a_j)
.add(r_alpha_at_beta_s[id] * instance_combiner * (eta_a + eta_c * z_b_s_at_beta[id][j]), z_a_j)
.add(-t_at_beta_s[id] * v_X_at_beta[id] * instance_combiner, w_j);
}
circuit_term
Expand Down
8 changes: 8 additions & 0 deletions algorithms/src/snark/marlin/ahp/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub enum AHPError {
ConstraintSystemError(snarkvm_r1cs::errors::SynthesisError),
/// The instance generated during proving does not match that in the index.
InstanceDoesNotMatchIndex,
/// Could not convert from int.
IntError(std::num::TryFromIntError),
/// The number of public inputs is incorrect.
InvalidPublicInputLength,
/// During verification, a required evaluation is missing
Expand All @@ -38,3 +40,9 @@ impl From<snarkvm_r1cs::errors::SynthesisError> for AHPError {
AHPError::ConstraintSystemError(other)
}
}

impl From<std::num::TryFromIntError> for AHPError {
fn from(other: std::num::TryFromIntError) -> Self {
Self::IntError(other)
}
}
2 changes: 1 addition & 1 deletion algorithms/src/snark/marlin/ahp/indexer/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ impl<F: PrimeField, MM: MarlinMode> Circuit<F, MM> {
}

/// The maximum degree required to represent polynomials of this index.
pub fn max_degree(&self) -> usize {
pub fn max_degree(&self) -> u32 {
self.index_info.max_degree::<MM>()
}

Expand Down
2 changes: 1 addition & 1 deletion algorithms/src/snark/marlin/ahp/indexer/circuit_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pub struct CircuitInfo<F: Sync + Send> {

impl<F: PrimeField> CircuitInfo<F> {
/// The maximum degree of polynomial required to represent this index in the AHP.
pub fn max_degree<MM: MarlinMode>(&self) -> usize {
pub fn max_degree<MM: MarlinMode>(&self) -> u32 {
let max_non_zero = self.num_non_zero_a.max(self.num_non_zero_b).max(self.num_non_zero_c);
AHPForR1CS::<F, MM>::max_degree(self.num_constraints, self.num_variables, max_non_zero).unwrap()
}
Expand Down
17 changes: 8 additions & 9 deletions algorithms/src/snark/marlin/ahp/prover/round_functions/first.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ impl<F: PrimeField, MM: MarlinMode> AHPForR1CS<F, MM> {

/// Output the degree bounds of oracles in the first round.
pub fn first_round_polynomial_info<'a>(
circuits: impl Iterator<Item = (&'a CircuitId, &'a usize)>,
batch_sizes: impl Iterator<Item = (&'a CircuitId, &'a usize)>,
) -> BTreeMap<PolynomialLabel, PolynomialInfo> {
let mut polynomials = circuits
let mut polynomials = batch_sizes
.flat_map(|(&circuit_id, &batch_size)| {
(0..batch_size).flat_map(move |i| {
[
Expand All @@ -74,13 +74,14 @@ impl<F: PrimeField, MM: MarlinMode> AHPForR1CS<F, MM> {
#[allow(clippy::type_complexity)]
pub fn prover_first_round<'a, R: RngCore>(
mut state: prover::State<'a, F, MM>,
batch_sizes: impl Iterator<Item = (&'a CircuitId, &'a usize)>,
rng: &mut R,
) -> Result<prover::State<'a, F, MM>, AHPError> {
let round_time = start_timer!(|| "AHP::Prover::FirstRound");
let mut r_b_s = Vec::with_capacity(state.circuit_specific_states.len());
let mut job_pool = snarkvm_utilities::ExecutionPool::with_capacity(3 * state.total_instances);
let mut job_pool = snarkvm_utilities::ExecutionPool::with_capacity((3 * state.total_instances).try_into()?);
for (circuit, circuit_state) in state.circuit_specific_states.iter_mut() {
let batch_size = circuit_state.batch_size;
let batch_size = usize::try_from(circuit_state.batch_size)?;

let z_a = circuit_state.z_a.take().unwrap();
let z_b = circuit_state.z_b.take().unwrap();
Expand Down Expand Up @@ -123,20 +124,18 @@ impl<F: PrimeField, MM: MarlinMode> AHPForR1CS<F, MM> {
prover::SingleEntry { z_a, z_b, w_poly, z_a_poly, z_b_poly }
})
.collect::<Vec<_>>();
assert_eq!(batches.len(), state.total_instances);
assert_eq!(batches.len(), usize::try_from(state.total_instances)?);

let mut circuit_specific_batches = BTreeMap::new();
for ((circuit, state), r_b_s) in state.circuit_specific_states.iter_mut().zip(r_b_s) {
let batches = batches.drain(0..state.batch_size).collect_vec();
let batches = batches.drain(0..state.batch_size.try_into()?).collect_vec();
circuit_specific_batches.insert(circuit.id, batches);
state.mz_poly_randomizer = MM::ZK.then_some(r_b_s);
end_timer!(round_time);
}
let mask_poly = Self::calculate_mask_poly(state.max_constraint_domain, rng);
let oracles = prover::FirstOracles { batches: circuit_specific_batches, mask_poly };
assert!(oracles.matches_info(&Self::first_round_polynomial_info(
state.circuit_specific_states.iter().map(|(c, s)| (&c.id, &s.batch_size))
)));
assert!(oracles.matches_info(&Self::first_round_polynomial_info(batch_sizes)));
state.first_round_oracles = Some(Arc::new(oracles));
Ok(state)
}
Expand Down
19 changes: 10 additions & 9 deletions algorithms/src/snark/marlin/ahp/prover/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub struct CircuitSpecificState<F: PrimeField> {
pub(super) non_zero_c_domain: EvaluationDomain<F>,

/// The number of instances being proved in this batch.
pub(in crate::snark) batch_size: usize,
pub(in crate::snark) batch_size: u32,

/// The list of public inputs for each instance in the batch.
/// The length of this list must be equal to the batch size.
Expand Down Expand Up @@ -76,7 +76,7 @@ pub struct State<'a, F: PrimeField, MM: MarlinMode> {
/// The largest constraint domain of all circuits in the batch.
pub(in crate::snark) max_constraint_domain: EvaluationDomain<F>,
/// The total number of instances we're proving in the batch.
pub(in crate::snark) total_instances: usize,
pub(in crate::snark) total_instances: u32,
}

/// The public inputs for a single instance.
Expand Down Expand Up @@ -114,13 +114,14 @@ impl<'a, F: PrimeField, MM: MarlinMode> State<'a, F, MM> {

let first_padded_public_inputs = &variable_assignments[0].0;
let input_domain = EvaluationDomain::new(first_padded_public_inputs.len()).unwrap();
let batch_size = variable_assignments.len();
let batch_size = variable_assignments.len().try_into()?;
total_instances += batch_size;
let mut z_as = Vec::with_capacity(batch_size);
let mut z_bs = Vec::with_capacity(batch_size);
let mut x_polys = Vec::with_capacity(batch_size);
let mut padded_public_variables = Vec::with_capacity(batch_size);
let mut private_variables = Vec::with_capacity(batch_size);
let batch_size_usize = batch_size as usize;
let mut z_as = Vec::with_capacity(batch_size_usize);
let mut z_bs = Vec::with_capacity(batch_size_usize);
let mut x_polys = Vec::with_capacity(batch_size_usize);
let mut padded_public_variables = Vec::with_capacity(batch_size_usize);
let mut private_variables = Vec::with_capacity(batch_size_usize);

for Assignments(padded_public_input, private_input, z_a, z_b) in variable_assignments {
z_as.push(z_a);
Expand Down Expand Up @@ -165,7 +166,7 @@ impl<'a, F: PrimeField, MM: MarlinMode> State<'a, F, MM> {
}

/// Get the batch size for a given circuit.
pub fn batch_size(&self, circuit: &Circuit<F, MM>) -> Option<usize> {
pub fn batch_size(&self, circuit: &Circuit<F, MM>) -> Option<u32> {
self.circuit_specific_states.get(circuit).map(|s| s.batch_size)
}

Expand Down
14 changes: 9 additions & 5 deletions algorithms/src/snark/marlin/ahp/verifier/messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use snarkvm_fields::PrimeField;

use crate::snark::marlin::{witness_label, CircuitId, MarlinMode};
use itertools::Itertools;
use std::collections::BTreeMap;
use std::{collections::BTreeMap, num::TryFromIntError};

/// Randomizers used to combine circuit-specific and instance-specific elements in the AHP sumchecks
#[derive(Clone, Debug)]
Expand Down Expand Up @@ -82,7 +82,7 @@ pub struct QuerySet<F: PrimeField> {
}

impl<F: PrimeField> QuerySet<F> {
pub fn new<MM: MarlinMode>(state: &super::State<F, MM>) -> Self {
pub fn new<MM: MarlinMode>(state: &super::State<F, MM>) -> Result<Self, TryFromIntError> {
let beta = state.second_round_message.unwrap().beta;
let gamma = state.gamma.unwrap();
// For the first linear combination
Expand All @@ -93,8 +93,12 @@ impl<F: PrimeField> QuerySet<F> {
// Note that z is the interpolation of x || w, so it equals x + v_X * w
// We also use an optimization: instead of explicitly calculating z_c, we
// use the "virtual oracle" z_a * z_b
Self {
batch_sizes: state.circuit_specific_states.iter().map(|(c, s)| (*c, s.batch_size)).collect(),
Ok(Self {
batch_sizes: state
.circuit_specific_states
.iter()
.map(|(c, s)| Ok::<_, TryFromIntError>((*c, usize::try_from(s.batch_size)?)))
.try_collect()?,
g_1_query: ("beta".into(), beta),
z_b_query: ("beta".into(), beta),
lincheck_sumcheck_query: ("beta".into(), beta),
Expand All @@ -103,7 +107,7 @@ impl<F: PrimeField> QuerySet<F> {
g_b_query: ("gamma".into(), gamma),
g_c_query: ("gamma".into(), gamma),
matrix_sumcheck_query: ("gamma".into(), gamma),
}
})
}

/// Returns a `BTreeSet` containing elements of the form
Expand Down
2 changes: 1 addition & 1 deletion algorithms/src/snark/marlin/ahp/verifier/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub struct CircuitSpecificState<F: PrimeField> {
pub(crate) non_zero_c_domain: EvaluationDomain<F>,

/// The number of instances being proved in this batch.
pub(in crate::snark::marlin) batch_size: usize,
pub(in crate::snark::marlin) batch_size: u32,
}
/// State of the AHP verifier.
#[derive(Debug)]
Expand Down
Loading

0 comments on commit 8fb4a79

Please sign in to comment.