Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dump compilation artifacts during test collection #323

Merged
3 changes: 1 addition & 2 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ jobs:

- name: Run tests with caching
run: |
cd cairo
uv run compile_os
uv run pytest tests -n logical --durations=0 -v
uv run pytest cairo/tests -n logical --durations=0 -v -s --log-cli-level=DEBUG
enitrat marked this conversation as resolved.
Show resolved Hide resolved

- uses: actions/cache/save@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"filterwarnings",
"fixturenames",
"frozendict",
"fspaths",
"hookwrapper",
"intdigest",
"ipykernel",
Expand Down Expand Up @@ -51,6 +52,7 @@
"snakeviz",
"testrunfinished",
"usort",
"workercount",
"workerid",
"workerinput",
"xxhash"
Expand Down
165 changes: 165 additions & 0 deletions cairo/pyproject.toml
enitrat marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
[project]
name = "cairo"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"

dependencies = [
"cairo-lang>=0.13.2",
"ethereum",
"filelock>=3.16.1",
"marshmallow-dataclass>=8.6.1",
"python-dotenv>=1.0.1",
"toml>=0.10.2",
"web3>=7.2.0",
"xxhash>=3.5.0",
]


[tool.pytest.ini_options]
filterwarnings = [
"ignore:Using or importing the ABCs:DeprecationWarning", # from frozendict
"ignore:lexer_state will be removed in subsequent releases. Use lexer_thread instead.", # from lark
"ignore:abi:DeprecationWarning", # from web3
"ignore::marshmallow.warnings.RemovedInMarshmallow4Warning", # from marshmallow
]
asyncio_default_fixture_loop_scope = 'session'
markers = [
"ArithmeticOperations",
"ADD: Opcode Value 0x01 - Addition operation",
"MUL: Opcode Value 0x02 - Multiplication operation",
"SUB: Opcode Value 0x03 - Subtraction operation",
"DIV: Opcode Value 0x04 - Integer division operation",
"SDIV: Opcode Value 0x05 - Signed integer division operation (truncated)",
"MOD: Opcode Value 0x06 - Modulo remainder operation",
"SMOD: Opcode Value 0x07 - Signed modulo remainder operation",
"ADDMOD: Opcode Value 0x08 - Modulo addition operation",
"MULMOD: Opcode Value 0x09 - Modulo multiplication operation",
"EXP: Opcode Value 0x0a - Exponential operation",
"SIGNEXTEND: Opcode Value 0x0b - Extend length of two's complement signed integer",
"ComparisonBitwiseLogicOperations",
"LT: Opcode Value 0x10 - Less-than comparison",
"GT: Opcode Value 0x11 - Greater-than comparison",
"SLT: Opcode Value 0x12 - Signed less-than comparison",
"SGT: Opcode Value 0x13 - Signed greater-than comparison",
"EQ: Opcode Value 0x14 - Equality comparison",
"ISZERO: Opcode Value 0x15 - Simple not operator",
"AND: Opcode Value 0x16 - Bitwise AND operation",
"OR: Opcode Value 0x17 - Bitwise OR operation",
"NOT: Opcode Value 0x19 - Bitwise NOT operation",
"SHL: Opcode Value 0x1b - Shift left",
"SHR: Opcode Value 0x1c - Logical shift right",
"SAR: Opcode Value 0x1d - Arithmetic shift right",
"SHA3: Opcode Value 0x20 - Compute Keccak-256 hash",
"EnvironmentalInformation",
"ADDRESS: Opcode Value 0x30 - Get address of currently executing account",
"BALANCE: Opcode Value 0x31 - Get balance of the given account",
"ORIGIN: Opcode Value 0x32 - Get execution origination address",
"CALLER: Opcode Value 0x33 - Get caller address",
"CALLVALUE: Opcode Value 0x34 - Get deposited value by the instruction/transaction responsible for this execution",
"CALLDATALOAD: Opcode Value 0x35 - Get input data of current environment",
"CALLDATASIZE: Opcode Value 0x36 - Get size of input data in current environment",
"CALLDATACOPY: Opcode Value 0x37 - Copy input data in current environment to memory",
"CODESIZE: Opcode Value 0x38 - Get size of code running in current environment",
"CODECOPY: Opcode Value 0x39 - Copy code running in current environment to memory",
"RETURNDATASIZE: Opcode Value 0x3d - Get size of output data from the previous call from the current environment",
"BlockInformation",
"BLOCKHASH: Opcode Value 0x40 - Get the hash of one of the 256 most recent complete blocks",
"COINBASE: Opcode Value 0x41 - Get the block's beneficiary address",
"TIMESTAMP: Opcode Value 0x42 - Get the block's timestamp",
"NUMBER: Opcode Value 0x43 - Get the block's number",
"DIFFICULTY: Opcode Value 0x44 - Get the block's difficulty",
"GASLIMIT: Opcode Value 0x45 - Get the block's gas limit",
"CHAINID: Opcode Value 0x46 - Get the chain ID",
"SELFBALANCE: Opcode Value 0x47 - Get the balance of the current contract",
"BASEFEE: Opcode Value 0x48 - Get the base fee of the current block",
"BLOBHASH: Opcode Value 0x49 - Get the versioned hash at the requested index",
"BLOBBASEFEE: Opcode Value 0x4a - Get the blob base-fee of the current block",
"StackMemoryStorageFlowOperations",
"MLOAD: Opcode Value 0x51 - Load word from memory",
"MSTORE: Opcode Value 0x52 - Save word to memory",
"MSTORE8: Opcode Value 0x53 - Save byte to memory",
"SLOAD: Opcode Value 0x54 - Load word from storage",
"SSTORE: Opcode Value 0x55 - Save word to storage",
"JUMP: Opcode Value 0x56 - Alter the program counter",
"JUMPI: Opcode Value 0x57 - Conditionally alter the program counter",
"PC: Opcode Value 0x58 - Get the value of the program counter prior to the increment",
"MSIZE: Opcode Value 0x59 - Get the size of active memory in bytes",
"JUMPDEST: Opcode Value 0x5b - Mark a valid destination for jumps",
"TLOAD: Opcode Value 0x5c - Load word from transient storage",
"TSTORE: Opcode Value 0x5d - Save word to transient storage",
"MCOPY: Opcode Value 0x5e - Copy memory from one location to another",
"PushOperations",
"PUSH Opcodes 0x60 ~ 7f - Place n-byte item on stack",
"DuplicationOperations",
"DUP: Opcodes 0x80 ~ 8f - Duplicate nth stack item",
"ExchangeOperations",
"SWAP: Opcodes 0x90 ~ 9f - Exchange 1st and nth stack items",
"LoggingOperations",
"LOG: Opcodes 0xa0 ~ a4 - Append log record with n topics",
"SystemOperations",
"RETURN: Opcode Value 0xf3 - Halt execution returning output data",
"REVERT: Opcode value 0xfd - Halt execution reverting state changes",
"INVALID: Opcode Value 0xfe - Designated invalid instruction",
"Precompiles",
"EC_RECOVER: Precompile Value 0x01 - Elliptic curve digital signature algorithm (ECDSA) public key recovery function",
"SHA256: Precompile Value 0x02 - Hash function",
"RIPEMD160: Precompile Value 0x03 - Hash function",
"MOD_EXP: Precompile Value 0x05 - Modular exponentiation MVP - missing support for bigint",
"EC_ADD: Precompile Value 0x06 - Point addition (ADD) on the elliptic curve 'alt_bn128'",
"EC_MUL: Precompile Value 0x07 - Scalar multiplication (MUL) on the elliptic curve 'alt_bn128'",
"BLAKE2F: Precompile Value 0x09 - Blake2 compression function",
"Counter",
"PlainOpcodes",
"SolmateERC20",
"SolmateERC721",
"UniswapV2ERC20",
"UniswapV2Factory",
"RIP7212",
"CairoPrecompiles",
"UniswapV2Router",
"AccountContract",
"Utils",
"Safe",
"EFTests",
"SSTORE",
"SLOAD",
"NoCI",
"slow",
]
norecursedirs = ".* tests/ef_tests/test_data"

[project.scripts]
compile_os = "src.utils.compile_cairo:compile_os"
compile_fibonacci = "src.utils.compile_cairo:compile_fibonacci"
transpile = "scripts.convert_py_to_cairo:main"
generate-tests = "scripts.generate_tests:main"

[tool.uv]
dev-dependencies = [
"eth-abi>=5.1.0",
"eth-account>=0.13.3",
"eth-keys>=0.5.1",
"eth-utils>=5.0.0",
"hypothesis>=6.112.1",
"ipykernel>=6.29.5",
"pytest-xdist>=3.6.1",
"pytest>=8.3.3",
"pydantic>=2.9.1",
"polars>=1.18.0",
]

[tool.isort]
profile = "black"
src_paths = ["src", "tests"]

[tool.uv.sources]
ethereum = { git = "https://github.com/kkrt-labs/execution-specs.git", rev = "b255036441d64437bd4fc9f9068bc64c45470e93" }

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
packages = ["src"]
2 changes: 1 addition & 1 deletion cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func dict_copy{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*)
) {
alloc_locals;
let (local new_start: DictAccess*) = alloc();
let new_end = new_start + (dict_end - dict_start);
tempvar new_end = new_start + (dict_end - dict_start);
memcpy(new_start, dict_start, dict_end - dict_start);
// Register the segment as a dict in the DictManager.
%{ dict_copy %}
Expand Down
118 changes: 84 additions & 34 deletions cairo/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import shutil
import time
from pathlib import Path

import pytest
Expand Down Expand Up @@ -115,11 +117,14 @@ def seed(request):

def pytest_sessionstart(session):
session.results = dict()
session.build_dir = Path("build") / ".pytest_build"


def pytest_sessionfinish(session):

if xdist.is_xdist_controller(session):
logger.info("Controller worker: collecting tests to skip")
shutil.rmtree(session.build_dir)
tests_to_skip = session.config.cache.get(f"cairo_run/{CACHED_TESTS_FILE}", [])
for worker_id in range(session.config.option.numprocesses):
tests_to_skip += session.config.cache.get(
Expand All @@ -144,6 +149,7 @@ def pytest_sessionfinish(session):
return

logger.info("Sequential worker: collecting tests to skip")
shutil.rmtree(session.build_dir)
tests_to_skip = session.config.cache.get(f"cairo_run/{CACHED_TESTS_FILE}", [])
tests_to_skip += session_tests_to_skip
session.config.cache.set(f"cairo_run/{CACHED_TESTS_FILE}", list(set(tests_to_skip)))
Expand Down Expand Up @@ -175,39 +181,83 @@ def pytest_collection_modifyitems(session, config, items):
session.cairo_programs = {}
session.main_paths = {}
session.test_hashes = {}
for item in items:
if hasattr(item, "fixturenames") and set(item.fixturenames) & {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}:
if item.fspath not in session.cairo_files:
cairo_file = get_cairo_file(item.fspath)
session.cairo_files[item.fspath] = cairo_file
if item.fspath not in session.main_paths:
main_path = get_main_path(cairo_file)
session.main_paths[item.fspath] = main_path
if item.fspath not in session.cairo_programs:
cairo_program = get_cairo_program(cairo_file, main_path)
session.cairo_programs[item.fspath] = cairo_program

test_hash = xxhash.xxh64(
program_hash(cairo_program)
+ file_hash(item.fspath)
+ item.nodeid.encode()
+ file_hash(Path(__file__).parent / "fixtures" / "runner.py")
+ file_hash(Path(__file__).parent / "utils" / "serde.py")
+ file_hash(Path(__file__).parent / "utils" / "args_gen.py")
).hexdigest()
session.test_hashes[item.nodeid] = test_hash

if config.getoption("no_skip_mark"):
item.own_markers = [
mark for mark in item.own_markers if mark.name != "skip"
]

if test_hash in tests_to_skip and config.getoption("skip_cached_tests"):
item.add_marker(pytest.mark.skip(reason="Cached results"))
cairo_items = [
item
for item in items
if (
hasattr(item, "fixturenames")
and set(item.fixturenames)
& {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}
)
]

# Distribute compilation using modulo
worker_count = getattr(config, "workerinput", {}).get("workercount", 1)
worker_id = getattr(config, "workerinput", {}).get("workerid", "master")
worker_index = int(worker_id[2:]) if worker_id != "master" else 0
fspaths = sorted(list({item.fspath for item in cairo_items}))
for fspath in fspaths[worker_index::worker_count]:
cairo_file = get_cairo_file(fspath)
main_path = get_main_path(cairo_file)
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
dump_path.parent.mkdir(parents=True, exist_ok=True)
get_cairo_program(cairo_file, main_path, dump_path)

# Wait for all workers to finish
all_paths = []
for item in cairo_items:
cairo_file = get_cairo_file(item.fspath)
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
all_paths.append(dump_path)

while not all([dump_path.exists() for dump_path in all_paths]):
logger.info(
f"Worker {worker_id} with index {worker_index} / {worker_count} waiting for other workers to finish"
)
# 0.25 seconds as observed to be one of the smallest time over the current test files
time.sleep(0.25)

# Select tests
for item in cairo_items:
if item.fspath not in session.cairo_files:
cairo_file = get_cairo_file(item.fspath)
session.cairo_files[item.fspath] = cairo_file
if item.fspath not in session.main_paths:
main_path = get_main_path(cairo_file)
session.main_paths[item.fspath] = main_path
if item.fspath not in session.cairo_programs:
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
cairo_program = get_cairo_program(cairo_file, main_path, dump_path)
session.cairo_programs[item.fspath] = cairo_program

cairo_program = session.cairo_programs[item.fspath]
test_hash = xxhash.xxh64(
program_hash(cairo_program)
+ file_hash(item.fspath)
+ item.nodeid.encode()
+ file_hash(Path(__file__).parent / "fixtures" / "runner.py")
+ file_hash(Path(__file__).parent / "utils" / "serde.py")
+ file_hash(Path(__file__).parent / "utils" / "args_gen.py")
).hexdigest()
session.test_hashes[item.nodeid] = test_hash

if config.getoption("no_skip_mark"):
item.own_markers = [
mark for mark in item.own_markers if mark.name != "skip"
]

if test_hash in tests_to_skip and config.getoption("skip_cached_tests"):
item.add_marker(pytest.mark.skip(reason="Cached results"))

yield
18 changes: 16 additions & 2 deletions cairo/tests/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
c) Follows project directory structure for test organization
"""

import json
import logging
from pathlib import Path
from time import perf_counter
from typing import Optional

from starkware.cairo.lang.compiler.program import Program
from starkware.cairo.lang.compiler.scoped_name import ScopedName

from src.utils.compiler import cairo_compile, implement_hints
Expand Down Expand Up @@ -56,9 +59,19 @@ def get_main_path(cairo_file):
)


def get_cairo_program(cairo_file, main_path):
def get_cairo_program(cairo_file: Path, main_path, dump_path: Optional[Path] = None):
start = perf_counter()
program = cairo_compile(cairo_file, debug_info=True, proof_mode=False)
if dump_path is not None and dump_path.is_file():
logger.info(f"Loading program from {dump_path}")
program = Program.load(data=json.loads(dump_path.read_text()))
else:
logger.info(f"Compiling {cairo_file}")
program = cairo_compile(cairo_file, debug_info=True, proof_mode=False)
if dump_path is not None:
dump_path.write_text(
json.dumps(program.Schema().dump(program), indent=4, sort_keys=True)
)

program.hints = implement_hints(program)
all_identifiers = list(program.identifiers.dict.items())
# when running the tests, the main file is the test file
Expand All @@ -72,4 +85,5 @@ def get_cairo_program(cairo_file, main_path):
program.identifiers.add_identifier(ScopedName(main_path + k.path[1:]), v)
stop = perf_counter()
logger.info(f"{cairo_file} compiled in {stop - start:.2f}s")

return program
Loading