diff --git a/src/cloudai/_core/test_scenario_parser.py b/src/cloudai/_core/test_scenario_parser.py index 36fade4f..c96146d2 100644 --- a/src/cloudai/_core/test_scenario_parser.py +++ b/src/cloudai/_core/test_scenario_parser.py @@ -15,8 +15,10 @@ # limitations under the License. import logging +import re +from datetime import timedelta from pathlib import Path -from typing import Any, Dict, Literal, Optional +from typing import Any, Dict, List, Literal, Optional import toml from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator @@ -26,6 +28,73 @@ from .test_scenario import TestDependency, TestRun, TestScenario +def parse_time_limit(limit: str) -> timedelta: + try: + if re.match(r"^\d+[smhdw]$", limit, re.IGNORECASE): + return parse_abbreviated_time(limit) + if "-" in limit: + return parse_dashed_time(limit) + if len(limit.split(":")) == 3: + hours, minutes, seconds = map(int, limit.split(":")) + return timedelta(hours=hours, minutes=minutes, seconds=seconds) + if len(limit.split(":")) == 2: + hours, minutes = map(int, limit.split(":")) + return timedelta(hours=hours, minutes=minutes) + except ValueError as err: + raise ValueError(f"Invalid time limit format: {limit}. Refer to SLURM time format documentation.") from err + + raise ValueError(f"Unsupported time limit format: {limit}. Refer to SLURM time format documentation.") + + +def parse_abbreviated_time(limit: str) -> timedelta: + value, unit = int(limit[:-1]), limit[-1].lower() + if unit == "s": + return timedelta(seconds=value) + if unit == "m": + return timedelta(minutes=value) + if unit == "h": + return timedelta(hours=value) + if unit == "d": + return timedelta(days=value) + if unit == "w": + return timedelta(weeks=value) + raise ValueError(f"Invalid abbreviated time format: {limit}") + + +def parse_dashed_time(limit: str) -> timedelta: + days, time_part = limit.split("-", 1) + hours, minutes, seconds = map(int, time_part.split(":")) + return timedelta(days=int(days), hours=hours, minutes=minutes, seconds=seconds) + + +def format_time_limit(total_time: timedelta) -> str: + total_seconds = int(total_time.total_seconds()) + days, remainder = divmod(total_seconds, 86400) + hours, remainder = divmod(remainder, 3600) + minutes, seconds = divmod(remainder, 60) + if days > 0: + return f"{days}-{hours:02}:{minutes:02}:{seconds:02}" + return f"{hours:02}:{minutes:02}:{seconds:02}" + + +def calculate_total_time_limit( + time_limit: Optional[str] = None, test_hooks: Optional[List[TestScenario]] = None +) -> str: + total_time = timedelta() + + if time_limit: + total_time += parse_time_limit(time_limit) + + if test_hooks: + for hook in test_hooks: + if hook: + for test_run in hook.test_runs: + if test_run.time_limit: + total_time += parse_time_limit(test_run.time_limit) + + return format_time_limit(total_time) + + class _TestDependencyTOML(BaseModel): model_config = ConfigDict(extra="forbid") @@ -217,13 +286,16 @@ def _create_test_run( test = Test(test_definition=original_test.test_definition, test_template=original_test.test_template) + hooks = [hook for hook in [pre_test, post_test] if hook is not None] + total_time_limit = calculate_total_time_limit(time_limit=test_info.time_limit, test_hooks=hooks) + tr = TestRun( test_info.id, test, num_nodes=test_info.num_nodes or 1, iterations=test_info.iterations, nodes=test_info.nodes, - time_limit=test_info.time_limit, + time_limit=total_time_limit, sol=test_info.sol, weight=test_info.weight * normalized_weight, ideal_perf=test_info.ideal_perf, diff --git a/tests/test_test_scenario.py b/tests/test_test_scenario.py index 72639068..9f71687b 100644 --- a/tests/test_test_scenario.py +++ b/tests/test_test_scenario.py @@ -21,7 +21,7 @@ import pytest from cloudai import CmdArgs, Test, TestRun, TestScenarioParser, TestScenarioParsingError -from cloudai._core.test_scenario_parser import _TestScenarioTOML +from cloudai._core.test_scenario_parser import _TestScenarioTOML, calculate_total_time_limit from tests.conftest import MyTestDefinition @@ -91,7 +91,7 @@ def test_with_time_limit(test: Test, test_scenario_parser: TestScenarioParser) - test_scenario = test_scenario_parser._parse_data( {"name": "nccl-test", "Tests": [{"id": "1", "test_name": "nccl", "time_limit": "10m"}]} ) - assert test_scenario.test_runs[0].time_limit == "10m" + assert test_scenario.test_runs[0].time_limit == "00:10:00" def test_two_independent_cases(test: Test, test_scenario_parser: TestScenarioParser) -> None: @@ -202,3 +202,43 @@ def test_test_id_must_contain_at_least_one_letter() -> None: with pytest.raises(ValueError) as exc_info: _TestScenarioTOML.model_validate({"name": "name", "Tests": [{"id": "", "test_name": "nccl"}]}) assert exc_info.match("_TestScenarioTOML\nTests.0.id\n String should have at least 1 character") + + +@pytest.mark.parametrize( + "time_str, expected", + [ + ("10m", "00:10:00"), + ("1h", "01:00:00"), + ("2d", "2-00:00:00"), + ("1w", "7-00:00:00"), + ("30s", "00:00:30"), + ("1-12:30:45", "1-12:30:45"), + ("12:30:45", "12:30:45"), + ("12:30", "12:30:00"), + ], +) +def test_calculate_total_time_limit(time_str, expected): + assert calculate_total_time_limit(time_limit=time_str) == expected + + +def test_create_test_run_with_hooks(test: Test, test_scenario_parser: TestScenarioParser): + pre_test = Mock( + test_runs=[TestRun(name="pre1", test=test, num_nodes=1, nodes=[], time_limit="00:30:00", iterations=1)] + ) + post_test = Mock( + test_runs=[TestRun(name="post1", test=test, num_nodes=1, nodes=[], time_limit="00:20:00", iterations=1)] + ) + + test_info = Mock(id="main1", test_name="test1", time_limit="01:00:00", weight=10, iterations=1, num_nodes=1) + test_scenario_parser.test_mapping = {"test1": test} + + test_run = test_scenario_parser._create_test_run( + test_info=test_info, normalized_weight=1.0, pre_test=pre_test, post_test=post_test + ) + + assert test_run.time_limit == "01:50:00" # Main + pre + post hooks + + +def test_total_time_limit_with_empty_hooks(): + result = calculate_total_time_limit("01:00:00", test_hooks=[]) + assert result == "01:00:00"