Skip to content

Commit

Permalink
Fix time limit calculation to include pre and post test hook durations (
Browse files Browse the repository at this point in the history
#345)

* Fix time limit calculation to include pre and post test hook durations

* Fix time limit format in test_with_time_limit

* Add tests for Slurm time parsing and total time limit calculation

* Merge and simplify Slurm time format tests
  • Loading branch information
TaekyungHeo authored Jan 21, 2025
1 parent 36e3993 commit 2c26087
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 4 deletions.
76 changes: 74 additions & 2 deletions src/cloudai/_core/test_scenario_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# limitations under the License.

import logging
import re
from datetime import timedelta
from pathlib import Path
from typing import Any, Dict, Literal, Optional
from typing import Any, Dict, List, Literal, Optional

import toml
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
Expand All @@ -26,6 +28,73 @@
from .test_scenario import TestDependency, TestRun, TestScenario


def parse_time_limit(limit: str) -> timedelta:
try:
if re.match(r"^\d+[smhdw]$", limit, re.IGNORECASE):
return parse_abbreviated_time(limit)
if "-" in limit:
return parse_dashed_time(limit)
if len(limit.split(":")) == 3:
hours, minutes, seconds = map(int, limit.split(":"))
return timedelta(hours=hours, minutes=minutes, seconds=seconds)
if len(limit.split(":")) == 2:
hours, minutes = map(int, limit.split(":"))
return timedelta(hours=hours, minutes=minutes)
except ValueError as err:
raise ValueError(f"Invalid time limit format: {limit}. Refer to SLURM time format documentation.") from err

raise ValueError(f"Unsupported time limit format: {limit}. Refer to SLURM time format documentation.")


def parse_abbreviated_time(limit: str) -> timedelta:
value, unit = int(limit[:-1]), limit[-1].lower()
if unit == "s":
return timedelta(seconds=value)
if unit == "m":
return timedelta(minutes=value)
if unit == "h":
return timedelta(hours=value)
if unit == "d":
return timedelta(days=value)
if unit == "w":
return timedelta(weeks=value)
raise ValueError(f"Invalid abbreviated time format: {limit}")


def parse_dashed_time(limit: str) -> timedelta:
days, time_part = limit.split("-", 1)
hours, minutes, seconds = map(int, time_part.split(":"))
return timedelta(days=int(days), hours=hours, minutes=minutes, seconds=seconds)


def format_time_limit(total_time: timedelta) -> str:
total_seconds = int(total_time.total_seconds())
days, remainder = divmod(total_seconds, 86400)
hours, remainder = divmod(remainder, 3600)
minutes, seconds = divmod(remainder, 60)
if days > 0:
return f"{days}-{hours:02}:{minutes:02}:{seconds:02}"
return f"{hours:02}:{minutes:02}:{seconds:02}"


def calculate_total_time_limit(
time_limit: Optional[str] = None, test_hooks: Optional[List[TestScenario]] = None
) -> str:
total_time = timedelta()

if time_limit:
total_time += parse_time_limit(time_limit)

if test_hooks:
for hook in test_hooks:
if hook:
for test_run in hook.test_runs:
if test_run.time_limit:
total_time += parse_time_limit(test_run.time_limit)

return format_time_limit(total_time)


class _TestDependencyTOML(BaseModel):
model_config = ConfigDict(extra="forbid")

Expand Down Expand Up @@ -217,13 +286,16 @@ def _create_test_run(

test = Test(test_definition=original_test.test_definition, test_template=original_test.test_template)

hooks = [hook for hook in [pre_test, post_test] if hook is not None]
total_time_limit = calculate_total_time_limit(time_limit=test_info.time_limit, test_hooks=hooks)

tr = TestRun(
test_info.id,
test,
num_nodes=test_info.num_nodes or 1,
iterations=test_info.iterations,
nodes=test_info.nodes,
time_limit=test_info.time_limit,
time_limit=total_time_limit,
sol=test_info.sol,
weight=test_info.weight * normalized_weight,
ideal_perf=test_info.ideal_perf,
Expand Down
44 changes: 42 additions & 2 deletions tests/test_test_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pytest

from cloudai import CmdArgs, Test, TestRun, TestScenarioParser, TestScenarioParsingError
from cloudai._core.test_scenario_parser import _TestScenarioTOML
from cloudai._core.test_scenario_parser import _TestScenarioTOML, calculate_total_time_limit
from tests.conftest import MyTestDefinition


Expand Down Expand Up @@ -91,7 +91,7 @@ def test_with_time_limit(test: Test, test_scenario_parser: TestScenarioParser) -
test_scenario = test_scenario_parser._parse_data(
{"name": "nccl-test", "Tests": [{"id": "1", "test_name": "nccl", "time_limit": "10m"}]}
)
assert test_scenario.test_runs[0].time_limit == "10m"
assert test_scenario.test_runs[0].time_limit == "00:10:00"


def test_two_independent_cases(test: Test, test_scenario_parser: TestScenarioParser) -> None:
Expand Down Expand Up @@ -202,3 +202,43 @@ def test_test_id_must_contain_at_least_one_letter() -> None:
with pytest.raises(ValueError) as exc_info:
_TestScenarioTOML.model_validate({"name": "name", "Tests": [{"id": "", "test_name": "nccl"}]})
assert exc_info.match("_TestScenarioTOML\nTests.0.id\n String should have at least 1 character")


@pytest.mark.parametrize(
"time_str, expected",
[
("10m", "00:10:00"),
("1h", "01:00:00"),
("2d", "2-00:00:00"),
("1w", "7-00:00:00"),
("30s", "00:00:30"),
("1-12:30:45", "1-12:30:45"),
("12:30:45", "12:30:45"),
("12:30", "12:30:00"),
],
)
def test_calculate_total_time_limit(time_str, expected):
assert calculate_total_time_limit(time_limit=time_str) == expected


def test_create_test_run_with_hooks(test: Test, test_scenario_parser: TestScenarioParser):
pre_test = Mock(
test_runs=[TestRun(name="pre1", test=test, num_nodes=1, nodes=[], time_limit="00:30:00", iterations=1)]
)
post_test = Mock(
test_runs=[TestRun(name="post1", test=test, num_nodes=1, nodes=[], time_limit="00:20:00", iterations=1)]
)

test_info = Mock(id="main1", test_name="test1", time_limit="01:00:00", weight=10, iterations=1, num_nodes=1)
test_scenario_parser.test_mapping = {"test1": test}

test_run = test_scenario_parser._create_test_run(
test_info=test_info, normalized_weight=1.0, pre_test=pre_test, post_test=post_test
)

assert test_run.time_limit == "01:50:00" # Main + pre + post hooks


def test_total_time_limit_with_empty_hooks():
result = calculate_total_time_limit("01:00:00", test_hooks=[])
assert result == "01:00:00"

0 comments on commit 2c26087

Please sign in to comment.