Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Several refactor and implement backend UniHyperPlonk #30

Merged
merged 7 commits into from
Oct 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[workspace]
members = ["benchmark", "plonkish_backend"]
resolver = "2"

[profile.flamegraph]
inherits = "release"
Expand Down
1 change: 1 addition & 0 deletions benchmark/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ plonkish_backend = { path = "../plonkish_backend", features = ["benchmark"] }
halo2_proofs = { git = "https://github.com/han0110/halo2.git", branch = "feature/for-benchmark" }
halo2_gadgets = { git = "https://github.com/han0110/halo2.git", branch = "feature/for-benchmark", features = ["unstable"] }
snark-verifier = { git = "https://github.com/han0110/snark-verifier", branch = "feature/for-benchmark", default-features = false, features = ["loader_halo2", "system_halo2"] }
zkevm-circuits = { git = "https://github.com/han0110/zkevm-circuits", branch = "feature/for-benchmark" }

# espresso
ark-ff = { version = "0.4.0", default-features = false }
Expand Down
93 changes: 65 additions & 28 deletions benchmark/benches/proof_system.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
use benchmark::{
espresso,
halo2::{AggregationCircuit, Sha256Circuit},
halo2::{AggregationCircuit, Keccak256Circuit, Sha256Circuit},
};
use espresso_hyperplonk::{prelude::MockCircuit, HyperPlonkSNARK};
use espresso_subroutines::{MultilinearKzgPCS, PolyIOP, PolynomialCommitmentScheme};
use halo2_proofs::{
plonk::{create_proof, keygen_pk, keygen_vk, verify_proof},
poly::kzg::{
commitment::ParamsKZG,
multiopen::{ProverGWC, VerifierGWC},
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
},
transcript::{Blake2bRead, Blake2bWrite, TranscriptReadBuffer, TranscriptWriterBuffer},
};
use itertools::Itertools;
use plonkish_backend::{
backend::{self, PlonkishBackend, PlonkishCircuit},
backend::{self, PlonkishBackend, PlonkishCircuit, WitnessEncoding},
frontend::halo2::{circuit::VanillaPlonk, CircuitExt, Halo2Circuit},
halo2_curves::bn256::{Bn256, Fr},
pcs::multilinear,
pcs::{multilinear, univariate, CommitmentChunk},
util::{
end_timer, start_timer,
test::std_rng,
transcript::{InMemoryTranscript, Keccak256Transcript},
transcript::{InMemoryTranscript, Keccak256Transcript, TranscriptRead, TranscriptWrite},
},
};
use std::{
env::args,
fmt::Display,
fs::{create_dir, File, OpenOptions},
io::Write,
io::{Cursor, Write},
iter,
ops::Range,
path::Path,
Expand All @@ -44,38 +44,54 @@ fn main() {
k_range.for_each(|k| systems.iter().for_each(|system| system.bench(k, circuit)));
}

fn bench_hyperplonk<C: CircuitExt<Fr>>(k: usize) {
type MultilinearKzg = multilinear::MultilinearKzg<Bn256>;
type HyperPlonk = backend::hyperplonk::HyperPlonk<MultilinearKzg>;

fn bench_plonkish_backend<B, C>(system: System, k: usize)
where
B: PlonkishBackend<Fr> + WitnessEncoding,
C: CircuitExt<Fr>,
Keccak256Transcript<Cursor<Vec<u8>>>: TranscriptRead<CommitmentChunk<Fr, B::Pcs>, Fr>
+ TranscriptWrite<CommitmentChunk<Fr, B::Pcs>, Fr>
+ InMemoryTranscript,
{
let circuit = C::rand(k, std_rng());
let circuit = Halo2Circuit::new::<HyperPlonk>(k, circuit);
let circuit = Halo2Circuit::new::<B>(k, circuit);
let circuit_info = circuit.circuit_info().unwrap();
let instances = circuit.instances();

let timer = start_timer(|| format!("hyperplonk_setup-{k}"));
let param = HyperPlonk::setup(&circuit_info, std_rng()).unwrap();
let timer = start_timer(|| format!("{system}_setup-{k}"));
let param = B::setup(&circuit_info, std_rng()).unwrap();
end_timer(timer);

let timer = start_timer(|| format!("hyperplonk_preprocess-{k}"));
let (pp, vp) = HyperPlonk::preprocess(&param, &circuit_info).unwrap();
let timer = start_timer(|| format!("{system}_preprocess-{k}"));
let (pp, vp) = B::preprocess(&param, &circuit_info).unwrap();
end_timer(timer);

let proof = sample(System::HyperPlonk, k, || {
let _timer = start_timer(|| format!("hyperplonk_prove-{k}"));
let proof = sample(system, k, || {
let _timer = start_timer(|| format!("{system}_prove-{k}"));
let mut transcript = Keccak256Transcript::default();
HyperPlonk::prove(&pp, &circuit, &mut transcript, std_rng()).unwrap();
B::prove(&pp, &circuit, &mut transcript, std_rng()).unwrap();
transcript.into_proof()
});

let _timer = start_timer(|| format!("hyperplonk_verify-{k}"));
let _timer = start_timer(|| format!("{system}_verify-{k}"));
let accept = {
let mut transcript = Keccak256Transcript::from_proof((), proof.as_slice());
HyperPlonk::verify(&vp, instances, &mut transcript, std_rng()).is_ok()
B::verify(&vp, instances, &mut transcript, std_rng()).is_ok()
};
assert!(accept);
}

fn bench_hyperplonk<C: CircuitExt<Fr>>(k: usize) {
type GeminiKzg = multilinear::Gemini<univariate::UnivariateKzg<Bn256>>;
type HyperPlonk = backend::hyperplonk::HyperPlonk<GeminiKzg>;
bench_plonkish_backend::<HyperPlonk, C>(System::HyperPlonk, k)
}

fn bench_unihyperplonk<C: CircuitExt<Fr>>(k: usize) {
type UnivariateKzg = univariate::UnivariateKzg<Bn256>;
type UniHyperPlonk = backend::unihyperplonk::UniHyperPlonk<UnivariateKzg, true>;
bench_plonkish_backend::<UniHyperPlonk, C>(System::UniHyperPlonk, k)
}

fn bench_halo2<C: CircuitExt<Fr>>(k: usize) {
let circuit = C::rand(k, std_rng());
let circuits = &[circuit];
Expand All @@ -93,11 +109,13 @@ fn bench_halo2<C: CircuitExt<Fr>>(k: usize) {
end_timer(timer);

let create_proof = |c, d, e, mut f: Blake2bWrite<_, _, _>| {
create_proof::<_, ProverGWC<_>, _, _, _, _, false>(&param, &pk, c, d, e, &mut f).unwrap();
create_proof::<_, ProverSHPLONK<_>, _, _, _, _, false>(&param, &pk, c, d, e, &mut f)
.unwrap();
f.finalize()
};
let verify_proof =
|c, d, e| verify_proof::<_, VerifierGWC<_>, _, _, _, false>(&param, pk.get_vk(), c, d, e);
let verify_proof = |c, d, e| {
verify_proof::<_, VerifierSHPLONK<_>, _, _, _, false>(&param, pk.get_vk(), c, d, e)
};

let proof = sample(System::Halo2, k, || {
let _timer = start_timer(|| format!("halo2_prove-{k}"));
Expand Down Expand Up @@ -150,6 +168,7 @@ fn bench_espresso_hyperplonk(circuit: MockCircuit<ark_bn254::Fr>) {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum System {
HyperPlonk,
UniHyperPlonk,
Halo2,
EspressoHyperPlonk,
}
Expand All @@ -158,6 +177,7 @@ impl System {
fn all() -> Vec<System> {
vec![
System::HyperPlonk,
System::UniHyperPlonk,
System::Halo2,
System::EspressoHyperPlonk,
]
Expand All @@ -176,12 +196,15 @@ impl System {

fn support(&self, circuit: Circuit) -> bool {
match self {
System::HyperPlonk | System::Halo2 => match circuit {
Circuit::VanillaPlonk | Circuit::Aggregation | Circuit::Sha256 => true,
System::HyperPlonk | System::UniHyperPlonk | System::Halo2 => match circuit {
Circuit::VanillaPlonk
| Circuit::Aggregation
| Circuit::Sha256
| Circuit::Keccak256 => true,
},
System::EspressoHyperPlonk => match circuit {
Circuit::VanillaPlonk => true,
Circuit::Aggregation | Circuit::Sha256 => false,
Circuit::Aggregation | Circuit::Sha256 | Circuit::Keccak256 => false,
},
}
}
Expand All @@ -199,15 +222,23 @@ impl System {
Circuit::VanillaPlonk => bench_hyperplonk::<VanillaPlonk<Fr>>(k),
Circuit::Aggregation => bench_hyperplonk::<AggregationCircuit<Bn256>>(k),
Circuit::Sha256 => bench_hyperplonk::<Sha256Circuit>(k),
Circuit::Keccak256 => bench_hyperplonk::<Keccak256Circuit>(k),
},
System::UniHyperPlonk => match circuit {
Circuit::VanillaPlonk => bench_unihyperplonk::<VanillaPlonk<Fr>>(k),
Circuit::Aggregation => bench_unihyperplonk::<AggregationCircuit<Bn256>>(k),
Circuit::Sha256 => bench_unihyperplonk::<Sha256Circuit>(k),
Circuit::Keccak256 => bench_unihyperplonk::<Keccak256Circuit>(k),
},
System::Halo2 => match circuit {
Circuit::VanillaPlonk => bench_halo2::<VanillaPlonk<Fr>>(k),
Circuit::Aggregation => bench_halo2::<AggregationCircuit<Bn256>>(k),
Circuit::Sha256 => bench_halo2::<Sha256Circuit>(k),
Circuit::Keccak256 => bench_halo2::<Keccak256Circuit>(k),
},
System::EspressoHyperPlonk => match circuit {
Circuit::VanillaPlonk => bench_espresso_hyperplonk(espresso::vanilla_plonk(k)),
Circuit::Aggregation | Circuit::Sha256 => unreachable!(),
Circuit::Aggregation | Circuit::Sha256 | Circuit::Keccak256 => unreachable!(),
},
}
}
Expand All @@ -217,6 +248,7 @@ impl Display for System {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
System::HyperPlonk => write!(f, "hyperplonk"),
System::UniHyperPlonk => write!(f, "unihyperplonk"),
System::Halo2 => write!(f, "halo2"),
System::EspressoHyperPlonk => write!(f, "espresso_hyperplonk"),
}
Expand All @@ -228,6 +260,7 @@ enum Circuit {
VanillaPlonk,
Aggregation,
Sha256,
Keccak256,
}

impl Circuit {
Expand All @@ -236,6 +269,7 @@ impl Circuit {
Circuit::VanillaPlonk => 4,
Circuit::Aggregation => 20,
Circuit::Sha256 => 17,
Circuit::Keccak256 => 10,
}
}
}
Expand All @@ -246,6 +280,7 @@ impl Display for Circuit {
Circuit::VanillaPlonk => write!(f, "vanilla_plonk"),
Circuit::Aggregation => write!(f, "aggregation"),
Circuit::Sha256 => write!(f, "sha256"),
Circuit::Keccak256 => write!(f, "keccak256"),
}
}
}
Expand All @@ -258,16 +293,18 @@ fn parse_args() -> (Vec<System>, Circuit, Range<usize>) {
"--system" => match value.as_str() {
"all" => systems = System::all(),
"hyperplonk" => systems.push(System::HyperPlonk),
"unihyperplonk" => systems.push(System::UniHyperPlonk),
"halo2" => systems.push(System::Halo2),
"espresso_hyperplonk" => systems.push(System::EspressoHyperPlonk),
_ => panic!(
"system should be one of {{all,hyperplonk,halo2,espresso_hyperplonk}}"
"system should be one of {{all,hyperplonk,unihyperplonk,halo2,espresso_hyperplonk}}"
),
},
"--circuit" => match value.as_str() {
"vanilla_plonk" => circuit = Circuit::VanillaPlonk,
"aggregation" => circuit = Circuit::Aggregation,
"sha256" => circuit = Circuit::Sha256,
"keccak256" => circuit = Circuit::Keccak256,
_ => panic!("circuit should be one of {{aggregation,vanilla_plonk}}"),
},
"--k" => {
Expand Down
79 changes: 62 additions & 17 deletions benchmark/src/bin/plotter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ fn main() {
}

fn parse_args() -> (bool, Vec<String>) {
let (verbose, logs) = args().chain(Some("".to_string())).tuple_windows().fold(
let (verbose, logs) = args().chain(["".to_string()]).tuple_windows().fold(
(false, None),
|(mut verbose, mut logs), (key, value)| {
match key.as_str() {
Expand Down Expand Up @@ -94,6 +94,7 @@ fn parse_args() -> (bool, Vec<String>) {
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
enum System {
HyperPlonk,
UniHyperPlonk,
Halo2,
EspressoHyperPlonk,
}
Expand All @@ -102,6 +103,7 @@ impl System {
fn iter() -> impl Iterator<Item = System> {
[
System::HyperPlonk,
System::UniHyperPlonk,
System::Halo2,
System::EspressoHyperPlonk,
]
Expand All @@ -110,7 +112,7 @@ impl System {

fn key_fn(&self) -> impl Fn(&Log) -> (bool, &str) + '_ {
move |log| match self {
System::HyperPlonk | System::Halo2 => (
System::HyperPlonk | System::UniHyperPlonk | System::Halo2 => (
false,
log.name.split([' ', '-']).next().unwrap_or(&log.name),
),
Expand Down Expand Up @@ -167,6 +169,49 @@ impl System {
]),
),
],
System::UniHyperPlonk => vec![
(
"all",
vec![
vec!["variable_base_msm"],
vec!["sum_check_prove"],
vec!["prove_multilinear_eval"],
],
None,
),
("multiexp", vec![vec!["variable_base_msm"]], None),
("sum check", vec![vec!["sum_check_prove"]], None),
(
"mleval multiexp",
vec![
vec!["prove_multilinear_eval", "variable_base_msm"],
vec![
"prove_multilinear_eval",
"pcs_batch_open",
"variable_base_msm",
],
],
None,
),
(
"mleval fft",
vec![vec!["prove_multilinear_eval", "fft"]],
None,
),
(
"mleval rest",
vec![vec!["prove_multilinear_eval"]],
Some(vec![
vec!["prove_multilinear_eval", "variable_base_msm"],
vec![
"prove_multilinear_eval",
"pcs_batch_open",
"variable_base_msm",
],
vec!["prove_multilinear_eval", "fft"],
]),
),
],
System::Halo2 => vec![
(
"all",
Expand Down Expand Up @@ -320,6 +365,7 @@ impl Display for System {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
System::HyperPlonk => write!(f, "hyperplonk"),
System::UniHyperPlonk => write!(f, "unihyperplonk"),
System::Halo2 => write!(f, "halo2"),
System::EspressoHyperPlonk => write!(f, "espresso_hyperplonk"),
}
Expand Down Expand Up @@ -613,21 +659,20 @@ fn plot_comparison(cost_breakdowns_by_system: &[BTreeMap<usize, Vec<(&str, Durat
let lines = System::iter()
.zip(cost_breakdowns_by_system.iter())
.skip(1)
.filter_map(|(system, cost_breakdowns)| {
(!cost_breakdowns.is_empty()).then(|| {
let [numer, denom] =
[cost_breakdowns, hyperplonk_cost_breakdowns].map(|cost_breakdowns| {
x.iter()
.map(|k| cost_breakdowns[k][0].1.as_nanos() as f64)
.collect_vec()
});
let ratio = numer
.iter()
.zip(denom.iter())
.map(|(numer, denom)| numer / denom)
.collect_vec();
(format!("{system}/{}", System::HyperPlonk), ratio)
})
.filter(|(_, cost_breakdowns)| !cost_breakdowns.is_empty())
.map(|(system, cost_breakdowns)| {
let [numer, denom] =
[cost_breakdowns, hyperplonk_cost_breakdowns].map(|cost_breakdowns| {
x.iter()
.map(|k| cost_breakdowns[k][0].1.as_nanos() as f64)
.collect_vec()
});
let ratio = numer
.iter()
.zip(denom.iter())
.map(|(numer, denom)| numer / denom)
.collect_vec();
(format!("{system}/{}", System::HyperPlonk), ratio)
})
.collect_vec();

Expand Down
Loading