Skip to content

Commit

Permalink
Distribute compilation accross workers
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementWalter committed Jan 1, 2025
1 parent e400c64 commit 61c9b89
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 27 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
"snakeviz",
"testrunfinished",
"usort",
"workercount",
"workerid",
"workerinput",
"xxhash"
Expand Down
82 changes: 60 additions & 22 deletions cairo/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os
import random
import shutil
import time
from pathlib import Path

import pytest
Expand Down Expand Up @@ -181,44 +181,82 @@ def pytest_collection_modifyitems(session, config, items):
session.cairo_programs = {}
session.main_paths = {}
session.test_hashes = {}
fspaths = list(
{
item.fspath
for item in items
if (
hasattr(item, "fixturenames")
and set(item.fixturenames)
& {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}
)
}
fspaths = sorted(
list(
{
item.fspath
for item in items
if (
hasattr(item, "fixturenames")
and set(item.fixturenames)
& {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}
)
}
)
)
random.shuffle(fspaths)

# Distribute compilation using modulo
worker_count = getattr(config, "workerinput", {}).get("workercount", 1)
worker_id = getattr(config, "workerinput", {}).get("workerid", "master")
worker_index = int(worker_id[2:]) if worker_id != "master" else 0
fspaths = [p for p in fspaths[worker_index::worker_count]]
for fspath in fspaths:
cairo_file = get_cairo_file(fspath)
session.cairo_files[fspath] = cairo_file
main_path = get_main_path(cairo_file)
session.main_paths[fspath] = main_path
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
dump_path.parent.mkdir(parents=True, exist_ok=True)
cairo_program = get_cairo_program(cairo_file, main_path, dump_path)
session.cairo_programs[fspath] = cairo_program
get_cairo_program(cairo_file, main_path, dump_path)

# Wait for all workers to finish
all_paths = []
for item in items:
if hasattr(item, "fixturenames") and set(item.fixturenames) & {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}:
cairo_program = session.cairo_programs[item.fspath]
cairo_file = get_cairo_file(item.fspath)
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
all_paths.append(dump_path)

while not all([dump_path.exists() for dump_path in all_paths]):
logger.info(
f"Worker {worker_id} with index {worker_index} / {worker_count} waiting for other workers to finish"
)
time.sleep(0.25)

# Select tests
for item in items:
if hasattr(item, "fixturenames") and set(item.fixturenames) & {
"cairo_file",
"main_path",
"cairo_program",
"cairo_run",
}:
if item.fspath not in session.cairo_files:
cairo_file = get_cairo_file(item.fspath)
session.cairo_files[item.fspath] = cairo_file
if item.fspath not in session.main_paths:
main_path = get_main_path(cairo_file)
session.main_paths[item.fspath] = main_path
if item.fspath not in session.cairo_programs:
dump_path = session.build_dir / cairo_file.relative_to(
Path().cwd()
).with_suffix(".json")
cairo_program = get_cairo_program(cairo_file, main_path, dump_path)
session.cairo_programs[item.fspath] = cairo_program

cairo_program = session.cairo_programs[item.fspath]
test_hash = xxhash.xxh64(
program_hash(cairo_program)
+ file_hash(item.fspath)
Expand Down
8 changes: 3 additions & 5 deletions cairo/tests/utils/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from time import perf_counter
from typing import Optional

from filelock import FileLock
from starkware.cairo.lang.compiler.program import Program
from starkware.cairo.lang.compiler.scoped_name import ScopedName

Expand Down Expand Up @@ -81,8 +80,7 @@ def get_cairo_program(cairo_file: Path, main_path, dump_path: Optional[Path] = N
stop = perf_counter()
logger.info(f"{cairo_file} compiled in {stop - start:.2f}s")
if dump_path is not None:
with FileLock(str(dump_path) + ".lock"):
dump_path.write_text(
json.dumps(program.Schema().dump(program), indent=4, sort_keys=True)
)
dump_path.write_text(
json.dumps(program.Schema().dump(program), indent=4, sort_keys=True)
)
return program

0 comments on commit 61c9b89

Please sign in to comment.