From 517b90e70884b4bc6a013fa2941d3d8031a9f8c9 Mon Sep 17 00:00:00 2001 From: Andrey Maslennikov Date: Fri, 25 Oct 2024 17:07:16 +0200 Subject: [PATCH] Pass TestRun into all cmd_gen related functions --- .../slurm_command_gen_strategy.py | 6 ++--- .../jax_toolbox/slurm_command_gen_strategy.py | 4 ++-- .../nccl_test/slurm_command_gen_strategy.py | 8 +++---- .../sleep/slurm_command_gen_strategy.py | 5 ++-- .../ucc_test/slurm_command_gen_strategy.py | 8 +++---- .../strategy/slurm_command_gen_strategy.py | 24 +++++++++---------- ...hakra_replay_slurm_command_gen_strategy.py | 11 ++++++--- .../test_common_slurm_command_gen_strategy.py | 6 ++--- .../test_nccl_slurm_command_gen_strategy.py | 5 +++- .../test_sleep_slurm_command_gen_strategy.py | 4 ++-- .../test_ucc_slurm_command_gen_strategy.py | 5 +++- 11 files changed, 43 insertions(+), 43 deletions(-) diff --git a/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py index bdcad07ce..2ace8e149 100644 --- a/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/chakra_replay/slurm_command_gen_strategy.py @@ -35,15 +35,13 @@ def _parse_slurm_args( return base_args - def generate_test_command( - self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str - ) -> List[str]: + def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]: srun_command_parts = [ "python /workspace/param/train/comms/pt/commsTraceReplay.py", f'--trace-type {cmd_args["trace_type"]}', f'--trace-path {cmd_args["trace_path"]}', f'--backend {cmd_args["backend"]}', f'--device {cmd_args["device"]}', - extra_cmd_args, + tr.test.extra_cmd_args, ] return srun_command_parts diff --git a/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py index ff70b5c49..9e20aae52 100644 --- a/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/jax_toolbox/slurm_command_gen_strategy.py @@ -147,9 +147,9 @@ def _parse_slurm_args( return base_args def generate_srun_command( - self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, Any], extra_cmd_args: str + self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, Any], tr: TestRun ) -> str: - self._create_run_script(env_vars, cmd_args, extra_cmd_args) + self._create_run_script(env_vars, cmd_args, tr.test.extra_cmd_args) commands = [] diff --git a/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py index 8805202c0..d1d8d5fc8 100644 --- a/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/nccl_test/slurm_command_gen_strategy.py @@ -43,9 +43,7 @@ def _parse_slurm_args( return base_args - def generate_test_command( - self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str - ) -> List[str]: + def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]: srun_command_parts = [f"/usr/local/bin/{cmd_args['subtest_name']}"] nccl_test_args = [ "nthreads", @@ -69,7 +67,7 @@ def generate_test_command( if arg in cmd_args: srun_command_parts.append(f"--{arg} {cmd_args[arg]}") - if extra_cmd_args: - srun_command_parts.append(extra_cmd_args) + if tr.test.extra_cmd_args: + srun_command_parts.append(tr.test.extra_cmd_args) return srun_command_parts diff --git a/src/cloudai/schema/test_template/sleep/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/sleep/slurm_command_gen_strategy.py index f90607256..5b732655c 100644 --- a/src/cloudai/schema/test_template/sleep/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/sleep/slurm_command_gen_strategy.py @@ -16,13 +16,12 @@ from typing import Dict, List +from cloudai import TestRun from cloudai.systems.slurm.strategy import SlurmCommandGenStrategy class SleepSlurmCommandGenStrategy(SlurmCommandGenStrategy): """Command generation strategy for Sleep on Slurm systems.""" - def generate_test_command( - self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str - ) -> List[str]: + def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]: return [f'sleep {cmd_args["seconds"]}'] diff --git a/src/cloudai/schema/test_template/ucc_test/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/ucc_test/slurm_command_gen_strategy.py index 75c0fc1cc..ebb07840e 100644 --- a/src/cloudai/schema/test_template/ucc_test/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/ucc_test/slurm_command_gen_strategy.py @@ -34,9 +34,7 @@ def _parse_slurm_args( return base_args - def generate_test_command( - self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str - ) -> List[str]: + def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]: srun_command_parts = ["/opt/hpcx/ucc/bin/ucc_perftest"] # Add collective, minimum bytes, and maximum bytes options if available @@ -52,7 +50,7 @@ def generate_test_command( srun_command_parts.append("-F") # Append any extra command-line arguments provided - if extra_cmd_args: - srun_command_parts.append(extra_cmd_args) + if tr.test.extra_cmd_args: + srun_command_parts.append(tr.test.extra_cmd_args) return srun_command_parts diff --git a/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py b/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py index 4a052a47b..7c5f69c3f 100644 --- a/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py +++ b/src/cloudai/systems/slurm/strategy/slurm_command_gen_strategy.py @@ -109,13 +109,13 @@ def job_name(self, job_name_prefix: str) -> str: return job_name def generate_srun_command( - self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str + self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun ) -> str: - srun_command_parts = self.generate_srun_prefix(slurm_args) - test_command_parts = self.generate_test_command(env_vars, cmd_args, extra_cmd_args) + srun_command_parts = self.generate_srun_prefix(slurm_args, tr) + test_command_parts = self.generate_test_command(env_vars, cmd_args, tr) return " \\\n".join(srun_command_parts + test_command_parts) - def generate_srun_prefix(self, slurm_args: Dict[str, Any]) -> List[str]: + def generate_srun_prefix(self, slurm_args: Dict[str, Any], tr: TestRun) -> List[str]: srun_command_parts = ["srun", f"--mpi={self.system.mpi}"] if slurm_args.get("image_path"): srun_command_parts.append(f'--container-image={slurm_args["image_path"]}') @@ -131,12 +131,10 @@ def gen_exec_command(self, tr: TestRun) -> str: env_vars = self._override_env_vars(self.system.global_env_vars, tr.test.extra_env_vars) cmd_args = self._override_cmd_args(self.default_cmd_args, tr.test.cmd_args) slurm_args = self._parse_slurm_args(tr.test.test_template.__class__.__name__, env_vars, cmd_args, tr) - srun_command = self.generate_srun_command(slurm_args, env_vars, cmd_args, tr.test.extra_cmd_args) - return self._write_sbatch_script(slurm_args, env_vars, srun_command, tr.output_path) + srun_command = self.generate_srun_command(slurm_args, env_vars, cmd_args, tr) + return self._write_sbatch_script(slurm_args, env_vars, srun_command, tr) - def generate_test_command( - self, env_vars: Dict[str, str], cmd_args: Dict[str, str], extra_cmd_args: str - ) -> List[str]: + def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]: return [] def _add_reservation(self, batch_script_content: List[str]): @@ -157,7 +155,7 @@ def _add_reservation(self, batch_script_content: List[str]): return batch_script_content def _write_sbatch_script( - self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], srun_command: str, output_path: Path + self, slurm_args: Dict[str, Any], env_vars: Dict[str, str], srun_command: str, tr: TestRun ) -> str: """ Write the batch script for Slurm submission and return the sbatch command. @@ -166,7 +164,7 @@ def _write_sbatch_script( slurm_args (Dict[str, Any]): Slurm-specific arguments. env_vars (env_vars: Dict[str, str]): Environment variables. srun_command (str): srun command. - output_path (Path): Output directory for script and logs. + tr (TestRun): Test run object. Returns: str: sbatch command to submit the job. @@ -177,12 +175,12 @@ def _write_sbatch_script( f"#SBATCH -N {slurm_args['num_nodes']}", ] - self._append_sbatch_directives(batch_script_content, slurm_args, output_path) + self._append_sbatch_directives(batch_script_content, slurm_args, tr.output_path) env_vars_str = self._format_env_vars(env_vars) batch_script_content.extend([env_vars_str, "", srun_command]) - batch_script_path = output_path / "cloudai_sbatch_script.sh" + batch_script_path = tr.output_path / "cloudai_sbatch_script.sh" with batch_script_path.open("w") as batch_file: batch_file.write("\n".join(batch_script_content)) diff --git a/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py index 701c234c4..d4c88034c 100644 --- a/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_chakra_replay_slurm_command_gen_strategy.py @@ -22,6 +22,7 @@ from cloudai import TestRun from cloudai.schema.test_template.chakra_replay.slurm_command_gen_strategy import ChakraReplaySlurmCommandGenStrategy from cloudai.systems import SlurmSystem +from tests.conftest import create_autospec_dataclass class TestChakraReplaySlurmCommandGenStrategy: @@ -123,16 +124,20 @@ def test_generate_test_command( expected_result: List[str], slurm_system: SlurmSystem, ) -> None: - command = cmd_gen_strategy.generate_test_command({}, cmd_args, extra_cmd_args) + tr = create_autospec_dataclass(TestRun) + tr.test.extra_cmd_args = extra_cmd_args + command = cmd_gen_strategy.generate_test_command({}, cmd_args, tr) assert command == expected_result def test_generate_test_command_invalid_args( self, cmd_gen_strategy: ChakraReplaySlurmCommandGenStrategy, slurm_system: SlurmSystem ) -> None: cmd_args: Dict[str, str] = {"trace_type": "comms_trace", "backend": "nccl", "device": "gpu"} - extra_cmd_args: str = "--max-steps 100" + + tr = create_autospec_dataclass(TestRun) + tr.test.extra_cmd_args = "--max-steps 100" with pytest.raises(KeyError) as exc_info: - cmd_gen_strategy.generate_test_command({}, cmd_args, extra_cmd_args) + cmd_gen_strategy.generate_test_command({}, cmd_args, tr) assert str(exc_info.value) == "'trace_path'", "Expected missing trace_path key" diff --git a/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py index 37d6a962e..87927f9c7 100644 --- a/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_common_slurm_command_gen_strategy.py @@ -54,11 +54,9 @@ def test_filename_generation(strategy_fixture: SlurmCommandGenStrategy, testrun_ env_vars = {"TEST_VAR": "VALUE"} cmd_args = {"test_arg": "test_value"} slurm_args = strategy_fixture._parse_slurm_args(job_name_prefix, env_vars, cmd_args, testrun_fixture) - srun_command = strategy_fixture.generate_srun_command(slurm_args, env_vars, cmd_args, "") + srun_command = strategy_fixture.generate_srun_command(slurm_args, env_vars, cmd_args, testrun_fixture) - sbatch_command = strategy_fixture._write_sbatch_script( - slurm_args, env_vars, srun_command, testrun_fixture.output_path - ) + sbatch_command = strategy_fixture._write_sbatch_script(slurm_args, env_vars, srun_command, testrun_fixture) filepath_from_command = sbatch_command.split()[-1] assert testrun_fixture.output_path.joinpath("cloudai_sbatch_script.sh").exists() diff --git a/tests/slurm_command_gen_strategy/test_nccl_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_nccl_slurm_command_gen_strategy.py index bfc5820d7..4c151469f 100644 --- a/tests/slurm_command_gen_strategy/test_nccl_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_nccl_slurm_command_gen_strategy.py @@ -15,6 +15,7 @@ # limitations under the License. from typing import Any, Dict, List +from unittest.mock import Mock import pytest @@ -102,5 +103,7 @@ def test_generate_test_command( expected_command: List[str], ) -> None: env_vars = {} - command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, extra_cmd_args) + tr = Mock() + tr.test.extra_cmd_args = extra_cmd_args + command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, tr) assert command == expected_command diff --git a/tests/slurm_command_gen_strategy/test_sleep_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_sleep_slurm_command_gen_strategy.py index 4561c928b..3276fa667 100644 --- a/tests/slurm_command_gen_strategy/test_sleep_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_sleep_slurm_command_gen_strategy.py @@ -15,6 +15,7 @@ # limitations under the License. from typing import Dict, List +from unittest.mock import Mock import pytest @@ -41,6 +42,5 @@ def test_generate_test_command( expected_command: List[str], ) -> None: env_vars = {} - extra_cmd_args = "" - command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, extra_cmd_args) + command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, Mock()) assert command == expected_command diff --git a/tests/slurm_command_gen_strategy/test_ucc_slurm_command_gen_strategy.py b/tests/slurm_command_gen_strategy/test_ucc_slurm_command_gen_strategy.py index fe2eae036..aa14b1c84 100644 --- a/tests/slurm_command_gen_strategy/test_ucc_slurm_command_gen_strategy.py +++ b/tests/slurm_command_gen_strategy/test_ucc_slurm_command_gen_strategy.py @@ -15,6 +15,7 @@ # limitations under the License. from typing import Dict, List +from unittest.mock import Mock import pytest @@ -64,5 +65,7 @@ def test_generate_test_command( expected_command: List[str], ) -> None: env_vars = {} - command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, extra_cmd_args) + tr = Mock() + tr.test.extra_cmd_args = extra_cmd_args + command = cmd_gen_strategy.generate_test_command(env_vars, cmd_args, tr) assert command == expected_command