Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dump compilation artifacts during test collection #323

Merged
3 changes: 1 addition & 2 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ jobs:

- name: Run tests with caching
run: |
cd cairo
uv run compile_os
uv run pytest tests -n logical --durations=0 -v
uv run pytest cairo/tests -n logical --durations=0 -v -s --log-cli-level=DEBUG
enitrat marked this conversation as resolved.
Show resolved Hide resolved

- uses: actions/cache/save@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"filterwarnings",
"fixturenames",
"frozendict",
"fspaths",
"hookwrapper",
"intdigest",
"ipykernel",
Expand Down Expand Up @@ -51,6 +52,7 @@
"snakeviz",
"testrunfinished",
"usort",
"workercount",
"workerid",
"workerinput",
"xxhash"
Expand Down
2 changes: 1 addition & 1 deletion cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func dict_copy{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*)
) {
alloc_locals;
let (local new_start: DictAccess*) = alloc();
let new_end = new_start + (dict_end - dict_start);
tempvar new_end = new_start + (dict_end - dict_start);
memcpy(new_start, dict_start, dict_end - dict_start);
// Register the segment as a dict in the DictManager.
%{ dict_copy %}
Expand Down
118 changes: 84 additions & 34 deletions cairo/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import shutil
import time
from pathlib import Path

import pytest
Expand Down Expand Up @@ -115,11 +117,14 @@ def seed(request):

def pytest_sessionstart(session):
session.results = dict()
session.build_dir = Path("build") / ".pytest_build"


def pytest_sessionfinish(session):

if xdist.is_xdist_controller(session):
logger.info("Controller worker: collecting tests to skip")
shutil.rmtree(session.build_dir)
tests_to_skip = session.config.cache.get(f"cairo_run/{CACHED_TESTS_FILE}", [])
for worker_id in range(session.config.option.numprocesses):
tests_to_skip += session.config.cache.get(
Expand All @@ -144,6 +149,7 @@ def pytest_sessionfinish(session):
return

logger.info("Sequential worker: collecting tests to skip")
shutil.rmtree(session.build_dir)
tests_to_skip = session.config.cache.get(f"cairo_run/{CACHED_TESTS_FILE}", [])
tests_to_skip += session_tests_to_skip
session.config.cache.set(f"cairo_run/{CACHED_TESTS_FILE}", list(set(tests_to_skip)))
Expand Down Expand Up @@ -175,39 +181,83 @@ def pytest_collection_modifyitems(session, config, items):
session.cairo_programs = {}
session.main_paths = {}
session.test_hashes = {}
for item in items:
if hasattr(item, "fixturenames") and set(item.fixturenames) & {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}:
if item.fspath not in session.cairo_files:
cairo_file = get_cairo_file(item.fspath)
session.cairo_files[item.fspath] = cairo_file
if item.fspath not in session.main_paths:
main_path = get_main_path(cairo_file)
session.main_paths[item.fspath] = main_path
if item.fspath not in session.cairo_programs:
cairo_program = get_cairo_program(cairo_file, main_path)
session.cairo_programs[item.fspath] = cairo_program

test_hash = xxhash.xxh64(
program_hash(cairo_program)
+ file_hash(item.fspath)
+ item.nodeid.encode()
+ file_hash(Path(__file__).parent / "fixtures" / "runner.py")
+ file_hash(Path(__file__).parent / "utils" / "serde.py")
+ file_hash(Path(__file__).parent / "utils" / "args_gen.py")
).hexdigest()
session.test_hashes[item.nodeid] = test_hash

if config.getoption("no_skip_mark"):
item.own_markers = [
mark for mark in item.own_markers if mark.name != "skip"
]

if test_hash in tests_to_skip and config.getoption("skip_cached_tests"):
item.add_marker(pytest.mark.skip(reason="Cached results"))
cairo_items = [
item
for item in items
if (
hasattr(item, "fixturenames")
and set(item.fixturenames)
& {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}
)
]

# Distribute compilation using modulo
worker_count = getattr(config, "workerinput", {}).get("workercount", 1)
worker_id = getattr(config, "workerinput", {}).get("workerid", "master")
worker_index = int(worker_id[2:]) if worker_id != "master" else 0
fspaths = sorted(list({item.fspath for item in cairo_items}))
for fspath in fspaths[worker_index::worker_count]:
cairo_file = get_cairo_file(fspath)
main_path = get_main_path(cairo_file)
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
dump_path.parent.mkdir(parents=True, exist_ok=True)
get_cairo_program(cairo_file, main_path, dump_path)

# Wait for all workers to finish
all_paths = []
for item in cairo_items:
cairo_file = get_cairo_file(item.fspath)
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
all_paths.append(dump_path)

while not all([dump_path.exists() for dump_path in all_paths]):
logger.info(
f"Worker {worker_id} with index {worker_index} / {worker_count} waiting for other workers to finish"
)
# 0.25 seconds as observed to be one of the smallest time over the current test files
time.sleep(0.25)

# Select tests
for item in cairo_items:
if item.fspath not in session.cairo_files:
cairo_file = get_cairo_file(item.fspath)
session.cairo_files[item.fspath] = cairo_file
if item.fspath not in session.main_paths:
main_path = get_main_path(cairo_file)
session.main_paths[item.fspath] = main_path
if item.fspath not in session.cairo_programs:
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
cairo_program = get_cairo_program(cairo_file, main_path, dump_path)
session.cairo_programs[item.fspath] = cairo_program

cairo_program = session.cairo_programs[item.fspath]
test_hash = xxhash.xxh64(
program_hash(cairo_program)
+ file_hash(item.fspath)
+ item.nodeid.encode()
+ file_hash(Path(__file__).parent / "fixtures" / "runner.py")
+ file_hash(Path(__file__).parent / "utils" / "serde.py")
+ file_hash(Path(__file__).parent / "utils" / "args_gen.py")
).hexdigest()
session.test_hashes[item.nodeid] = test_hash

if config.getoption("no_skip_mark"):
item.own_markers = [
mark for mark in item.own_markers if mark.name != "skip"
]

if test_hash in tests_to_skip and config.getoption("skip_cached_tests"):
item.add_marker(pytest.mark.skip(reason="Cached results"))

yield
18 changes: 16 additions & 2 deletions cairo/tests/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
c) Follows project directory structure for test organization
"""

import json
import logging
from pathlib import Path
from time import perf_counter
from typing import Optional

from starkware.cairo.lang.compiler.program import Program
from starkware.cairo.lang.compiler.scoped_name import ScopedName

from src.utils.compiler import cairo_compile, implement_hints
Expand Down Expand Up @@ -56,9 +59,19 @@ def get_main_path(cairo_file):
)


def get_cairo_program(cairo_file, main_path):
def get_cairo_program(cairo_file: Path, main_path, dump_path: Optional[Path] = None):
start = perf_counter()
program = cairo_compile(cairo_file, debug_info=True, proof_mode=False)
if dump_path is not None and dump_path.is_file():
logger.info(f"Loading program from {dump_path}")
program = Program.load(data=json.loads(dump_path.read_text()))
else:
logger.info(f"Compiling {cairo_file}")
program = cairo_compile(cairo_file, debug_info=True, proof_mode=False)
if dump_path is not None:
dump_path.write_text(
json.dumps(program.Schema().dump(program), indent=4, sort_keys=True)
)

program.hints = implement_hints(program)
all_identifiers = list(program.identifiers.dict.items())
# when running the tests, the main file is the test file
Expand All @@ -72,4 +85,5 @@ def get_cairo_program(cairo_file, main_path):
program.identifiers.add_identifier(ScopedName(main_path + k.path[1:]), v)
stop = perf_counter()
logger.info(f"{cairo_file} compiled in {stop - start:.2f}s")

return program
Loading