Skip to content

Commit

Permalink
Merge pull request #4 from HerodotusDev/mpt-key-be
Browse files Browse the repository at this point in the history
Key Verification in Big endian
  • Loading branch information
feltroidprime authored May 8, 2024
2 parents 629825c + 4dcb024 commit 7e9aeec
Show file tree
Hide file tree
Showing 17 changed files with 4,556 additions and 283 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,6 @@ jobs:
- name: Run Cairo tests
env:
RPC_URL_MAINNET: ${{ secrets.RPC_URL_MAINNET }}
run: source ./tools/make/cairo_tests.sh
run: source ./tools/make/cairo_tests.sh
- name: Run MPT tests
run: source ./tools/make/fuzzer.sh tests/fuzzing/mpt.cairo --ci
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ out/
.env
.encryptedKey
broadcast/

*.idea
tests/rust/target
node_modules
package-lock.json

.DS_Store
src/.DS_Store

solidity-verifier/lib/*
*.log

!tests/fuzzing/fixtures/*.json
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ test-full:

format-cairo:
@echo "Format all .cairo files"
./tools/make/format_cairo_files.sh
./tools/make/format_cairo_files.sh

fuzz-mpt:
./tools/make/fuzzer.sh tests/fuzzing/mpt.cairo
268 changes: 162 additions & 106 deletions lib/mpt.cairo

Large diffs are not rendered by default.

506 changes: 358 additions & 148 deletions lib/rlp_little.cairo

Large diffs are not rendered by default.

194 changes: 187 additions & 7 deletions lib/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,174 @@ from starkware.cairo.common.registers import get_fp_and_pc
const DIV_32 = 2 ** 32;
const DIV_32_MINUS_1 = DIV_32 - 1;

// Takes the hex representation and count the number of zeroes.
// Ie: returns the number of trailing zeroes bytes.
// If x is 0, returns 16.
func count_trailing_zeroes_128{bitwise_ptr: BitwiseBuiltin*}(x: felt, pow2_array: felt*) -> (
res: felt
) {
if (x == 0) {
return (res=16);
}
alloc_locals;
local trailing_zeroes_bytes;
%{
from tools.py.utils import count_trailing_zero_bytes_from_int
ids.trailing_zeroes_bytes = count_trailing_zero_bytes_from_int(ids.x)
#print(f"Input: {hex(ids.x)}_{ids.trailing_zeroes_bytes}Tr_Zerobytes")
%}
// Verify.
if (trailing_zeroes_bytes == 0) {
// Make sure the last byte is not zero.
let (_, last_byte) = bitwise_divmod(x, 2 ** 8);
if (last_byte == 0) {
assert 1 = 0; // Add unsatisfiability check.
return (res=0);
} else {
return (res=0);
}
} else {
// Make sure the last trailing_zeroes_bytes are zeroes.
let (q, r) = bitwise_divmod(x, pow2_array[8 * trailing_zeroes_bytes]);
assert r = 0;
// Make sure the byte just before the last trailing_zeroes_bytes is not zero.
let (_, first_non_zero_byte) = bitwise_divmod(q, 2 ** 8);
if (first_non_zero_byte == 0) {
assert 1 = 0; // Add unsatisfiability check.
return (res=0);
} else {
return (res=trailing_zeroes_bytes);
}
}
}

// Returns the number of bytes in a number with n_bits bits.
// Assumptions:
// - 0 <= n_bits < 8 * RC_BOUND
func n_bits_to_n_bytes{range_check_ptr: felt}(n_bits: felt) -> (res: felt) {
if (n_bits == 0) {
return (res=0);
}
let (q, r) = felt_divmod_8(n_bits);
if (q == 0) {
return (res=1);
}
if (r == 0) {
return (res=q);
}
return (res=q + 1);
}

// Returns the number of nibbles in a number with n_bits bits.
// Assumptions:
// - 0 <= n_bits < 4 * RC_BOUND
func n_bits_to_n_nibbles{range_check_ptr: felt}(n_bits: felt) -> (res: felt) {
if (n_bits == 0) {
return (res=0);
}
let (q, r) = felt_divmod(n_bits, 4);
if (q == 0) {
return (res=1);
}
if (r == 0) {
return (res=q);
}
return (res=q + 1);
}

// Returns the number of bytes in a 128 bits number.
// Assumptions:
// - 0 <= x < 2^128
func get_felt_n_bytes_128{range_check_ptr: felt}(x: felt, pow2_array: felt*) -> (n_bytes: felt) {
let n_bits = get_felt_bitlength_128{pow2_array=pow2_array}(x);
let (n_bytes) = n_bits_to_n_bytes(n_bits);
return (n_bytes,);
}

// Returns the number of nibbles in a 128 bits number.
func get_felt_n_nibbles{range_check_ptr: felt}(x: felt, pow2_array: felt*) -> (n_nibbles: felt) {
let n_bits = get_felt_bitlength_128{pow2_array=pow2_array}(x);
let (n_nibbles) = n_bits_to_n_nibbles(n_bits);
return (n_nibbles,);
}
// Returns the total number of bits in the uint256 number.
// Assumptions :
// - 0 <= x < 2^256
// Returns:
// - nbits: felt - Total number of bits in the uint256 number.
func get_uint256_bit_length{range_check_ptr}(x: Uint256, pow2_array: felt*) -> (nbits: felt) {
alloc_locals;
with pow2_array {
if (x.high != 0) {
let x_bit_high = get_felt_bitlength_128(x.high);
return (nbits=128 + x_bit_high);
} else {
if (x.low != 0) {
let x_bit_low = get_felt_bitlength_128(x.low);
return (nbits=x_bit_low);
} else {
return (nbits=0);
}
}
}
}

// Takes a uint128 number, reverse its byte endianness without adding right-padding
// Ex :
// Input = 0x123456
// Output = 0x563412
// Input = 0x123
// Output = 0x0312
func uint128_reverse_endian_no_padding{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
x: felt, pow2_array: felt*
) -> (res: felt, n_bytes: felt) {
alloc_locals;
let (num_bytes_input) = get_felt_n_bytes_128(x, pow2_array);
let (x_reversed) = word_reverse_endian(x);
let (num_bytes_reversed) = get_felt_n_bytes_128(x_reversed, pow2_array);
let (trailing_zeroes_input) = count_trailing_zeroes_128(x, pow2_array);

if (num_bytes_input != num_bytes_reversed) {
// %{ print(f"\tinput128: {hex(ids.x)}_{ids.num_bytes_input}bytes") %}
// %{ print(f"\treversed: {hex(ids.x_reversed)}_{ids.num_bytes_reversed}bytes") %}
let (x_reversed, r) = bitwise_divmod(
x_reversed,
pow2_array[8 * (num_bytes_reversed - num_bytes_input + trailing_zeroes_input)],
);
assert r = 0; // Sanity check.
// %{
// import math
// print(f"\treversed_fixed: {hex(ids.x_reversed)}_{math.ceil(ids.x_reversed.bit_length() / 8)}bytes")
// %}
return (res=x_reversed, n_bytes=num_bytes_input);
}
return (res=x_reversed, n_bytes=num_bytes_input);
}

// Takes a uint256 number, reverse its byte endianness without adding right-padding and returns the number of bytes.
func uint256_reverse_endian_no_padding{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
x: Uint256, pow2_array: felt*
) -> (res: Uint256, n_bytes: felt) {
alloc_locals;
if (x.high != 0) {
let (high_reversed, n_bytes_high) = uint128_reverse_endian_no_padding(x.high, pow2_array);
// %{ print(f"High_Rev: {hex(ids.high_reversed)}_{ids.high_reversed.bit_length()}b {ids.n_bytes_high}bytes") %}
let (low_reversed) = word_reverse_endian(x.low);
// %{ print(f"Low_rev: {hex(ids.low_reversed)}_{ids.low_reversed.bit_length()}b") %}
let (q, r) = bitwise_divmod(low_reversed, pow2_array[8 * (16 - n_bytes_high)]);
// %{ print(f"Q: {hex(ids.q)}") %}
// %{ print(f"R: {hex(ids.r)}") %}
return (
res=Uint256(low=high_reversed + pow2_array[8 * n_bytes_high] * r, high=q),
n_bytes=16 + n_bytes_high,
);
} else {
let (low_reversed, n_bytes_low) = uint128_reverse_endian_no_padding(x.low, pow2_array);
return (res=Uint256(low=low_reversed, high=0), n_bytes=n_bytes_low);
}
}

// Adds two integers. Returns the result as a 256-bit integer and the (1-bit) carry.
// Strictly equivalent and faster version of common.uint256.uint256_add using the same whitelisted hint.
func uint256_add{range_check_ptr}(a: Uint256, b: Uint256) -> (res: Uint256, carry: felt) {
Expand Down Expand Up @@ -126,10 +294,13 @@ func write_felt_array_to_dict_keys{dict_end: DictAccess*}(array: felt*, index: f
// Params:
// - x: felt - Input value.
// Assumptions for the caller:
// - 1 <= x < 2^127
// - 0 <= x < 2^127
// Returns:
// - bit_length: felt - Number of bits in x.
func get_felt_bitlength{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
if (x == 0) {
return 0;
}
alloc_locals;
local bit_length;
%{
Expand All @@ -154,12 +325,16 @@ func get_felt_bitlength{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
// Params:
// - x: felt - Input value.
// Assumptions for the caller:
// - 1 <= x < 2^128
// - 0 <= x < 2^128
// Returns:
// - bit_length: felt - Number of bits in x.
func get_felt_bitlength_128{range_check_ptr, pow2_array: felt*}(x: felt) -> felt {
if (x == 0) {
return 0;
}
alloc_locals;
local bit_length;

%{
x = ids.x
ids.bit_length = x.bit_length()
Expand Down Expand Up @@ -191,12 +366,17 @@ func get_felt_bitlength_128{range_check_ptr, pow2_array: felt*}(x: felt) -> felt
// q: the quotient.
// r: the remainder.
func bitwise_divmod{bitwise_ptr: BitwiseBuiltin*}(x: felt, y: felt) -> (q: felt, r: felt) {
assert bitwise_ptr.x = x;
assert bitwise_ptr.y = y - 1;
let x_and_y = bitwise_ptr.x_and_y;
if (y == 1) {
let bitwise_ptr = bitwise_ptr;
return (q=x, r=0);
} else {
assert bitwise_ptr.x = x;
assert bitwise_ptr.y = y - 1;
let x_and_y = bitwise_ptr.x_and_y;

let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE;
return (q=(x - x_and_y) / y, r=x_and_y);
let bitwise_ptr = bitwise_ptr + BitwiseBuiltin.SIZE;
return (q=(x - x_and_y) / y, r=x_and_y);
}
}

// Computes x//(2**32) and x%(2**32) using range checks operations.
Expand Down
37 changes: 20 additions & 17 deletions tests/cairo_programs/mpt_prove_EOA_test.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,10 @@ func main{
%}

let (pow2_array: felt*) = pow2alloc127();
let (keys_little: Uint256*) = alloc();
let (keys_be: Uint256*) = alloc();

hash_n_addresses(
addresses_64_little=addresses_64_little,
keys_little=keys_little,
n_addresses=n_proofs,
index=0,
addresses_64_little=addresses_64_little, keys_be=keys_be, n_addresses=n_proofs, index=0
);

let (values: felt**) = alloc();
Expand All @@ -91,7 +88,7 @@ func main{
mpt_proofs=account_proofs,
mpt_proofs_bytes_len=account_proofs_bytes_len,
mpt_proofs_len=account_proofs_len,
keys_little=keys_little,
keys_be=keys_be,
hashes_to_assert=state_roots,
n_proofs=n_proofs,
index=0,
Expand All @@ -104,19 +101,20 @@ func main{
}

func hash_n_addresses{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*}(
addresses_64_little: felt**, keys_little: Uint256*, n_addresses: felt, index: felt
addresses_64_little: felt**, keys_be: Uint256*, n_addresses: felt, index: felt
) {
alloc_locals;
if (index == n_addresses) {
return ();
} else {
let (hash: Uint256) = keccak(addresses_64_little[index], 20);
assert keys_little[index].low = hash.low;
assert keys_little[index].high = hash.high;
let (hash_le: Uint256) = keccak(addresses_64_little[index], 20);
let (hash: Uint256) = uint256_reverse_endian(hash_le);
assert keys_be[index].low = hash.low;
assert keys_be[index].high = hash.high;

return hash_n_addresses(
addresses_64_little=addresses_64_little,
keys_little=keys_little,
keys_be=keys_be,
n_addresses=n_addresses,
index=index + 1,
);
Expand All @@ -127,7 +125,7 @@ func verify_n_mpt_proofs{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_p
mpt_proofs: felt***,
mpt_proofs_bytes_len: felt**,
mpt_proofs_len: felt*,
keys_little: Uint256*,
keys_be: Uint256*,
hashes_to_assert: Uint256*,
n_proofs: felt,
index: felt,
Expand All @@ -139,14 +137,19 @@ func verify_n_mpt_proofs{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_p
if (index == n_proofs) {
return (values=values, values_lens=values_lens);
} else {
local key_be_leading_zeroes_nibbles;
let key_be: Uint256 = keys_be[index];
%{
from tools.py.utils import count_leading_zero_nibbles_from_hex
ids.key_be_leading_zeroes_nibbles = count_leading_zero_nibbles_from_hex(hex(ids.key_be.low+2**128*ids.key_be.high))
%}
let (value: felt*, value_len: felt) = verify_mpt_proof(
mpt_proof=mpt_proofs[index],
mpt_proof_bytes_len=mpt_proofs_bytes_len[index],
mpt_proof_len=mpt_proofs_len[index],
key_little=keys_little[index],
n_nibbles_already_checked=0,
node_index=0,
hash_to_assert=hashes_to_assert[index],
key_be=key_be,
key_be_leading_zeroes_nibbles=key_be_leading_zeroes_nibbles,
root=hashes_to_assert[index],
pow2_array=pow2_array,
);
assert values_lens[index] = value_len;
Expand All @@ -155,7 +158,7 @@ func verify_n_mpt_proofs{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_p
mpt_proofs=mpt_proofs,
mpt_proofs_bytes_len=mpt_proofs_bytes_len,
mpt_proofs_len=mpt_proofs_len,
keys_little=keys_little,
keys_be=keys_be,
hashes_to_assert=hashes_to_assert,
n_proofs=n_proofs,
index=index + 1,
Expand Down
Loading

0 comments on commit 7e9aeec

Please sign in to comment.