From 0d8545fbe168d984de8d6bb4b69f48d4d6cc726a Mon Sep 17 00:00:00 2001 From: imoshkov Date: Tue, 20 Feb 2024 08:58:38 -0800 Subject: [PATCH 1/9] added dataclass help utils --- nemo_skills/utils.py | 79 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/nemo_skills/utils.py b/nemo_skills/utils.py index 59e672be5c..01b0b226a4 100644 --- a/nemo_skills/utils.py +++ b/nemo_skills/utils.py @@ -13,8 +13,11 @@ # limitations under the License. import glob +import inspect import logging import sys +import typing +from dataclasses import fields, is_dataclass, MISSING def unroll_files(prediction_jsonl_files): @@ -36,3 +39,79 @@ def setup_logging(disable_hydra_logs: bool = True): sys.argv.extend( ["hydra.run.dir=.", "hydra.output_subdir=null", "hydra/job_logging=none", "hydra/hydra_logging=none"] ) + + +def type_to_str(type_hint): + """Convert type hints to a more readable string.""" + origin = typing.get_origin(type_hint) + args = typing.get_args(type_hint) + + if hasattr(type_hint, '__name__'): + return type_hint.__name__.replace('NoneType', 'None') + elif origin is typing.Union: + if len(args) == 2 and type(None) in args: + return f'Optional[{type_to_str(args[0])}]' + else: + return ' or '.join(type_to_str(arg) for arg in args) + elif origin is typing.Callable: + if args[0] is Ellipsis: + args_str = '...' + else: + args_str = ', '.join(type_to_str(arg) for arg in args[:-1]) + return f'Callable[[{args_str}], {type_to_str(args[-1])}]' + elif origin: + inner_types = ', '.join(type_to_str(arg) for arg in args) + origin_name = origin.__name__ if hasattr(origin, '__name__') else str(origin) + return f'{origin_name}[{inner_types}]' + else: + return str(type_hint) + + +def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = 0): + source_lines = inspect.getsource(dataclass_obj).split('\n') + fields_info = { + field.name: { + 'type': field.type, + 'default': field.default if field.default != MISSING else None, + 'default_factory': field.default_factory if field.default_factory != MISSING else None + } for field in fields(dataclass_obj) + } + comments, comment_cache = {}, [] + + for line in source_lines: + if line.strip().startswith('#'): + comment_cache.append(' '.join(line.strip().lstrip('# ').strip().split())) + continue + if ':' not in line: + continue + + field_name = line.split(':')[0].strip() + if field_name not in fields_info: + continue + + field_info = fields_info[field_name] + field_type = type_to_str(field_info['type']) + default = field_info['default'] + default_factory = field_info['default_factory'] + default_factory_str = f', Default Factory: {default_factory.__name__}' if default_factory else '' + default_str = f', Default: {default}' if not default_factory else default_factory_str + + comment = '. '.join(comment_cache) + field_detail = f"{' ' * level}{prefix + field_name}: {comment}\n{' ' * level}Type: {field_type}{default_str}" + comments[field_name] = field_detail + comment_cache = [] + + # Recursively extract nested dataclasses + if is_dataclass(field_info['type']): + nested_comments = extract_comments_above_fields(field_info['type'], prefix=prefix + field_name + '.', level=level + 1) + for k, v in nested_comments.items(): + comments[f"{prefix}{field_name}.{k}"] = v + + return comments + + +def print_fields_docstring(dataclass_obj): + commented_fields = extract_comments_above_fields(dataclass_obj) + for content in commented_fields.values(): + print(content) + print() From bcb61fddcb6baa58e5699bcdfea290c60518ef4d Mon Sep 17 00:00:00 2001 From: imoshkov Date: Tue, 20 Feb 2024 09:00:33 -0800 Subject: [PATCH 2/9] format fixes --- nemo_skills/utils.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/nemo_skills/utils.py b/nemo_skills/utils.py index 01b0b226a4..d6b67becc7 100644 --- a/nemo_skills/utils.py +++ b/nemo_skills/utils.py @@ -17,7 +17,7 @@ import logging import sys import typing -from dataclasses import fields, is_dataclass, MISSING +from dataclasses import MISSING, fields, is_dataclass def unroll_files(prediction_jsonl_files): @@ -73,8 +73,9 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = field.name: { 'type': field.type, 'default': field.default if field.default != MISSING else None, - 'default_factory': field.default_factory if field.default_factory != MISSING else None - } for field in fields(dataclass_obj) + 'default_factory': field.default_factory if field.default_factory != MISSING else None, + } + for field in fields(dataclass_obj) } comments, comment_cache = {}, [] @@ -95,7 +96,7 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = default_factory = field_info['default_factory'] default_factory_str = f', Default Factory: {default_factory.__name__}' if default_factory else '' default_str = f', Default: {default}' if not default_factory else default_factory_str - + comment = '. '.join(comment_cache) field_detail = f"{' ' * level}{prefix + field_name}: {comment}\n{' ' * level}Type: {field_type}{default_str}" comments[field_name] = field_detail @@ -103,7 +104,9 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = # Recursively extract nested dataclasses if is_dataclass(field_info['type']): - nested_comments = extract_comments_above_fields(field_info['type'], prefix=prefix + field_name + '.', level=level + 1) + nested_comments = extract_comments_above_fields( + field_info['type'], prefix=prefix + field_name + '.', level=level + 1 + ) for k, v in nested_comments.items(): comments[f"{prefix}{field_name}.{k}"] = v From 46d0a3bc04862119bd314ed11f67ae9d7c37e2f9 Mon Sep 17 00:00:00 2001 From: imoshkov Date: Tue, 20 Feb 2024 09:01:59 -0800 Subject: [PATCH 3/9] added help printing --- nemo_skills/evaluation/evaluate_results.py | 12 +++++++----- nemo_skills/finetuning/prepare_sft_data.py | 9 ++++++--- nemo_skills/inference/generate_solutions.py | 10 +++++++--- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/nemo_skills/evaluation/evaluate_results.py b/nemo_skills/evaluation/evaluate_results.py index d271dc0ffa..c1e43b8f8c 100644 --- a/nemo_skills/evaluation/evaluate_results.py +++ b/nemo_skills/evaluation/evaluate_results.py @@ -13,15 +13,14 @@ # limitations under the License. import logging +import sys from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional import hydra from omegaconf import MISSING, OmegaConf from nemo_skills.code_execution.sandbox import get_sandbox -from nemo_skills.utils import setup_logging +from nemo_skills.utils import print_fields_docstring, setup_logging LOG = logging.getLogger(__file__) @@ -71,5 +70,8 @@ def evaluate_results(cfg: EvaluateResultsConfig): if __name__ == "__main__": - setup_logging() - evaluate_results() + if '--help' in sys.argv: + print_fields_docstring(EvaluateResultsConfig) + else: + setup_logging() + evaluate_results() diff --git a/nemo_skills/finetuning/prepare_sft_data.py b/nemo_skills/finetuning/prepare_sft_data.py index ee24e17bce..356b57867f 100644 --- a/nemo_skills/finetuning/prepare_sft_data.py +++ b/nemo_skills/finetuning/prepare_sft_data.py @@ -33,7 +33,7 @@ from nemo_skills.finetuning.filtering_utils import downsample_data, process_bad_solutions from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt -from nemo_skills.utils import setup_logging, unroll_files +from nemo_skills.utils import print_fields_docstring, setup_logging, unroll_files LOG = logging.getLogger(__file__) @@ -206,5 +206,8 @@ def prepare_sft_data(cfg: PrepareSFTDataConfig): if __name__ == "__main__": - setup_logging() - prepare_sft_data() + if '--help' in sys.argv: + print_fields_docstring(PrepareSFTDataConfig) + else: + setup_logging() + prepare_sft_data() diff --git a/nemo_skills/inference/generate_solutions.py b/nemo_skills/inference/generate_solutions.py index bfdb2e7045..0c3b098b55 100644 --- a/nemo_skills/inference/generate_solutions.py +++ b/nemo_skills/inference/generate_solutions.py @@ -14,6 +14,7 @@ import json import logging +import sys from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Optional @@ -25,7 +26,7 @@ from nemo_skills.code_execution.sandbox import get_sandbox from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt from nemo_skills.inference.server.model import get_model -from nemo_skills.utils import setup_logging +from nemo_skills.utils import print_fields_docstring, setup_logging LOG = logging.getLogger(__file__) @@ -137,5 +138,8 @@ def generate_solutions(cfg: GenerateSolutionsConfig): if __name__ == "__main__": - setup_logging() - generate_solutions() + if '--help' in sys.argv: + print_fields_docstring(GenerateSolutionsConfig) + else: + setup_logging() + generate_solutions() From deaffe41e631ae639e77363a782ce8d2c494ac0b Mon Sep 17 00:00:00 2001 From: imoshkov Date: Tue, 20 Feb 2024 09:02:26 -0800 Subject: [PATCH 4/9] added missing requirements --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 50b685f82b..0077dd723b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ 'tqdm', 'pyyaml', 'numpy', + 'requests', + 'sympy', ] [tool.setuptools.packages.find] From 303561de764750d37a46b729ab8cda9ce11760f1 Mon Sep 17 00:00:00 2001 From: imoshkov Date: Tue, 20 Feb 2024 09:53:16 -0800 Subject: [PATCH 5/9] improved comment parsing --- nemo_skills/utils.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/nemo_skills/utils.py b/nemo_skills/utils.py index d6b67becc7..76adc20d80 100644 --- a/nemo_skills/utils.py +++ b/nemo_skills/utils.py @@ -14,8 +14,10 @@ import glob import inspect +import io import logging import sys +import tokenize import typing from dataclasses import MISSING, fields, is_dataclass @@ -41,6 +43,18 @@ def setup_logging(disable_hydra_logs: bool = True): ) +def extract_comments(code: str): + """Extract a list of comments from the given Python code.""" + comments = [] + tokens = tokenize.tokenize(io.BytesIO(code.encode()).readline) + + for token, line, *_ in tokens: + if token is tokenize.COMMENT: + comments.append(line.lstrip('#').strip()) + + return comments + + def type_to_str(type_hint): """Convert type hints to a more readable string.""" origin = typing.get_origin(type_hint) @@ -80,9 +94,9 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = comments, comment_cache = {}, [] for line in source_lines: - if line.strip().startswith('#'): - comment_cache.append(' '.join(line.strip().lstrip('# ').strip().split())) - continue + line_comment = extract_comments(line) + if line_comment: + comment_cache.append(line_comment[0]) if ':' not in line: continue From a572f7f7bd757badf2ec208b3322050fff90e9a0 Mon Sep 17 00:00:00 2001 From: imoshkov Date: Wed, 21 Feb 2024 02:45:59 -0800 Subject: [PATCH 6/9] improved help readability --- nemo_skills/utils.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/nemo_skills/utils.py b/nemo_skills/utils.py index 76adc20d80..1d3cdf7996 100644 --- a/nemo_skills/utils.py +++ b/nemo_skills/utils.py @@ -94,6 +94,9 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = comments, comment_cache = {}, [] for line in source_lines: + # skip unfinished multiline comments + if line.count("'") % 2 or line.count('"') % 2: + continue line_comment = extract_comments(line) if line_comment: comment_cache.append(line_comment[0]) @@ -105,30 +108,37 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = continue field_info = fields_info[field_name] + field_name = prefix + field_name field_type = type_to_str(field_info['type']) default = field_info['default'] default_factory = field_info['default_factory'] - default_factory_str = f', Default Factory: {default_factory.__name__}' if default_factory else '' - default_str = f', Default: {default}' if not default_factory else default_factory_str - - comment = '. '.join(comment_cache) - field_detail = f"{' ' * level}{prefix + field_name}: {comment}\n{' ' * level}Type: {field_type}{default_str}" + default_factory_str = f'default factory: {default_factory.__name__}' if default_factory else '' + if default == '???': + default_str = ', required' + else: + default_str = f', default: {default}' if not default_factory else default_factory_str + if is_dataclass(default_factory): + default_str = '' + + indent = ' ' * level + comment = f"\n{indent}".join(comment_cache) + comment = "- " + comment if comment else "" + field_detail = f"{indent}{field_name} ({field_type}{default_str}) {comment}" comments[field_name] = field_detail comment_cache = [] # Recursively extract nested dataclasses if is_dataclass(field_info['type']): nested_comments = extract_comments_above_fields( - field_info['type'], prefix=prefix + field_name + '.', level=level + 1 + field_info['type'], prefix=field_name + '.', level=level + 1 ) for k, v in nested_comments.items(): - comments[f"{prefix}{field_name}.{k}"] = v + comments[f"{field_name}.{k}"] = v return comments -def print_fields_docstring(dataclass_obj): +def get_fields_docstring(dataclass_obj): commented_fields = extract_comments_above_fields(dataclass_obj) - for content in commented_fields.values(): - print(content) - print() + docstring = [content for content in commented_fields.values()] + return '\n\n'.join(docstring) From dc0cd0caa4c63092d4a8ed20974c92fabff0d2d2 Mon Sep 17 00:00:00 2001 From: imoshkov Date: Wed, 21 Feb 2024 02:48:10 -0800 Subject: [PATCH 7/9] added help as a string --- nemo_skills/evaluation/evaluate_results.py | 5 +++-- nemo_skills/finetuning/prepare_sft_data.py | 6 ++++-- nemo_skills/inference/generate_solutions.py | 5 +++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/nemo_skills/evaluation/evaluate_results.py b/nemo_skills/evaluation/evaluate_results.py index c1e43b8f8c..d3b5694724 100644 --- a/nemo_skills/evaluation/evaluate_results.py +++ b/nemo_skills/evaluation/evaluate_results.py @@ -20,7 +20,7 @@ from omegaconf import MISSING, OmegaConf from nemo_skills.code_execution.sandbox import get_sandbox -from nemo_skills.utils import print_fields_docstring, setup_logging +from nemo_skills.utils import get_fields_docstring, setup_logging LOG = logging.getLogger(__file__) @@ -71,7 +71,8 @@ def evaluate_results(cfg: EvaluateResultsConfig): if __name__ == "__main__": if '--help' in sys.argv: - print_fields_docstring(EvaluateResultsConfig) + help_msg = get_fields_docstring(EvaluateResultsConfig) + print(help_msg) else: setup_logging() evaluate_results() diff --git a/nemo_skills/finetuning/prepare_sft_data.py b/nemo_skills/finetuning/prepare_sft_data.py index 356b57867f..fa63476869 100644 --- a/nemo_skills/finetuning/prepare_sft_data.py +++ b/nemo_skills/finetuning/prepare_sft_data.py @@ -21,6 +21,7 @@ from dataclasses import dataclass, field from itertools import zip_longest from pathlib import Path +from pydoc import doc from typing import Dict, List, Optional import hydra @@ -33,7 +34,7 @@ from nemo_skills.finetuning.filtering_utils import downsample_data, process_bad_solutions from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt -from nemo_skills.utils import print_fields_docstring, setup_logging, unroll_files +from nemo_skills.utils import get_fields_docstring, setup_logging, unroll_files LOG = logging.getLogger(__file__) @@ -207,7 +208,8 @@ def prepare_sft_data(cfg: PrepareSFTDataConfig): if __name__ == "__main__": if '--help' in sys.argv: - print_fields_docstring(PrepareSFTDataConfig) + help_msg = get_fields_docstring(PrepareSFTDataConfig) + # print(help_msg) else: setup_logging() prepare_sft_data() diff --git a/nemo_skills/inference/generate_solutions.py b/nemo_skills/inference/generate_solutions.py index 0c3b098b55..af24ad3857 100644 --- a/nemo_skills/inference/generate_solutions.py +++ b/nemo_skills/inference/generate_solutions.py @@ -26,7 +26,7 @@ from nemo_skills.code_execution.sandbox import get_sandbox from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt from nemo_skills.inference.server.model import get_model -from nemo_skills.utils import print_fields_docstring, setup_logging +from nemo_skills.utils import get_fields_docstring, setup_logging LOG = logging.getLogger(__file__) @@ -139,7 +139,8 @@ def generate_solutions(cfg: GenerateSolutionsConfig): if __name__ == "__main__": if '--help' in sys.argv: - print_fields_docstring(GenerateSolutionsConfig) + help_msg = get_fields_docstring(GenerateSolutionsConfig) + print(help_msg) else: setup_logging() generate_solutions() From eb9e504e671bb78acd34029061877b6be3451d26 Mon Sep 17 00:00:00 2001 From: imoshkov Date: Thu, 22 Feb 2024 02:50:39 -0800 Subject: [PATCH 8/9] added reference to hydra in help msg --- nemo_skills/evaluation/evaluate_results.py | 4 ++-- nemo_skills/finetuning/prepare_masked_data.py | 10 +++++++--- nemo_skills/finetuning/prepare_sft_data.py | 6 +++--- nemo_skills/inference/generate_solutions.py | 4 ++-- nemo_skills/utils.py | 14 +++++++++++++- 5 files changed, 27 insertions(+), 11 deletions(-) diff --git a/nemo_skills/evaluation/evaluate_results.py b/nemo_skills/evaluation/evaluate_results.py index d3b5694724..67cef6a774 100644 --- a/nemo_skills/evaluation/evaluate_results.py +++ b/nemo_skills/evaluation/evaluate_results.py @@ -20,7 +20,7 @@ from omegaconf import MISSING, OmegaConf from nemo_skills.code_execution.sandbox import get_sandbox -from nemo_skills.utils import get_fields_docstring, setup_logging +from nemo_skills.utils import get_help_message, setup_logging LOG = logging.getLogger(__file__) @@ -71,7 +71,7 @@ def evaluate_results(cfg: EvaluateResultsConfig): if __name__ == "__main__": if '--help' in sys.argv: - help_msg = get_fields_docstring(EvaluateResultsConfig) + help_msg = get_help_message(EvaluateResultsConfig) print(help_msg) else: setup_logging() diff --git a/nemo_skills/finetuning/prepare_masked_data.py b/nemo_skills/finetuning/prepare_masked_data.py index e828d93c3d..16909ee018 100644 --- a/nemo_skills/finetuning/prepare_masked_data.py +++ b/nemo_skills/finetuning/prepare_masked_data.py @@ -24,7 +24,7 @@ sys.path.append(str(Path(__file__).absolute().parents[2])) -from nemo_skills.utils import setup_logging, unroll_files +from nemo_skills.utils import get_help_message, setup_logging, unroll_files LOG = logging.getLogger(__file__) @@ -157,5 +157,9 @@ def prepare_masked_data(cfg: PrepareMaskedDataConfig): if __name__ == '__main__': - setup_logging() - prepare_masked_data() + if '--help' in sys.argv: + help_msg = get_help_message(PrepareMaskedDataConfig) + print(help_msg) + else: + setup_logging() + prepare_masked_data() diff --git a/nemo_skills/finetuning/prepare_sft_data.py b/nemo_skills/finetuning/prepare_sft_data.py index fa63476869..79d8c5fa79 100644 --- a/nemo_skills/finetuning/prepare_sft_data.py +++ b/nemo_skills/finetuning/prepare_sft_data.py @@ -34,7 +34,7 @@ from nemo_skills.finetuning.filtering_utils import downsample_data, process_bad_solutions from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt -from nemo_skills.utils import get_fields_docstring, setup_logging, unroll_files +from nemo_skills.utils import get_help_message, setup_logging, unroll_files LOG = logging.getLogger(__file__) @@ -208,8 +208,8 @@ def prepare_sft_data(cfg: PrepareSFTDataConfig): if __name__ == "__main__": if '--help' in sys.argv: - help_msg = get_fields_docstring(PrepareSFTDataConfig) - # print(help_msg) + help_msg = get_help_message(PrepareSFTDataConfig) + print(help_msg) else: setup_logging() prepare_sft_data() diff --git a/nemo_skills/inference/generate_solutions.py b/nemo_skills/inference/generate_solutions.py index af24ad3857..062ba9ee7c 100644 --- a/nemo_skills/inference/generate_solutions.py +++ b/nemo_skills/inference/generate_solutions.py @@ -26,7 +26,7 @@ from nemo_skills.code_execution.sandbox import get_sandbox from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt from nemo_skills.inference.server.model import get_model -from nemo_skills.utils import get_fields_docstring, setup_logging +from nemo_skills.utils import get_help_message, setup_logging LOG = logging.getLogger(__file__) @@ -139,7 +139,7 @@ def generate_solutions(cfg: GenerateSolutionsConfig): if __name__ == "__main__": if '--help' in sys.argv: - help_msg = get_fields_docstring(GenerateSolutionsConfig) + help_msg = get_help_message(GenerateSolutionsConfig) print(help_msg) else: setup_logging() diff --git a/nemo_skills/utils.py b/nemo_skills/utils.py index 1d3cdf7996..28ca113f06 100644 --- a/nemo_skills/utils.py +++ b/nemo_skills/utils.py @@ -95,7 +95,7 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = for line in source_lines: # skip unfinished multiline comments - if line.count("'") % 2 or line.count('"') % 2: + if line.count("'") == 3 or line.count('"') == 3: continue line_comment = extract_comments(line) if line_comment: @@ -142,3 +142,15 @@ def get_fields_docstring(dataclass_obj): commented_fields = extract_comments_above_fields(dataclass_obj) docstring = [content for content in commented_fields.values()] return '\n\n'.join(docstring) + + +def get_help_message(dataclass_obj): + heading = """ +This script uses Hydra for dynamic configuration management. +You can apply Hydra's command-line syntax for overriding configuration values directly. +Below are the available configuration options and their default values: + """.strip() + + docstring = get_fields_docstring(dataclass_obj) + + return f"{heading}\n\n{'-' * 75}\n\n{docstring}" From 0708b354e7bfb27c2dafaa4426c6d0e3b99a204a Mon Sep 17 00:00:00 2001 From: imoshkov Date: Fri, 23 Feb 2024 14:53:12 -0800 Subject: [PATCH 9/9] formatting fixes --- nemo_skills/finetuning/prepare_sft_data.py | 5 ++--- nemo_skills/utils.py | 25 +++++++++++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/nemo_skills/finetuning/prepare_sft_data.py b/nemo_skills/finetuning/prepare_sft_data.py index 79d8c5fa79..a7b148228b 100644 --- a/nemo_skills/finetuning/prepare_sft_data.py +++ b/nemo_skills/finetuning/prepare_sft_data.py @@ -21,8 +21,7 @@ from dataclasses import dataclass, field from itertools import zip_longest from pathlib import Path -from pydoc import doc -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import hydra import numpy as np @@ -64,7 +63,7 @@ class PrepareSFTDataConfig: preprocessed_dataset_files: Optional[str] = None # can specify datasets from HF instead of prediction_jsonl_files output_path: str = MISSING # can provide additional metadata to store (e.g. dataset or generation_type) - metadata: Dict = field(default_factory=dict) + metadata: Dict[Any, Any] = field(default_factory=dict) skip_first: int = 0 # useful for skipping validation set from train_full generation (it's always first) add_correct: bool = True # can set to False if only want to export incorrect solutions add_incorrect: bool = False # if True, saves only incorrect solutions instead of correct diff --git a/nemo_skills/utils.py b/nemo_skills/utils.py index 28ca113f06..4aa2e6ea93 100644 --- a/nemo_skills/utils.py +++ b/nemo_skills/utils.py @@ -78,7 +78,7 @@ def type_to_str(type_hint): origin_name = origin.__name__ if hasattr(origin, '__name__') else str(origin) return f'{origin_name}[{inner_types}]' else: - return str(type_hint) + return str(type_hint).replace('typing.', '') def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = 0): @@ -112,18 +112,23 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = field_type = type_to_str(field_info['type']) default = field_info['default'] default_factory = field_info['default_factory'] - default_factory_str = f'default factory: {default_factory.__name__}' if default_factory else '' if default == '???': - default_str = ', required' + default_str = ' = MISSING' else: - default_str = f', default: {default}' if not default_factory else default_factory_str - if is_dataclass(default_factory): - default_str = '' + default_str = f' = {default}' + if default_factory: + try: + default_factory = default_factory() + default_str = f' = {default_factory}' + except: + pass + if is_dataclass(default_factory): + default_str = f' = {field_type}()' indent = ' ' * level comment = f"\n{indent}".join(comment_cache) comment = "- " + comment if comment else "" - field_detail = f"{indent}{field_name} ({field_type}{default_str}) {comment}" + field_detail = f"{indent}{field_name}: {field_type}{default_str} {comment}" comments[field_name] = field_detail comment_cache = [] @@ -141,16 +146,16 @@ def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = def get_fields_docstring(dataclass_obj): commented_fields = extract_comments_above_fields(dataclass_obj) docstring = [content for content in commented_fields.values()] - return '\n\n'.join(docstring) + return '\n'.join(docstring) def get_help_message(dataclass_obj): heading = """ -This script uses Hydra for dynamic configuration management. +This script uses Hydra (https://hydra.cc/) for dynamic configuration management. You can apply Hydra's command-line syntax for overriding configuration values directly. Below are the available configuration options and their default values: """.strip() docstring = get_fields_docstring(dataclass_obj) - return f"{heading}\n\n{'-' * 75}\n\n{docstring}" + return f"{heading}\n{'-' * 75}\n{docstring}"