Skip to content

Commit

Permalink
Merge pull request #294 from NVIDIA/am/tr-everywhere
Browse files Browse the repository at this point in the history
Pass TestRun object into cmd_gen related functions
  • Loading branch information
amaslenn authored Oct 30, 2024
2 parents cb153f5 + 8754357 commit f2a82a5
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,13 @@ def _parse_slurm_args(

return base_args

def generate_test_command(
self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
) -> List[str]:
def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]:
srun_command_parts = [
"python /workspace/param/train/comms/pt/commsTraceReplay.py",
f'--trace-type {cmd_args["trace_type"]}',
f'--trace-path {cmd_args["trace_path"]}',
f'--backend {cmd_args["backend"]}',
f'--device {cmd_args["device"]}',
extra_cmd_args,
tr.test.extra_cmd_args,
]
return srun_command_parts
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ def _parse_slurm_args(
return base_args

def generate_srun_command(
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, Any], extra_cmd_args: str
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, Any], tr: TestRun
) -> str:
self._create_run_script(env_vars, cmd_args, extra_cmd_args)
self._create_run_script(env_vars, cmd_args, tr.test.extra_cmd_args)

commands = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,7 @@ def _parse_slurm_args(

return base_args

def generate_test_command(
self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
) -> List[str]:
def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]:
srun_command_parts = [f"/usr/local/bin/{cmd_args['subtest_name']}"]
nccl_test_args = [
"nthreads",
Expand All @@ -69,7 +67,7 @@ def generate_test_command(
if arg in cmd_args:
srun_command_parts.append(f"--{arg} {cmd_args[arg]}")

if extra_cmd_args:
srun_command_parts.append(extra_cmd_args)
if tr.test.extra_cmd_args:
srun_command_parts.append(tr.test.extra_cmd_args)

return srun_command_parts
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@

from typing import Dict, List

from cloudai import TestRun
from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy


class SleepSlurmCommandGenStrategy(SlurmCommandGenStrategy):
"""Command generation strategy for Sleep on Slurm systems."""

def generate_test_command(
self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
) -> List[str]:
def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]:
return [f'sleep {cmd_args["seconds"]}']
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def _parse_slurm_args(

return base_args

def generate_test_command(
self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
) -> List[str]:
def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]:
srun_command_parts = ["/opt/hpcx/ucc/bin/ucc_perftest"]

# Add collective, minimum bytes, and maximum bytes options if available
Expand All @@ -52,7 +50,7 @@ def generate_test_command(
srun_command_parts.append("-F")

# Append any extra command-line arguments provided
if extra_cmd_args:
srun_command_parts.append(extra_cmd_args)
if tr.test.extra_cmd_args:
srun_command_parts.append(tr.test.extra_cmd_args)

return srun_command_parts
24 changes: 11 additions & 13 deletions src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,13 @@ def job_name(self, job_name_prefix: str) -> str:
return job_name

def generate_srun_command(
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun
) -> str:
srun_command_parts = self.generate_srun_prefix(slurm_args)
test_command_parts = self.generate_test_command(env_vars, cmd_args, extra_cmd_args)
srun_command_parts = self.generate_srun_prefix(slurm_args, tr)
test_command_parts = self.generate_test_command(env_vars, cmd_args, tr)
return " \\\n".join(srun_command_parts + test_command_parts)

def generate_srun_prefix(self, slurm_args: Dict[str, Any]) -> List[str]:
def generate_srun_prefix(self, slurm_args: Dict[str, Any], tr: TestRun) -> List[str]:
srun_command_parts = ["srun", f"--mpi={self.system.mpi}"]
if slurm_args.get("image_path"):
srun_command_parts.append(f'--container-image={slurm_args["image_path"]}')
Expand All @@ -131,12 +131,10 @@ def gen_exec_command(self, tr: TestRun) -> str:
env_vars = self._override_env_vars(self.system.global_env_vars, tr.test.extra_env_vars)
cmd_args = self._override_cmd_args(self.default_cmd_args, tr.test.cmd_args)
slurm_args = self._parse_slurm_args(tr.test.test_template.__class__.__name__, env_vars, cmd_args, tr)
srun_command = self.generate_srun_command(slurm_args, env_vars, cmd_args, tr.test.extra_cmd_args)
return self._write_sbatch_script(slurm_args, env_vars, srun_command, tr.output_path)
srun_command = self.generate_srun_command(slurm_args, env_vars, cmd_args, tr)
return self._write_sbatch_script(slurm_args, env_vars, srun_command, tr)

def generate_test_command(
self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str
) -> List[str]:
def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]:
return []

def _add_reservation(self, batch_script_content: List[str]):
Expand All @@ -157,7 +155,7 @@ def _add_reservation(self, batch_script_content: List[str]):
return batch_script_content

def _write_sbatch_script(
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], srun_command: str, output_path: Path
self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], srun_command: str, tr: TestRun
) -> str:
"""
Write the batch script for Slurm submission and return the sbatch command.
Expand All @@ -166,7 +164,7 @@ def _write_sbatch_script(
slurm_args (Dict[str, Any]): Slurm-specific arguments.
env_vars (env_vars: Dict[str, str]): Environment variables.
srun_command (str): srun command.
output_path (Path): Output directory for script and logs.
tr (TestRun): Test run object.
Returns:
str: sbatch command to submit the job.
Expand All @@ -177,12 +175,12 @@ def _write_sbatch_script(
f"#SBATCH -N {slurm_args['num_nodes']}",
]

self._append_sbatch_directives(batch_script_content, slurm_args, output_path)
self._append_sbatch_directives(batch_script_content, slurm_args, tr.output_path)

env_vars_str = self._format_env_vars(env_vars)
batch_script_content.extend([env_vars_str, "", srun_command])

batch_script_path = output_path / "cloudai_sbatch_script.sh"
batch_script_path = tr.output_path / "cloudai_sbatch_script.sh"
with batch_script_path.open("w") as batch_file:
batch_file.write("\n".join(batch_script_content))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cloudai import TestRun
from cloudai.schema.test_template.chakra_replay.slurm_command_gen_strategy import ChakraReplaySlurmCommandGenStrategy
from cloudai.systems import SlurmSystem
from tests.conftest import create_autospec_dataclass


class TestChakraReplaySlurmCommandGenStrategy:
Expand Down Expand Up @@ -123,16 +124,20 @@ def test_generate_test_command(
expected_result: List[str],
slurm_system: SlurmSystem,
) -> None:
command = cmd_gen_strategy.generate_test_command({}, cmd_args, extra_cmd_args)
tr = create_autospec_dataclass(TestRun)
tr.test.extra_cmd_args = extra_cmd_args
command = cmd_gen_strategy.generate_test_command({}, cmd_args, tr)
assert command == expected_result

def test_generate_test_command_invalid_args(
self, cmd_gen_strategy: ChakraReplaySlurmCommandGenStrategy, slurm_system: SlurmSystem
) -> None:
cmd_args: Dict[str, str] = {"trace_type": "comms_trace", "backend": "nccl", "device": "gpu"}
extra_cmd_args: str = "--max-steps 100"

tr = create_autospec_dataclass(TestRun)
tr.test.extra_cmd_args = "--max-steps 100"

with pytest.raises(KeyError) as exc_info:
cmd_gen_strategy.generate_test_command({}, cmd_args, extra_cmd_args)
cmd_gen_strategy.generate_test_command({}, cmd_args, tr)

assert str(exc_info.value) == "'trace_path'", "Expected missing trace_path key"
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,9 @@ def test_filename_generation(strategy_fixture: SlurmCommandGenStrategy, testrun_
env_vars = {"TEST_VAR": "VALUE"}
cmd_args = {"test_arg": "test_value"}
slurm_args = strategy_fixture._parse_slurm_args(job_name_prefix, env_vars, cmd_args, testrun_fixture)
srun_command = strategy_fixture.generate_srun_command(slurm_args, env_vars, cmd_args, "")
srun_command = strategy_fixture.generate_srun_command(slurm_args, env_vars, cmd_args, testrun_fixture)

sbatch_command = strategy_fixture._write_sbatch_script(
slurm_args, env_vars, srun_command, testrun_fixture.output_path
)
sbatch_command = strategy_fixture._write_sbatch_script(slurm_args, env_vars, srun_command, testrun_fixture)
filepath_from_command = sbatch_command.split()[-1]

assert testrun_fixture.output_path.joinpath("cloudai_sbatch_script.sh").exists()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

from typing import Any, Dict, List
from unittest.mock import Mock

import pytest

Expand Down Expand Up @@ -102,5 +103,7 @@ def test_generate_test_command(
expected_command: List[str],
) -> None:
env_vars = {}
command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, extra_cmd_args)
tr = Mock()
tr.test.extra_cmd_args = extra_cmd_args
command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, tr)
assert command == expected_command
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

from typing import Dict, List
from unittest.mock import Mock

import pytest

Expand All @@ -41,6 +42,5 @@ def test_generate_test_command(
expected_command: List[str],
) -> None:
env_vars = {}
extra_cmd_args = ""
command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, extra_cmd_args)
command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, Mock())
assert command == expected_command
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

from typing import Dict, List
from unittest.mock import Mock

import pytest

Expand Down Expand Up @@ -64,5 +65,7 @@ def test_generate_test_command(
expected_command: List[str],
) -> None:
env_vars = {}
command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, extra_cmd_args)
tr = Mock()
tr.test.extra_cmd_args = extra_cmd_args
command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, tr)
assert command == expected_command

0 comments on commit f2a82a5

Please sign in to comment.