diff --git a/cairo/ethereum/cancun/vm.cairo b/cairo/ethereum/cancun/vm.cairo index 701aaf8c..584f26a2 100644 --- a/cairo/ethereum/cancun/vm.cairo +++ b/cairo/ethereum/cancun/vm.cairo @@ -297,7 +297,7 @@ namespace EvmImpl { return (); } - func set_refund_counter{evm: Evm}(new_refund_counter: Uint) { + func set_refund_counter{evm: Evm}(new_refund_counter: felt) { tempvar evm = Evm( new EvmStruct( pc=evm.value.pc, diff --git a/cairo/ethereum/cancun/vm/instructions/storage.cairo b/cairo/ethereum/cancun/vm/instructions/storage.cairo index 7544cc80..c9608805 100644 --- a/cairo/ethereum/cancun/vm/instructions/storage.cairo +++ b/cairo/ethereum/cancun/vm/instructions/storage.cairo @@ -1,8 +1,15 @@ from ethereum.cancun.vm.stack import pop, push from ethereum.cancun.vm import Evm, EvmImpl, Environment, EnvImpl -from ethereum.cancun.vm.exceptions import ExceptionalHalt, WriteInStaticContext +from ethereum.cancun.vm.exceptions import ExceptionalHalt, WriteInStaticContext, OutOfGasError from ethereum.cancun.vm.gas import charge_gas, GasConstants -from ethereum.cancun.state import get_storage, get_transient_storage, set_transient_storage +from ethereum.utils.numeric import U256__eq__ +from ethereum.cancun.state import ( + get_storage, + get_storage_original, + set_storage, + get_transient_storage, + set_transient_storage, +) from ethereum.cancun.fork_types import ( SetTupleAddressBytes32, SetTupleAddressBytes32DictAccess, @@ -21,6 +28,7 @@ from starkware.cairo.common.cairo_builtins import PoseidonBuiltin, BitwiseBuilti from starkware.cairo.lang.compiler.lib.registers import get_fp_and_pc from starkware.cairo.common.alloc import alloc from starkware.cairo.common.dict_access import DictAccess +from starkware.cairo.common.math_cmp import is_le // @notice Loads to the stack the value corresponding to a certain key from the // storage of the current account. @@ -96,6 +104,143 @@ func sload{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, poseidon_ptr: Poseidon return ok; } +// @notice Stores a value at a certain key in the current context's storage. +func sstore{ + range_check_ptr, bitwise_ptr: BitwiseBuiltin*, poseidon_ptr: PoseidonBuiltin*, evm: Evm +}() -> ExceptionalHalt* { + alloc_locals; + // STACK + let stack = evm.value.stack; + with stack { + let (key, err) = pop(); + if (cast(err, felt) != 0) { + return err; + } + let (new_value, err) = pop(); + if (cast(err, felt) != 0) { + return err; + } + } + + let is_gas_left_not_enough = is_le(evm.value.gas_left.value, GasConstants.GAS_CALL_STIPEND); + if (is_gas_left_not_enough != 0) { + tempvar err = new ExceptionalHalt(OutOfGasError); + return err; + } + + // Get storage values + let key_bytes32 = U256_to_be_bytes(key); + let state = evm.value.env.value.state; + let current_target = evm.value.message.value.current_target; + with state { + let original_value = get_storage_original(current_target, key_bytes32); + let current_value = get_storage(current_target, key_bytes32); + } + + // Gas calculation + // Check accessed storage keys + tempvar accessed_tuple = TupleAddressBytes32( + new TupleAddressBytes32Struct(current_target, key_bytes32) + ); + let (serialized_keys: felt*) = alloc(); + assert serialized_keys[0] = accessed_tuple.value.address.value; + assert serialized_keys[1] = accessed_tuple.value.bytes32.value.low; + assert serialized_keys[2] = accessed_tuple.value.bytes32.value.high; + let dict_ptr = cast(evm.value.accessed_storage_keys.value.dict_ptr, DictAccess*); + with dict_ptr { + let (is_present) = hashdict_read(3, serialized_keys); + if (is_present == 0) { + hashdict_write(3, serialized_keys, 1); + tempvar gas_cost = GasConstants.GAS_COLD_SLOAD; + tempvar poseidon_ptr = poseidon_ptr; + tempvar dict_ptr = dict_ptr; + } else { + tempvar gas_cost = 0; + tempvar poseidon_ptr = poseidon_ptr; + tempvar dict_ptr = dict_ptr; + } + } + let gas_cost = [ap - 3]; + let poseidon_ptr = cast([ap - 2], PoseidonBuiltin*); + let dict_ptr = cast([ap - 1], DictAccess*); + + let new_dict_ptr = cast(dict_ptr, SetTupleAddressBytes32DictAccess*); + tempvar new_accessed_storage_keys = SetTupleAddressBytes32( + new SetTupleAddressBytes32Struct( + evm.value.accessed_storage_keys.value.dict_ptr_start, new_dict_ptr + ), + ); + + // Calculate storage gas cost + tempvar zero_u256 = U256(new U256Struct(0, 0)); + let is_original_eq_current = U256__eq__(original_value, current_value); + let is_current_eq_new = U256__eq__(current_value, new_value); + let is_original_zero = U256__eq__(original_value, zero_u256); + if (is_original_eq_current.value != 0) { + if (is_current_eq_new.value == 0) { + if (is_original_zero.value != 0) { + tempvar gas_cost = gas_cost + GasConstants.GAS_STORAGE_SET; + } else { + tempvar gas_cost = gas_cost + ( + GasConstants.GAS_STORAGE_UPDATE - GasConstants.GAS_COLD_SLOAD + ); + } + } + } else { + tempvar gas_cost = GasConstants.GAS_WARM_ACCESS; + } + let gas_cost = [ap - 1]; + + tempvar refund_counter = evm.value.refund_counter; + let is_original_eq_new = U256__eq__(original_value, new_value); + // Refund calculation + if (is_current_eq_new.value == 0) { + let is_current_zero = U256__eq__(current_value, zero_u256); + let is_new_zero = U256__eq__(new_value, zero_u256); + if (is_original_zero.value == 0 and is_current_zero.value == 0 and is_new_zero.value != 0) { + refund_counter = refund_counter + GasConstants.GAS_STORAGE_CLEAR_REFUND; + } + if (is_original_zero.value == 0 and is_current_zero.value != 0) { + refund_counter = refund_counter - GasConstants.GAS_STORAGE_CLEAR_REFUND; + } + if (is_original_eq_new.value != 0) { + if (is_original_zero.value != 0) { + refund_counter = refund_counter + + (GasConstants.GAS_STORAGE_SET - GasConstants.GAS_WARM_ACCESS); + } else { + refund_counter = refund_counter + + (GasConstants.GAS_STORAGE_UPDATE - GasConstants.GAS_COLD_SLOAD - GasConstants.GAS_WARM_ACCESS); + } + } + } + + // Charge gas + let err = charge_gas(Uint(gas_cost)); + if (cast(err, felt) != 0) { + return err; + } + // Check static call + if (evm.value.message.value.is_static.value != 0) { + tempvar err = new ExceptionalHalt(WriteInStaticContext); + return err; + } + + // Set storage + with state { + set_storage(current_target, key_bytes32, new_value); + } + + // Update EVM state + let env = evm.value.env; + EnvImpl.set_state{env=env}(state); + EvmImpl.set_env(env); + EvmImpl.set_pc_stack(Uint(evm.value.pc.value + 1), stack); + EvmImpl.set_refund_counter(refund_counter); + EvmImpl.set_accessed_storage_keys(new_accessed_storage_keys); + let ok = cast(0, ExceptionalHalt*); + return ok; +} + // @notice Loads to the stack the value corresponding to a certain key from the // transient storage of the current account. func tload{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, poseidon_ptr: PoseidonBuiltin*, evm: Evm}( diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py b/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py index 02cf7066..13f235c6 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_arithmetic.py @@ -1,4 +1,3 @@ -import pytest from hypothesis import given from ethereum.cancun.vm.instructions.arithmetic import ( @@ -15,6 +14,7 @@ sub, ) from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder arithmetic_tests_strategy = EvmBuilder().with_stack().with_gas_left().build() @@ -26,7 +26,7 @@ def test_add(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("add", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): add(evm) return @@ -38,7 +38,7 @@ def test_sub(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("sub", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): sub(evm) return @@ -50,7 +50,7 @@ def test_mul(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("mul", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): mul(evm) return @@ -63,7 +63,7 @@ def test_div(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("div", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): div(evm) return @@ -75,7 +75,7 @@ def test_sdiv(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("sdiv", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): sdiv(evm) return @@ -87,7 +87,7 @@ def test_mod(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("mod", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): mod(evm) return @@ -99,7 +99,7 @@ def test_smod(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("smod", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): smod(evm) return @@ -111,7 +111,7 @@ def test_addmod(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("addmod", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): addmod(evm) return @@ -123,7 +123,7 @@ def test_mulmod(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("mulmod", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): mulmod(evm) return @@ -135,7 +135,7 @@ def test_exp(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("exp", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): exp(evm) return @@ -147,7 +147,7 @@ def test_signextend(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("signextend", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): signextend(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py b/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py index 3c66fa34..a3efbd72 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_bitwise.py @@ -1,4 +1,3 @@ -import pytest from hypothesis import given from ethereum.cancun.vm.exceptions import ExceptionalHalt @@ -13,6 +12,7 @@ get_byte, ) from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder bitwise_tests_strategy = EvmBuilder().with_stack().with_gas_left().build() @@ -24,7 +24,7 @@ def test_and(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("bitwise_and", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): bitwise_and(evm) return @@ -36,7 +36,7 @@ def test_or(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("bitwise_or", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): bitwise_or(evm) return @@ -48,7 +48,7 @@ def test_xor(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("bitwise_xor", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): bitwise_xor(evm) return @@ -60,7 +60,7 @@ def test_not(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("bitwise_not", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): bitwise_not(evm) return @@ -72,7 +72,7 @@ def test_get_byte(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("get_byte", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): get_byte(evm) return @@ -84,7 +84,7 @@ def test_shl(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("bitwise_shl", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): bitwise_shl(evm) return @@ -96,7 +96,7 @@ def test_shr(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("bitwise_shr", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): bitwise_shr(evm) return @@ -108,7 +108,7 @@ def test_sar(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("bitwise_sar", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): bitwise_sar(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_block.py b/cairo/tests/ethereum/cancun/vm/instructions/test_block.py index 90c0ca28..7df32c51 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_block.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_block.py @@ -1,4 +1,3 @@ -import pytest from ethereum_types.numeric import U64, Uint from hypothesis import given from hypothesis import strategies as st @@ -14,6 +13,7 @@ timestamp, ) from tests.utils.args_gen import Environment, Evm, TransientStorage +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder, address_zero from tests.utils.strategies import ( BLOCK_HASHES_LIST, @@ -73,7 +73,7 @@ def test_block_hash(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("block_hash", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): block_hash(evm) return @@ -85,7 +85,7 @@ def test_coinbase(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("coinbase", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): coinbase(evm) return @@ -97,7 +97,7 @@ def test_timestamp(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("timestamp", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): timestamp(evm) return @@ -109,7 +109,7 @@ def test_number(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("number", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): number(evm) return @@ -121,7 +121,7 @@ def test_prev_randao(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("prev_randao", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): prev_randao(evm) return @@ -133,7 +133,7 @@ def test_gas_limit(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("gas_limit", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): gas_limit(evm) return @@ -145,7 +145,7 @@ def test_chain_id(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("chain_id", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): chain_id(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py b/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py index e3551e08..957ae973 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_comparison.py @@ -1,4 +1,3 @@ -import pytest from hypothesis import given from ethereum.cancun.vm.exceptions import ExceptionalHalt @@ -11,6 +10,7 @@ signed_less_than, ) from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder comparison_tests_strategy = EvmBuilder().with_stack().with_gas_left().build() @@ -22,7 +22,7 @@ def test_less_than(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("less_than", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): less_than(evm) return @@ -34,7 +34,7 @@ def test_greater_than(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("greater_than", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): greater_than(evm) return @@ -46,7 +46,7 @@ def test_signed_less_than(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("signed_less_than", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): signed_less_than(evm) return @@ -58,7 +58,7 @@ def test_signed_greater_than(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("signed_greater_than", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): signed_greater_than(evm) return @@ -70,7 +70,7 @@ def test_equal(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("equal", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): equal(evm) return @@ -82,7 +82,7 @@ def test_is_zero(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("is_zero", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): is_zero(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py b/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py index 720dd1c8..1ee49394 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_control_flow.py @@ -1,4 +1,3 @@ -import pytest from ethereum_types.numeric import U256 from hypothesis import given @@ -13,6 +12,7 @@ ) from ethereum.cancun.vm.stack import push from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder @@ -22,7 +22,7 @@ def test_stop(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("stop", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): stop(evm) return @@ -49,7 +49,7 @@ def test_jump(self, cairo_run, evm: Evm, push_valid_jump_destination: bool): try: cairo_result = cairo_run("jump", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): jump(evm) return @@ -84,7 +84,7 @@ def test_jumpi( try: cairo_result = cairo_run("jumpi", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): jumpi(evm) return @@ -96,7 +96,7 @@ def test_pc(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("pc", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): pc(evm) return @@ -108,7 +108,7 @@ def test_gas_left(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("gas_left", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): gas_left(evm) return @@ -120,7 +120,7 @@ def test_jumpdest(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("jumpdest", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): jumpdest(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py b/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py index 1f34806b..b79eafbc 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_keccak.py @@ -1,4 +1,3 @@ -import pytest from ethereum_types.numeric import U256 from hypothesis import given @@ -6,6 +5,7 @@ from ethereum.cancun.vm.instructions.keccak import keccak from ethereum.cancun.vm.stack import push from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder from tests.utils.strategies import memory_lite_access_size, memory_lite_start_position @@ -25,7 +25,7 @@ def test_keccak(self, cairo_run, evm: Evm, start_index: U256, size: U256): try: cairo_result = cairo_run("keccak", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): keccak(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_log.py b/cairo/tests/ethereum/cancun/vm/instructions/test_log.py index 944d9814..28e627ca 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_log.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_log.py @@ -1,4 +1,3 @@ -import pytest from ethereum_types.numeric import U256 from hypothesis import given @@ -6,6 +5,7 @@ from ethereum.cancun.vm.instructions.log import log0, log1, log2, log3, log4 from ethereum.cancun.vm.stack import push from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder from tests.utils.strategies import memory_lite_access_size, memory_lite_start_position @@ -24,7 +24,7 @@ def test_log0(self, cairo_run, evm: Evm, start_index: U256, size: U256): try: cairo_result = cairo_run("log0", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): log0(evm) return @@ -46,7 +46,7 @@ def test_log1( try: cairo_result = cairo_run("log1", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): log1(evm) return @@ -76,7 +76,7 @@ def test_log2( try: cairo_result = cairo_run("log2", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): log2(evm) return @@ -109,7 +109,7 @@ def test_log3( try: cairo_result = cairo_run("log3", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): log3(evm) return @@ -145,7 +145,7 @@ def test_log4( try: cairo_result = cairo_run("log4", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): log4(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py b/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py index 3825f1d8..d0c214c1 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_memory_instructions.py @@ -1,4 +1,3 @@ -import pytest from ethereum_types.numeric import U256 from hypothesis import given @@ -6,6 +5,7 @@ from ethereum.cancun.vm.instructions.memory import mcopy, mload, msize, mstore, mstore8 from ethereum.cancun.vm.stack import push from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder from tests.utils.strategies import ( memory_lite_access_size, @@ -32,7 +32,7 @@ def test_mstore( try: cairo_result = cairo_run("mstore", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): mstore(evm) return @@ -55,7 +55,7 @@ def test_mstore8( try: cairo_result = cairo_run("mstore8", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): mstore8(evm) return @@ -76,7 +76,7 @@ def test_mload( try: cairo_result = cairo_run("mload", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): mload(evm) return @@ -88,7 +88,7 @@ def test_msize(self, cairo_run, evm: Evm): try: cairo_result = cairo_run("msize", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): msize(evm) return @@ -118,7 +118,7 @@ def test_mcopy( try: cairo_result = cairo_run("mcopy", evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): mcopy(evm) return diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py b/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py index 8479ac63..7c6d28ae 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_stack_instructions.py @@ -4,6 +4,7 @@ import ethereum.cancun.vm.instructions.stack as stack from ethereum.cancun.vm.exceptions import ExceptionalHalt from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder @@ -16,7 +17,7 @@ def test_push_n(self, cairo_run, evm: Evm, num_bytes: int): try: cairo_result = cairo_run(func_name, evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): push_i(evm) return @@ -33,7 +34,7 @@ def test_swap_n(self, cairo_run, evm: Evm, item_number: int): try: cairo_result = cairo_run(func_name, evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): swap_i(evm) return @@ -50,7 +51,7 @@ def test_dup_n(self, cairo_run, evm: Evm, item_number: int): try: cairo_result = cairo_run(func_name, evm) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): dup_i(evm) return dup_i(evm) diff --git a/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py b/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py index 989f4bcd..b3a5b18b 100644 --- a/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py +++ b/cairo/tests/ethereum/cancun/vm/instructions/test_storage.py @@ -1,14 +1,14 @@ from typing import Tuple -import pytest from ethereum_types.bytes import Bytes20, Bytes32 from ethereum_types.numeric import U256 from hypothesis import given from hypothesis import strategies as st from hypothesis.strategies import composite -from ethereum.cancun.vm.instructions.storage import sload, tload, tstore +from ethereum.cancun.vm.instructions.storage import sload, sstore, tload, tstore from tests.utils.args_gen import Evm +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder from tests.utils.strategies import MAX_STORAGE_KEY_SET_SIZE @@ -38,19 +38,30 @@ def test_sload(self, cairo_run, evm: Evm): try: cairo_evm = cairo_run("sload", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): sload(evm) return sload(evm) assert evm == cairo_evm + @given(evm=evm_with_accessed_storage_keys()) + def test_sstore(self, cairo_run, evm: Evm): + try: + cairo_evm = cairo_run("sstore", evm) + except Exception as cairo_error: + with strict_raises(type(cairo_error)): + sstore(evm) + return + sstore(evm) + assert evm == cairo_evm + @given(evm=evm_with_accessed_storage_keys()) def test_tload(self, cairo_run, evm: Evm): try: cairo_evm = cairo_run("tload", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): tload(evm) return @@ -62,7 +73,7 @@ def test_tstore(self, cairo_run, evm: Evm): try: cairo_evm = cairo_run("tstore", evm) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): tstore(evm) return tstore(evm) diff --git a/cairo/tests/ethereum/cancun/vm/test_gas.py b/cairo/tests/ethereum/cancun/vm/test_gas.py index 8964e4e7..7e15c93b 100644 --- a/cairo/tests/ethereum/cancun/vm/test_gas.py +++ b/cairo/tests/ethereum/cancun/vm/test_gas.py @@ -1,6 +1,5 @@ from typing import List, Tuple -import pytest from ethereum_types.numeric import U256, Uint from hypothesis import assume, given from hypothesis import strategies as st @@ -23,6 +22,7 @@ max_message_call_gas, ) from tests.utils.args_gen import Evm, Memory +from tests.utils.errors import strict_raises from tests.utils.evm_builder import EvmBuilder @@ -40,7 +40,7 @@ def test_charge_gas(self, cairo_run, evm: Evm, amount: Uint): try: cairo_result = cairo_run("charge_gas", evm, amount) except Exception as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): charge_gas(evm, amount) return @@ -61,7 +61,7 @@ def test_calculate_gas_extend_memory( try: cairo_result = cairo_run("calculate_gas_extend_memory", memory, extensions) except ExceptionalHalt as cairo_error: - with pytest.raises(type(cairo_error)): + with strict_raises(type(cairo_error)): calculate_gas_extend_memory(memory, extensions) return diff --git a/cairo/tests/fixtures/runner.py b/cairo/tests/fixtures/runner.py index 6aa5457e..3277e13d 100644 --- a/cairo/tests/fixtures/runner.py +++ b/cairo/tests/fixtures/runner.py @@ -460,6 +460,7 @@ def _run(entrypoint, *args, **kwargs): except Exception as e: if "An ASSERT_EQ instruction failed" in str(e): raise AssertionError(e) from e + raise Exception(str(e)) from e cumulative_retdata_offsets = serde.get_offsets(return_data_types) first_return_data_offset = ( diff --git a/cairo/tests/utils/errors.py b/cairo/tests/utils/errors.py index 1cec5c12..9b85acde 100644 --- a/cairo/tests/utils/errors.py +++ b/cairo/tests/utils/errors.py @@ -1,5 +1,6 @@ import re from contextlib import contextmanager +from typing import Type import pytest @@ -16,3 +17,33 @@ def cairo_error(message=None): assert message in error, f"Expected {message}, got {error}" finally: pass + + +@contextmanager +def strict_raises(expected_exception: Type[Exception], match: str = None): + """ + Context manager that extends pytest.raises to enforce strict exception type matching. + Unlike pytest.raises, this doesn't allow subclass exceptions to match. + + Args: + expected_exception: The exact exception type expected + match: Optional string pattern to match against the exception message + + Example: + with strict_raises(ValueError, match="invalid value"): + raise ValueError("invalid value") # passes + + with strict_raises(Exception): + raise ValueError() # fails - more specific exception + """ + with pytest.raises(Exception) as exc_info: + yield exc_info + + if type(exc_info.value) is not expected_exception: + raise AssertionError( + f"Expected exactly {expected_exception.__name__}, but got {type(exc_info.value).__name__}" + ) + + if match is not None: + error_msg = str(exc_info.value) + assert match in error_msg, f"Expected '{match}' in '{error_msg}'"