Skip to content

Commit

Permalink
Sponge construction for keccak syscall & Intruction::keccakf (#2263)
Browse files Browse the repository at this point in the history
  • Loading branch information
gzanitti authored Jan 3, 2025
1 parent 28ad2d8 commit 50fd458
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 74 deletions.
1 change: 1 addition & 0 deletions riscv-executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ p3-goldilocks = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432
p3-symmetric = { git = "https://github.com/plonky3/Plonky3.git", rev = "2192432ddf28e7359dd2c577447886463e6124f0" }
rustc-demangle = "0.1"
inferno = "0.11.19"
tiny-keccak = { version = "2.0.2", features = ["keccak"] }

rayon = "1.7.0"

Expand Down
48 changes: 46 additions & 2 deletions riscv-executor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ use powdr_ast::{
BinaryOperation, Expression, FunctionCall, Number, UnaryOperation,
},
};
use tiny_keccak::keccakf;

use powdr_executor::constant_evaluator::VariablySizedColumn;
use powdr_number::{write_polys_csv_file, FieldElement, LargeInt};
pub use profiler::ProfilerOptions;
Expand Down Expand Up @@ -132,7 +134,8 @@ instructions! {
ec_add,
ec_double,
commit_public,
fail
fail,
keccakf
}

/// Enum with columns directly accessed by the executor (as to avoid matching on strings)
Expand Down Expand Up @@ -299,7 +302,7 @@ machine_instances! {
poseidon_gl
// TODO: these are not implemented yet
// poseidon2_gl,
// keccakf,
// keccakf
// arith,
}

Expand Down Expand Up @@ -2570,6 +2573,47 @@ impl<F: FieldElement> Executor<'_, '_, F> {
);
None
}
Instruction::keccakf => {
let reg1 = args[0].u();
let reg2 = args[1].u();
let lid = self.instr_link_id(instr, MachineInstance::regs, 0);
let input_ptr = self.reg_read(0, reg1, lid);
let lid = self.instr_link_id(instr, MachineInstance::regs, 1);
let output_ptr = self.reg_read(1, reg2, lid);

set_col!(tmp1_col, input_ptr);
set_col!(tmp2_col, output_ptr);

let mut state = [0u64; 25];
// Note: lo/hi positions are swapped (lo at +4 offset, hi at +0) to match
// the Keccak machine specification's memory layout
for (i, state_i) in state.iter_mut().enumerate() {
let lo = self
.proc
.get_mem(input_ptr.u() + 8 * i as u32 + 4, self.step, lid);
let hi = self
.proc
.get_mem(input_ptr.u() + 8 * i as u32, self.step, lid);
*state_i = ((hi as u64) << 32) | lo as u64;
}

keccakf(&mut state);

for (i, val) in state.iter().enumerate() {
let lo = *val as u32;
let hi = (val >> 32) as u32;

self.proc
.set_mem(output_ptr.u() + i as u32 * 8 + 4, lo, self.step + 1, lid);
self.proc
.set_mem(output_ptr.u() + i as u32 * 8, hi, self.step + 1, lid);
}

//let lid = self.instr_link_id(instr, "main_keccakf", 0);
//submachine_op!(keccakf, lid, &[input_ptr.into_fe(), output_ptr.into_fe()],);
//main_op!(keccakf32_memory);
None
}
Instruction::Count => unreachable!(),
};

Expand Down
111 changes: 79 additions & 32 deletions riscv-runtime/src/hash.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use core::arch::asm;
use core::convert::TryInto;
use core::mem::{self, MaybeUninit};
use core::mem::MaybeUninit;

use crate::goldilocks::Goldilocks;
use powdr_riscv_syscalls::Syscall;
Expand Down Expand Up @@ -40,47 +40,94 @@ pub fn poseidon2_gl(data: &[Goldilocks; 8]) -> [Goldilocks; 8] {

/// Calls the keccakf machine.
/// Return value is placed in the output array.
pub fn keccakf(input: &[u64; 25], output: &mut [u64; 25]) {
pub fn keccakf(input: &[u32; 50], output: &mut [u32; 50]) {
unsafe {
// Syscall inputs: memory pointer to input array and memory pointer to output array.
ecall!(Syscall::KeccakF, in("a0") input, in("a1") output);
}
}

// Output number of bytes for keccak-256 (32 bytes)
const W: usize = 32;

/// Keccak function that calls the keccakf machine.
/// Input is a byte array of arbitrary length and a delimiter byte.
/// Output is a byte array of length W.
pub fn keccak(data: &[u8], delim: u8) -> [u8; W] {
let mut b = [[0u8; 200]; 2];
let [mut b_input, mut b_output] = &mut b;
let rate = 200 - (2 * W);
let mut pt = 0;

// update
for &byte in data {
b_input[pt] ^= byte;
pt = (pt + 1) % rate;
if pt == 0 {
unsafe {
keccakf(mem::transmute(&b_input), mem::transmute(&mut b_output));
pub struct Keccak {
state: [u32; 50],
next_word: usize,
input_buffer: u32,
next_byte: usize,
}

impl Keccak {
const RATE: usize = 34; // Rate in u32 words

pub fn v256() -> Self {
Self {
state: [0u32; 50],
next_word: 0,
input_buffer: 0,
next_byte: 0,
}
}

fn xor_word_to_state(&mut self, word: u32) {
let word_pair = self.next_word & !1;
if (self.next_word & 1) == 0 {
self.state[word_pair + 1] ^= word;
} else {
self.state[word_pair] ^= word;
}
self.next_word += 1;

if self.next_word == Self::RATE {
let mut state_out = [0u32; 50];
keccakf(&self.state, &mut state_out);
self.state = state_out;
self.next_word = 0;
}
}

pub fn update(&mut self, data: &[u8]) {
unsafe {
let (prefix, words, suffix) = data.align_to::<u32>();

self.update_unaligned(prefix);
for &word in words {
if self.next_byte == 0 {
self.xor_word_to_state(word);
} else {
self.xor_word_to_state(self.input_buffer | (word << (8 * self.next_byte)));
self.input_buffer = word >> (32 - 8 * self.next_byte);
}
}
mem::swap(&mut b_input, &mut b_output);
self.update_unaligned(suffix);
}
}

// finalize
b_input[pt] ^= delim;
b_input[rate - 1] ^= 0x80;
unsafe {
keccakf(mem::transmute(&b_input), mem::transmute(&mut b_output));
fn update_unaligned(&mut self, bytes: &[u8]) {
for &byte in bytes {
self.input_buffer |= (byte as u32) << (8 * self.next_byte);
self.next_byte += 1;
if self.next_byte == 4 {
self.xor_word_to_state(self.input_buffer);
self.input_buffer = 0;
self.next_byte = 0;
}
}
}

// Extract the first W bytes and return as a fixed-size array
// Need to copy the data, not just returning a slice
let mut output = [0u8; W];
output.copy_from_slice(&b_output[..W]);
output
pub fn finalize(&mut self, output: &mut [u8]) {
if self.next_byte > 0 {
self.input_buffer |= 0x01 << (8 * self.next_byte);
self.xor_word_to_state(self.input_buffer);
} else {
self.xor_word_to_state(0x01);
}

while self.next_word < Self::RATE - 1 {
self.xor_word_to_state(0);
}
self.xor_word_to_state(0x80000000);

for i in 0..4 {
output[i * 8..(i * 8 + 4)].copy_from_slice(&self.state[i * 2 + 1].to_le_bytes());
output[(i * 8 + 4)..(i * 8 + 8)].copy_from_slice(&self.state[i * 2].to_le_bytes());
}
}
}
26 changes: 13 additions & 13 deletions riscv/src/large_field/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,22 @@ impl Runtime {

fn with_keccak(mut self) -> Self {
self.add_submachine(
"std::machines::hash::keccakf::KeccakF",
"std::machines::hash::keccakf32_memory::Keccakf32Memory",
None,
"keccakf",
vec!["memory"],
vec!["memory", "MIN_DEGREE", "MAIN_MAX_DEGREE"],
[r#"instr keccakf X, Y
link ~> tmp1_col = regs.mload(X, STEP),
link ~> tmp2_col = regs.mload(Y, STEP + 1)
link ~> keccakf.keccakf(tmp1_col, tmp2_col, STEP)
{
// make sure tmp1_col and tmp2_col are aligned memory addresses
tmp3_col * 4 = tmp1_col,
tmp4_col * 4 = tmp2_col,
// make sure the factors fit in 32 bits
tmp3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000,
tmp4_col = Y_b5 + Y_b6 * 0x100 + Y_b7 * 0x10000 + Y_b8 * 0x1000000
}
link ~> tmp1_col = regs.mload(X, STEP)
link ~> tmp2_col = regs.mload(Y, STEP + 1)
link ~> keccakf.keccakf32_memory(tmp1_col, tmp2_col, STEP)
{
// make sure tmp1_col and tmp2_col are aligned memory addresses
tmp3_col * 4 = tmp1_col,
tmp4_col * 4 = tmp2_col,
// make sure the factors fit in 32 bits
tmp3_col = X_b1 + X_b2 * 0x100 + X_b3 * 0x10000 + X_b4 * 0x1000000,
tmp4_col = Y_b5 + Y_b6 * 0x100 + Y_b7 * 0x10000 + Y_b8 * 0x1000000
}
"#
.to_string()],
0,
Expand Down
8 changes: 8 additions & 0 deletions riscv/tests/riscv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,14 @@ fn keccak() {
verify_riscv_crate(case, Default::default(), false);
}

#[test]
#[ignore = "Too slow"]
fn keccak_powdr() {
let case = "keccak_powdr";
let options = CompilerOptions::new_gl().with_keccak();
verify_riscv_crate_gl_with_options(case, Default::default(), options, false);
}

#[cfg(feature = "estark-polygon")]
#[test]
#[ignore = "Too slow"]
Expand Down
1 change: 0 additions & 1 deletion riscv/tests/riscv_data/keccak_powdr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,5 @@ edition = "2021"

[dependencies]
powdr-riscv-runtime = { path = "../../../../riscv-runtime" }
hex-literal = "0.3.1"

[workspace]
53 changes: 27 additions & 26 deletions riscv/tests/riscv_data/keccak_powdr/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,35 @@
#![no_main]
#![no_std]
extern crate alloc;
use alloc::vec::Vec;

use powdr_riscv_runtime::hash::keccak;
use hex_literal::hex;
extern crate powdr_riscv_runtime;
use powdr_riscv_runtime::hash::Keccak;

#[no_mangle]
pub fn main() {
// Tests using our keccak syscall.
let input: [&[u8]; 4] = [
// Zokrates test vectors
&[0x7a, 0x6f, 0x6b, 0x72, 0x61, 0x74, 0x65, 0x73],
&[0x2a; 135],
&[0x2a; 136],
// All zero test vector
&[0x00; 256],
];
let inputs = [b"Solidity", b"Powdrrrr"];
let mut hasher = Keccak::v256();
let mut output = [0u8; 32];
for input in inputs.into_iter().cycle().take(100) {
hasher.update(input);
}
hasher.finalize(&mut output);

let output: Vec<[u8; 32]> = input.iter().map(|x| keccak(x, 0x01)).collect();

let expected = [
hex!("ca85d1976d40dcb6ca3becc8c6596e83c0774f4185cf016a05834f5856a37f39"),
hex!("723e2ae02ca8d8fb45dca21e5f6369c4f124da72f217dca5e657a4bbc69b917d"),
hex!("e60d5160227cb1b8dc8547deb9c6a2c5e6c3306a1ca155611a73ed2c2324bfc0"),
hex!("d397b3b043d87fcd6fad1291ff0bfd16401c274896d8c63a923727f077b8e0b5")
];

// Currently commented out because keccakf syscall is not ready (roadblocked by circuit compiler).
// output.iter().zip(expected.iter()).for_each(|(out, exp)| {
// assert_eq!(out, exp);
// });
// The expected output was generated using tiny-keccak's Keccak256 implementation
// with the following code:
// use tiny_keccak::{Hasher, Keccak};
//
// let mut output = [0u8; 32];
// let mut hasher = Keccak::v256();
// for input in inputs.into_iter().cycle().take(100) {
// hasher.update(input);
// }
// hasher.finalize(&mut output);
assert_eq!(
output,
[
0xb2, 0x60, 0x1c, 0x72, 0x12, 0xd8, 0x26, 0x0d, 0xa4, 0x6d, 0xde, 0x19, 0x8d, 0x50,
0xa7, 0xe4, 0x67, 0x1f, 0xc1, 0xbb, 0x8f, 0xf2, 0xd1, 0x72, 0x5a, 0x8d, 0xa1, 0x08,
0x11, 0xb5, 0x81, 0x69
]
);
}

0 comments on commit 50fd458

Please sign in to comment.