Skip to content

Commit

Permalink
Merge branch 'main' into chakra-replay-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo authored Nov 20, 2024
2 parents 23cad7b + ec5a384 commit 848d5f9
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 16 deletions.
File renamed without changes.
16 changes: 16 additions & 0 deletions tests/ref_data/nemo-run-pre-test.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash
#SBATCH --job-name=__JOB_NAME__
#SBATCH -N 1
#SBATCH --output=__OUTPUT_DIR__/output/stdout.txt
#SBATCH --error=__OUTPUT_DIR__/output/stderr.txt
#SBATCH --partition=main

export SLURM_JOB_MASTER_NODE=$(scontrol show hostname $SLURM_JOB_NODELIST | head -n 1)


srun --output=__OUTPUT_DIR__/output/pre_test/nccl/stdout.txt --error=__OUTPUT_DIR__/output/pre_test/nccl/stderr.txt --mpi=pmix --container-image=nvcr.io/nvidia/pytorch:24.02-py3 /usr/local/bin/all_reduce_perf_mpi --nthreads 1 --ngpus 1 --minbytes 32M --maxbytes 32M --stepbytes 1M --op sum --datatype float --root 0 --iters 20 --warmup_iters 5 --agg_iters 1 --average 1 --parallel_init 0 --check 1 --blocking 0 --cudagraph 0
SUCCESS_0=$(grep -q "Avg bus bandwidth" __OUTPUT_DIR__/output/pre_test/nccl/stdout.txt && echo 1 || echo 0)
PRE_TEST_SUCCESS=$( [ $SUCCESS_0 -eq 1 ] && echo 1 || echo 0 )
if [ $PRE_TEST_SUCCESS -eq 1 ]; then
srun --mpi=pmix --container-image=nvcr.io/nvidia/nemo:24.09 nemo llm pretrain --factory llama_3b -y trainer.num_nodes=1
fi
39 changes: 23 additions & 16 deletions tests/test_acceptance.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def partial_tr(slurm_system: SlurmSystem) -> partial[TestRun]:
"grok-pre-test",
"grok-no-hook",
"nemo-launcher",
"nemo-run",
"nemo-run-pre-test",
"nemo-run-no-hook",
"slurm_container",
]
)
Expand Down Expand Up @@ -162,19 +163,6 @@ def create_test_run(name, test_definition, test_template, command_gen_strategy):
NeMoLauncher,
NeMoLauncherSlurmCommandGenStrategy,
),
"nemo-run": lambda: create_test_run(
"nemo-run",
NeMoRunTestDefinition(
name="nemo-run",
description="nemo-run",
test_template_name="nemo-run",
cmd_args=NeMoRunCmdArgs(
docker_image_url="nvcr.io/nvidia/nemo:24.09", task="pretrain", recipe_name="llama_3b"
),
),
NeMoRun,
NeMoRunSlurmCommandGenStrategy,
),
"slurm_container": lambda: create_test_run(
"slurm_container",
SlurmContainerTestDefinition(
Expand All @@ -196,7 +184,7 @@ def create_test_run(name, test_definition, test_template, command_gen_strategy):
}

# Special cases for gpt and grok
if request.param.startswith("gpt-") or request.param.startswith("grok-"):
if request.param.startswith("gpt-") or request.param.startswith("grok-") or request.param.startswith("nemo-run-"):
if "gpt" in request.param:
test_type = "gpt"
tr = create_test_run(
Expand All @@ -211,6 +199,7 @@ def create_test_run(name, test_definition, test_template, command_gen_strategy):
JaxToolbox,
JaxToolboxSlurmCommandGenStrategy,
)

elif "grok" in request.param:
test_type = "grok"
tr = create_test_run(
Expand All @@ -225,6 +214,21 @@ def create_test_run(name, test_definition, test_template, command_gen_strategy):
JaxToolbox,
JaxToolboxSlurmCommandGenStrategy,
)
elif "nemo-run" in request.param:
test_type = "nemo-run"
tr = create_test_run(
test_type,
NeMoRunTestDefinition(
name=test_type,
description=test_type,
test_template_name=test_type,
cmd_args=NeMoRunCmdArgs(
docker_image_url="nvcr.io/nvidia/nemo:24.09", task="pretrain", recipe_name="llama_3b"
),
),
NeMoRun,
NeMoRunSlurmCommandGenStrategy,
)
else:
raise ValueError(f"Unknown test type: {request.param}")

Expand All @@ -233,7 +237,10 @@ def create_test_run(name, test_definition, test_template, command_gen_strategy):
pre_test_tr = test_mapping["nccl"]()
tr.pre_test = TestScenario(name=f"{pre_test_tr.name} NCCL pre-test", test_runs=[pre_test_tr])

return (tr, f"{request.param}.sbatch", f"{test_type}.run")
if test_type == "nemo-run":
return (tr, f"{request.param}.sbatch", None)
else:
return (tr, f"{request.param}.sbatch", f"{test_type}.run")

# Default handler for simple mappings
if request.param in test_mapping:
Expand Down

0 comments on commit 848d5f9

Please sign in to comment.