From 12122d515683cf143c4375f8a1553084f7602256 Mon Sep 17 00:00:00 2001 From: Taekyung Heo Date: Tue, 19 Nov 2024 06:17:03 -0500 Subject: [PATCH] Reflect Andrei's comments --- .../test_template/nemo_run/slurm_command_gen_strategy.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/cloudai/schema/test_template/nemo_run/slurm_command_gen_strategy.py b/src/cloudai/schema/test_template/nemo_run/slurm_command_gen_strategy.py index d96f056f..6f14a10a 100644 --- a/src/cloudai/schema/test_template/nemo_run/slurm_command_gen_strategy.py +++ b/src/cloudai/schema/test_template/nemo_run/slurm_command_gen_strategy.py @@ -36,12 +36,9 @@ def _parse_slurm_args( return base_args def generate_test_command(self, env_vars: Dict[str, str], cmd_args: Dict[str, str], tr: TestRun) -> List[str]: - command = ["nemo", "llm"] - tdef: NeMoRunTestDefinition = cast(NeMoRunTestDefinition, tr.test.test_definition) - command.append(tdef.cmd_args.task) - command.extend(["--factory", tdef.cmd_args.recipe_name]) - command.append("-y") + + command = ["nemo", "llm", tdef.cmd_args.task, "--factory", tdef.cmd_args.recipe_name, "-y"] if tr.nodes: command.append(f"trainer.num_nodes={len(tr.nodes)}")