Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysonfrancis committed Jan 4, 2025
1 parent 0553659 commit de1000e
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 36 deletions.
16 changes: 7 additions & 9 deletions tests/unit_tests/test_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def test_job_config_file_cmd_overrides(self):
def test_parse_pp_split_points(self):

toml_splits = ["layers.2", "layers.4", "layers.6"]
toml_split_str = ",".join(toml_splits)
cmdline_splits = ["layers.1", "layers.3", "layers.5"]
cmdline_split_str = ",".join(cmdline_splits)
# no split points specified
config = JobConfig()
config.parse_args(
[
Expand All @@ -68,7 +65,7 @@ def test_parse_pp_split_points(self):
"--job.config_file",
"./train_configs/debug_model.toml",
"--experimental.pipeline_parallel_split_points",
f"{cmdline_split_str}",
*cmdline_splits,
]
)
assert (
Expand All @@ -81,7 +78,7 @@ def test_parse_pp_split_points(self):
tomli_w.dump(
{
"experimental": {
"pipeline_parallel_split_points": toml_split_str,
"pipeline_parallel_split_points": toml_splits,
}
},
f,
Expand All @@ -98,7 +95,7 @@ def test_parse_pp_split_points(self):
tomli_w.dump(
{
"experimental": {
"pipeline_parallel_split_points": toml_split_str,
"pipeline_parallel_split_points": toml_splits,
}
},
f,
Expand All @@ -109,14 +106,15 @@ def test_parse_pp_split_points(self):
"--job.config_file",
fp.name,
"--experimental.pipeline_parallel_split_points",
f"{cmdline_split_str}",
*cmdline_splits,
]
)
assert (
config.experimental.pipeline_parallel_split_points == cmdline_splits
), config.experimental.pipeline_parallel_split_points

def test_print_help(self):
config = JobConfig()
parser = config.parser
from tyro.extras import get_parser

parser = get_parser(JobConfig)
parser.print_help()
43 changes: 16 additions & 27 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import sys
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Any, Callable, Dict, List, Optional, Union

Expand Down Expand Up @@ -221,7 +223,7 @@ class Experimental:
"""
Specify comma-separated names of modules to use as the beginning of a split point.
e.g. "layers.0,layers.2" will cause the model to be split into 3 stages,
e.g. "layers.0" "layers.2" will cause the model to be split into 3 stages,
the first containing all the layers up to layers.0,
the second containing layers.0 and up to layers.2,
the third containing layers.2 and all the remaining layers.
Expand Down Expand Up @@ -444,25 +446,28 @@ class JobConfig:
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

def _update(self, instance: "JobConfig") -> None:
def _update(self, instance) -> None:
for f in fields(self):
setattr(self, f.name, getattr(instance, f.name, getattr(self, f.name)))

def parse_args(self) -> None:
def find_config_file(self, argv):
config_flags = ("--job.config-file", "--job.config_file")
for i, arg in enumerate(argv[:-1]):
if arg in config_flags:
return argv[i + 1]
return None

def parse_args(self, args=None) -> None:
"""
Parse CLI arguments, optionally load from a TOML file,
merge with defaults, and return a JobConfig instance.
"""
defaults = tyro.cli(self.__class__)
config_file = defaults.job.config_file
config = self.__class__ # initialize with defaults
config_file = self.find_config_file(args if args is not None else sys.argv[1:])
if config_file:
toml_data = self._load_toml(config_file)
toml_config = self._dict_to_dataclass(self.__class__, toml_data)
merged_config = self._merge_with_defaults(toml_config, defaults)
# TODO: find a way to make this work without two calls
final_config = tyro.cli(self.__class__, default=merged_config)
else:
final_config = defaults
config = self._dict_to_dataclass(self.__class__, toml_data)
final_config = tyro.cli(self.__class__, args=args, default=config)
self._update(final_config)
self._validate_config()

Expand Down Expand Up @@ -490,22 +495,6 @@ def _dict_to_dataclass(self, config_class: Callable, data: Dict[str, Any]) -> An
kwargs[f.name] = value
return config_class(**kwargs)

def _merge_with_defaults(self, target, defaults) -> Any:
"""Recursively merge two dataclass instances (source overrides defaults)."""
merged_kwargs = {}
for f in fields(target):
target_val = getattr(target, f.name)
default_val = getattr(defaults, f.name)
if is_dataclass(target_val) and is_dataclass(default_val):
merged_kwargs[f.name] = self._merge_with_defaults(
target_val, default_val
)
else:
merged_kwargs[f.name] = (
target_val if target_val is not None else default_val
)
return type(target)(**merged_kwargs)

def _validate_config(self) -> None:
# TODO: Add more mandatory validations
assert self.model.name, "Model name is required"
Expand Down

0 comments on commit de1000e

Please sign in to comment.