Skip to content

Commit

Permalink
Fixed a bug + linting
Browse files Browse the repository at this point in the history
  • Loading branch information
jaysonfrancis committed Jan 3, 2025
1 parent f7c669b commit eb3b94c
Showing 1 changed file with 27 additions and 32 deletions.
59 changes: 27 additions & 32 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field, fields, is_dataclass, asdict
from dataclasses import asdict, dataclass, field, fields, is_dataclass
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -448,30 +448,27 @@ class JobConfig:
def to_dict(self) -> Dict[str, Any]:
return asdict(self)

@classmethod
def parse_args(cls) -> "JobConfig":
def _update(self, instance: "JobConfig") -> None:
for f in fields(self):
setattr(self, f.name, getattr(instance, f.name, getattr(self, f.name)))

def parse_args(self):
"""
Parse CLI arguments, optionally load from a TOML file,
merge with defaults, and return a JobConfig instance.
"""
cli_config = tyro.cli(cls)
config_file = cli_config.job.config_file

defaults = tyro.cli(self.__class__)
config_file = defaults.job.config_file
if config_file:
logger.info(f"Loading configuration from {config_file}")
toml_data = cls._load_toml(config_file)
toml_config = cls._dict_to_dataclass(cls, toml_data)

# TOML > deafults
merged_config = cls._merge_with_defaults(cli_config, toml_config)

# cmdline > TOML > defaults
final_config = tyro.cli(cls, default=merged_config)
toml_data = self._load_toml(config_file)
toml_config = self._dict_to_dataclass(self.__class__, toml_data)
merged_config = self._merge_with_defaults(toml_config, defaults)
# TODO: find a way to make this work without two calls
final_config = tyro.cli(self.__class__, default=merged_config)
else:
final_config = cli_config

final_config._validate_config()
return final_config
final_config = defaults
self._update(final_config)
self._validate_config()

@staticmethod
def _load_toml(file_path: str) -> Dict[str, Any]:
Expand All @@ -482,39 +479,37 @@ def _load_toml(file_path: str) -> Dict[str, Any]:
logger.exception(f"Error while loading config file: {file_path}")
raise e

@classmethod
def _dict_to_dataclass(cls, config_class, data: Dict[str, Any]) -> Any:
def _dict_to_dataclass(self, config_class, data: Dict[str, Any]) -> Any:
"""Recursively convert dictionaries to nested dataclasses."""
if not is_dataclass(config_class):
return data

kwargs = {}
for f in fields(config_class):
if f.name in data:
value = data[f.name]
# If target field is also a dataclass and value is a dict, recurse
if is_dataclass(f.type) and isinstance(value, dict):
kwargs[f.name] = cls._dict_to_dataclass(f.type, value)
kwargs[f.name] = self._dict_to_dataclass(f.type, value)
else:
kwargs[f.name] = value
return config_class(**kwargs)

@classmethod
def _merge_with_defaults(cls, source: "JobConfig", defaults: "JobConfig") -> "JobConfig":
def _merge_with_defaults(
self, source: "JobConfig", defaults: "JobConfig"
) -> "JobConfig":
"""Recursively merge two dataclass instances (source overrides defaults)."""
if not is_dataclass(source) or not is_dataclass(defaults):
return source or defaults

merged_kwargs = {}
for f in fields(source):
source_val = getattr(source, f.name)
default_val = getattr(defaults, f.name)
# If both are dataclasses, merge recursively
if is_dataclass(source_val) and is_dataclass(default_val):
merged_kwargs[f.name] = cls._merge_with_defaults(source_val, default_val)
merged_kwargs[f.name] = self._merge_with_defaults(
source_val, default_val
)
else:
merged_kwargs[f.name] = source_val if source_val is not None else default_val

merged_kwargs[f.name] = (
source_val if source_val is not None else default_val
)
return type(source)(**merged_kwargs)

def _validate_config(self) -> None:
Expand Down

0 comments on commit eb3b94c

Please sign in to comment.