From eb3b94ce0e3aba2bc37171a56fcc98916a2d0acd Mon Sep 17 00:00:00 2001 From: Jayson Francis Date: Thu, 2 Jan 2025 19:30:29 -0800 Subject: [PATCH] Fixed a bug + linting --- torchtitan/config_manager.py | 59 +++++++++++++++++------------------- 1 file changed, 27 insertions(+), 32 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index e0cd961c..01c73491 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass, field, fields, is_dataclass, asdict +from dataclasses import asdict, dataclass, field, fields, is_dataclass from typing import Any, Dict, List, Optional, Union import torch @@ -448,30 +448,27 @@ class JobConfig: def to_dict(self) -> Dict[str, Any]: return asdict(self) - @classmethod - def parse_args(cls) -> "JobConfig": + def _update(self, instance: "JobConfig") -> None: + for f in fields(self): + setattr(self, f.name, getattr(instance, f.name, getattr(self, f.name))) + + def parse_args(self): """ Parse CLI arguments, optionally load from a TOML file, merge with defaults, and return a JobConfig instance. """ - cli_config = tyro.cli(cls) - config_file = cli_config.job.config_file - + defaults = tyro.cli(self.__class__) + config_file = defaults.job.config_file if config_file: - logger.info(f"Loading configuration from {config_file}") - toml_data = cls._load_toml(config_file) - toml_config = cls._dict_to_dataclass(cls, toml_data) - - # TOML > deafults - merged_config = cls._merge_with_defaults(cli_config, toml_config) - - # cmdline > TOML > defaults - final_config = tyro.cli(cls, default=merged_config) + toml_data = self._load_toml(config_file) + toml_config = self._dict_to_dataclass(self.__class__, toml_data) + merged_config = self._merge_with_defaults(toml_config, defaults) + # TODO: find a way to make this work without two calls + final_config = tyro.cli(self.__class__, default=merged_config) else: - final_config = cli_config - - final_config._validate_config() - return final_config + final_config = defaults + self._update(final_config) + self._validate_config() @staticmethod def _load_toml(file_path: str) -> Dict[str, Any]: @@ -482,39 +479,37 @@ def _load_toml(file_path: str) -> Dict[str, Any]: logger.exception(f"Error while loading config file: {file_path}") raise e - @classmethod - def _dict_to_dataclass(cls, config_class, data: Dict[str, Any]) -> Any: + def _dict_to_dataclass(self, config_class, data: Dict[str, Any]) -> Any: """Recursively convert dictionaries to nested dataclasses.""" if not is_dataclass(config_class): return data - kwargs = {} for f in fields(config_class): if f.name in data: value = data[f.name] # If target field is also a dataclass and value is a dict, recurse if is_dataclass(f.type) and isinstance(value, dict): - kwargs[f.name] = cls._dict_to_dataclass(f.type, value) + kwargs[f.name] = self._dict_to_dataclass(f.type, value) else: kwargs[f.name] = value return config_class(**kwargs) - @classmethod - def _merge_with_defaults(cls, source: "JobConfig", defaults: "JobConfig") -> "JobConfig": + def _merge_with_defaults( + self, source: "JobConfig", defaults: "JobConfig" + ) -> "JobConfig": """Recursively merge two dataclass instances (source overrides defaults).""" - if not is_dataclass(source) or not is_dataclass(defaults): - return source or defaults - merged_kwargs = {} for f in fields(source): source_val = getattr(source, f.name) default_val = getattr(defaults, f.name) - # If both are dataclasses, merge recursively if is_dataclass(source_val) and is_dataclass(default_val): - merged_kwargs[f.name] = cls._merge_with_defaults(source_val, default_val) + merged_kwargs[f.name] = self._merge_with_defaults( + source_val, default_val + ) else: - merged_kwargs[f.name] = source_val if source_val is not None else default_val - + merged_kwargs[f.name] = ( + source_val if source_val is not None else default_val + ) return type(source)(**merged_kwargs) def _validate_config(self) -> None: