Skip to content

Commit

Permalink
Dump compilation artifacts during test collection (#323)
Browse files Browse the repository at this point in the history
Resolves #320

Time down from ~6min to ~3min when restarting the CI and everything is
skipped, which means initialization time down 2x and absolute time down
3 minutes whatever the context
  • Loading branch information
ClementWalter authored Jan 2, 2025
1 parent 37c972b commit f127a97
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 39 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ jobs:

- name: Run tests with caching
run: |
cd cairo
uv run compile_os
uv run pytest tests -n logical --durations=0 -v
uv run pytest cairo/tests -n logical --durations=0 -v -s --log-cli-level=DEBUG
- uses: actions/cache/save@v4
with:
Expand Down
2 changes: 2 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"filterwarnings",
"fixturenames",
"frozendict",
"fspaths",
"hookwrapper",
"intdigest",
"ipykernel",
Expand Down Expand Up @@ -51,6 +52,7 @@
"snakeviz",
"testrunfinished",
"usort",
"workercount",
"workerid",
"workerinput",
"xxhash"
Expand Down
2 changes: 1 addition & 1 deletion cairo/src/utils/dict.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ func dict_copy{range_check_ptr}(dict_start: DictAccess*, dict_end: DictAccess*)
) {
alloc_locals;
let (local new_start: DictAccess*) = alloc();
let new_end = new_start + (dict_end - dict_start);
tempvar new_end = new_start + (dict_end - dict_start);
memcpy(new_start, dict_start, dict_end - dict_start);
// Register the segment as a dict in the DictManager.
%{ dict_copy %}
Expand Down
118 changes: 84 additions & 34 deletions cairo/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import os
import shutil
import time
from pathlib import Path

import pytest
Expand Down Expand Up @@ -115,11 +117,14 @@ def seed(request):

def pytest_sessionstart(session):
session.results = dict()
session.build_dir = Path("build") / ".pytest_build"


def pytest_sessionfinish(session):

if xdist.is_xdist_controller(session):
logger.info("Controller worker: collecting tests to skip")
shutil.rmtree(session.build_dir)
tests_to_skip = session.config.cache.get(f"cairo_run/{CACHED_TESTS_FILE}", [])
for worker_id in range(session.config.option.numprocesses):
tests_to_skip += session.config.cache.get(
Expand All @@ -144,6 +149,7 @@ def pytest_sessionfinish(session):
return

logger.info("Sequential worker: collecting tests to skip")
shutil.rmtree(session.build_dir)
tests_to_skip = session.config.cache.get(f"cairo_run/{CACHED_TESTS_FILE}", [])
tests_to_skip += session_tests_to_skip
session.config.cache.set(f"cairo_run/{CACHED_TESTS_FILE}", list(set(tests_to_skip)))
Expand Down Expand Up @@ -175,39 +181,83 @@ def pytest_collection_modifyitems(session, config, items):
session.cairo_programs = {}
session.main_paths = {}
session.test_hashes = {}
for item in items:
if hasattr(item, "fixturenames") and set(item.fixturenames) & {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}:
if item.fspath not in session.cairo_files:
cairo_file = get_cairo_file(item.fspath)
session.cairo_files[item.fspath] = cairo_file
if item.fspath not in session.main_paths:
main_path = get_main_path(cairo_file)
session.main_paths[item.fspath] = main_path
if item.fspath not in session.cairo_programs:
cairo_program = get_cairo_program(cairo_file, main_path)
session.cairo_programs[item.fspath] = cairo_program

test_hash = xxhash.xxh64(
program_hash(cairo_program)
+ file_hash(item.fspath)
+ item.nodeid.encode()
+ file_hash(Path(__file__).parent / "fixtures" / "runner.py")
+ file_hash(Path(__file__).parent / "utils" / "serde.py")
+ file_hash(Path(__file__).parent / "utils" / "args_gen.py")
).hexdigest()
session.test_hashes[item.nodeid] = test_hash

if config.getoption("no_skip_mark"):
item.own_markers = [
mark for mark in item.own_markers if mark.name != "skip"
]

if test_hash in tests_to_skip and config.getoption("skip_cached_tests"):
item.add_marker(pytest.mark.skip(reason="Cached results"))
cairo_items = [
item
for item in items
if (
hasattr(item, "fixturenames")
and set(item.fixturenames)
& {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}
)
]

# Distribute compilation using modulo
worker_count = getattr(config, "workerinput", {}).get("workercount", 1)
worker_id = getattr(config, "workerinput", {}).get("workerid", "master")
worker_index = int(worker_id[2:]) if worker_id != "master" else 0
fspaths = sorted(list({item.fspath for item in cairo_items}))
for fspath in fspaths[worker_index::worker_count]:
cairo_file = get_cairo_file(fspath)
main_path = get_main_path(cairo_file)
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
dump_path.parent.mkdir(parents=True, exist_ok=True)
get_cairo_program(cairo_file, main_path, dump_path)

# Wait for all workers to finish
all_paths = []
for item in cairo_items:
cairo_file = get_cairo_file(item.fspath)
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
all_paths.append(dump_path)

while not all([dump_path.exists() for dump_path in all_paths]):
logger.info(
f"Worker {worker_id} with index {worker_index} / {worker_count} waiting for other workers to finish"
)
# 0.25 seconds as observed to be one of the smallest time over the current test files
time.sleep(0.25)

# Select tests
for item in cairo_items:
if item.fspath not in session.cairo_files:
cairo_file = get_cairo_file(item.fspath)
session.cairo_files[item.fspath] = cairo_file
if item.fspath not in session.main_paths:
main_path = get_main_path(cairo_file)
session.main_paths[item.fspath] = main_path
if item.fspath not in session.cairo_programs:
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
cairo_program = get_cairo_program(cairo_file, main_path, dump_path)
session.cairo_programs[item.fspath] = cairo_program

cairo_program = session.cairo_programs[item.fspath]
test_hash = xxhash.xxh64(
program_hash(cairo_program)
+ file_hash(item.fspath)
+ item.nodeid.encode()
+ file_hash(Path(__file__).parent / "fixtures" / "runner.py")
+ file_hash(Path(__file__).parent / "utils" / "serde.py")
+ file_hash(Path(__file__).parent / "utils" / "args_gen.py")
).hexdigest()
session.test_hashes[item.nodeid] = test_hash

if config.getoption("no_skip_mark"):
item.own_markers = [
mark for mark in item.own_markers if mark.name != "skip"
]

if test_hash in tests_to_skip and config.getoption("skip_cached_tests"):
item.add_marker(pytest.mark.skip(reason="Cached results"))

yield
18 changes: 16 additions & 2 deletions cairo/tests/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,13 @@
c) Follows project directory structure for test organization
"""

import json
import logging
from pathlib import Path
from time import perf_counter
from typing import Optional

from starkware.cairo.lang.compiler.program import Program
from starkware.cairo.lang.compiler.scoped_name import ScopedName

from src.utils.compiler import cairo_compile, implement_hints
Expand Down Expand Up @@ -56,9 +59,19 @@ def get_main_path(cairo_file):
)


def get_cairo_program(cairo_file, main_path):
def get_cairo_program(cairo_file: Path, main_path, dump_path: Optional[Path] = None):
start = perf_counter()
program = cairo_compile(cairo_file, debug_info=True, proof_mode=False)
if dump_path is not None and dump_path.is_file():
logger.info(f"Loading program from {dump_path}")
program = Program.load(data=json.loads(dump_path.read_text()))
else:
logger.info(f"Compiling {cairo_file}")
program = cairo_compile(cairo_file, debug_info=True, proof_mode=False)
if dump_path is not None:
dump_path.write_text(
json.dumps(program.Schema().dump(program), indent=4, sort_keys=True)
)

program.hints = implement_hints(program)
all_identifiers = list(program.identifiers.dict.items())
# when running the tests, the main file is the test file
Expand All @@ -72,4 +85,5 @@ def get_cairo_program(cairo_file, main_path):
program.identifiers.add_identifier(ScopedName(main_path + k.path[1:]), v)
stop = perf_counter()
logger.info(f"{cairo_file} compiled in {stop - start:.2f}s")

return program

0 comments on commit f127a97

Please sign in to comment.