diff --git a/nemo_skills/evaluation/evaluate_results.py b/nemo_skills/evaluation/evaluate_results.py index d271dc0ffa..67cef6a774 100644 --- a/nemo_skills/evaluation/evaluate_results.py +++ b/nemo_skills/evaluation/evaluate_results.py @@ -13,15 +13,14 @@ # limitations under the License. import logging +import sys from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional import hydra from omegaconf import MISSING, OmegaConf from nemo_skills.code_execution.sandbox import get_sandbox -from nemo_skills.utils import setup_logging +from nemo_skills.utils import get_help_message, setup_logging LOG = logging.getLogger(__file__) @@ -71,5 +70,9 @@ def evaluate_results(cfg: EvaluateResultsConfig): if __name__ == "__main__": - setup_logging() - evaluate_results() + if '--help' in sys.argv: + help_msg = get_help_message(EvaluateResultsConfig) + print(help_msg) + else: + setup_logging() + evaluate_results() diff --git a/nemo_skills/finetuning/prepare_masked_data.py b/nemo_skills/finetuning/prepare_masked_data.py index e828d93c3d..16909ee018 100644 --- a/nemo_skills/finetuning/prepare_masked_data.py +++ b/nemo_skills/finetuning/prepare_masked_data.py @@ -24,7 +24,7 @@ sys.path.append(str(Path(__file__).absolute().parents[2])) -from nemo_skills.utils import setup_logging, unroll_files +from nemo_skills.utils import get_help_message, setup_logging, unroll_files LOG = logging.getLogger(__file__) @@ -157,5 +157,9 @@ def prepare_masked_data(cfg: PrepareMaskedDataConfig): if __name__ == '__main__': - setup_logging() - prepare_masked_data() + if '--help' in sys.argv: + help_msg = get_help_message(PrepareMaskedDataConfig) + print(help_msg) + else: + setup_logging() + prepare_masked_data() diff --git a/nemo_skills/finetuning/prepare_sft_data.py b/nemo_skills/finetuning/prepare_sft_data.py index ee24e17bce..a7b148228b 100644 --- a/nemo_skills/finetuning/prepare_sft_data.py +++ b/nemo_skills/finetuning/prepare_sft_data.py @@ -21,7 +21,7 @@ from dataclasses import dataclass, field from itertools import zip_longest from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import hydra import numpy as np @@ -33,7 +33,7 @@ from nemo_skills.finetuning.filtering_utils import downsample_data, process_bad_solutions from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt -from nemo_skills.utils import setup_logging, unroll_files +from nemo_skills.utils import get_help_message, setup_logging, unroll_files LOG = logging.getLogger(__file__) @@ -63,7 +63,7 @@ class PrepareSFTDataConfig: preprocessed_dataset_files: Optional[str] = None # can specify datasets from HF instead of prediction_jsonl_files output_path: str = MISSING # can provide additional metadata to store (e.g. dataset or generation_type) - metadata: Dict = field(default_factory=dict) + metadata: Dict[Any, Any] = field(default_factory=dict) skip_first: int = 0 # useful for skipping validation set from train_full generation (it's always first) add_correct: bool = True # can set to False if only want to export incorrect solutions add_incorrect: bool = False # if True, saves only incorrect solutions instead of correct @@ -206,5 +206,9 @@ def prepare_sft_data(cfg: PrepareSFTDataConfig): if __name__ == "__main__": - setup_logging() - prepare_sft_data() + if '--help' in sys.argv: + help_msg = get_help_message(PrepareSFTDataConfig) + print(help_msg) + else: + setup_logging() + prepare_sft_data() diff --git a/nemo_skills/inference/generate_solutions.py b/nemo_skills/inference/generate_solutions.py index bfdb2e7045..062ba9ee7c 100644 --- a/nemo_skills/inference/generate_solutions.py +++ b/nemo_skills/inference/generate_solutions.py @@ -14,6 +14,7 @@ import json import logging +import sys from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Optional @@ -25,7 +26,7 @@ from nemo_skills.code_execution.sandbox import get_sandbox from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt from nemo_skills.inference.server.model import get_model -from nemo_skills.utils import setup_logging +from nemo_skills.utils import get_help_message, setup_logging LOG = logging.getLogger(__file__) @@ -137,5 +138,9 @@ def generate_solutions(cfg: GenerateSolutionsConfig): if __name__ == "__main__": - setup_logging() - generate_solutions() + if '--help' in sys.argv: + help_msg = get_help_message(GenerateSolutionsConfig) + print(help_msg) + else: + setup_logging() + generate_solutions() diff --git a/nemo_skills/utils.py b/nemo_skills/utils.py index 59e672be5c..4aa2e6ea93 100644 --- a/nemo_skills/utils.py +++ b/nemo_skills/utils.py @@ -13,8 +13,13 @@ # limitations under the License. import glob +import inspect +import io import logging import sys +import tokenize +import typing +from dataclasses import MISSING, fields, is_dataclass def unroll_files(prediction_jsonl_files): @@ -36,3 +41,121 @@ def setup_logging(disable_hydra_logs: bool = True): sys.argv.extend( ["hydra.run.dir=.", "hydra.output_subdir=null", "hydra/job_logging=none", "hydra/hydra_logging=none"] ) + + +def extract_comments(code: str): + """Extract a list of comments from the given Python code.""" + comments = [] + tokens = tokenize.tokenize(io.BytesIO(code.encode()).readline) + + for token, line, *_ in tokens: + if token is tokenize.COMMENT: + comments.append(line.lstrip('#').strip()) + + return comments + + +def type_to_str(type_hint): + """Convert type hints to a more readable string.""" + origin = typing.get_origin(type_hint) + args = typing.get_args(type_hint) + + if hasattr(type_hint, '__name__'): + return type_hint.__name__.replace('NoneType', 'None') + elif origin is typing.Union: + if len(args) == 2 and type(None) in args: + return f'Optional[{type_to_str(args[0])}]' + else: + return ' or '.join(type_to_str(arg) for arg in args) + elif origin is typing.Callable: + if args[0] is Ellipsis: + args_str = '...' + else: + args_str = ', '.join(type_to_str(arg) for arg in args[:-1]) + return f'Callable[[{args_str}], {type_to_str(args[-1])}]' + elif origin: + inner_types = ', '.join(type_to_str(arg) for arg in args) + origin_name = origin.__name__ if hasattr(origin, '__name__') else str(origin) + return f'{origin_name}[{inner_types}]' + else: + return str(type_hint).replace('typing.', '') + + +def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = 0): + source_lines = inspect.getsource(dataclass_obj).split('\n') + fields_info = { + field.name: { + 'type': field.type, + 'default': field.default if field.default != MISSING else None, + 'default_factory': field.default_factory if field.default_factory != MISSING else None, + } + for field in fields(dataclass_obj) + } + comments, comment_cache = {}, [] + + for line in source_lines: + # skip unfinished multiline comments + if line.count("'") == 3 or line.count('"') == 3: + continue + line_comment = extract_comments(line) + if line_comment: + comment_cache.append(line_comment[0]) + if ':' not in line: + continue + + field_name = line.split(':')[0].strip() + if field_name not in fields_info: + continue + + field_info = fields_info[field_name] + field_name = prefix + field_name + field_type = type_to_str(field_info['type']) + default = field_info['default'] + default_factory = field_info['default_factory'] + if default == '???': + default_str = ' = MISSING' + else: + default_str = f' = {default}' + if default_factory: + try: + default_factory = default_factory() + default_str = f' = {default_factory}' + except: + pass + if is_dataclass(default_factory): + default_str = f' = {field_type}()' + + indent = ' ' * level + comment = f"\n{indent}".join(comment_cache) + comment = "- " + comment if comment else "" + field_detail = f"{indent}{field_name}: {field_type}{default_str} {comment}" + comments[field_name] = field_detail + comment_cache = [] + + # Recursively extract nested dataclasses + if is_dataclass(field_info['type']): + nested_comments = extract_comments_above_fields( + field_info['type'], prefix=field_name + '.', level=level + 1 + ) + for k, v in nested_comments.items(): + comments[f"{field_name}.{k}"] = v + + return comments + + +def get_fields_docstring(dataclass_obj): + commented_fields = extract_comments_above_fields(dataclass_obj) + docstring = [content for content in commented_fields.values()] + return '\n'.join(docstring) + + +def get_help_message(dataclass_obj): + heading = """ +This script uses Hydra (https://hydra.cc/) for dynamic configuration management. +You can apply Hydra's command-line syntax for overriding configuration values directly. +Below are the available configuration options and their default values: + """.strip() + + docstring = get_fields_docstring(dataclass_obj) + + return f"{heading}\n{'-' * 75}\n{docstring}" diff --git a/pyproject.toml b/pyproject.toml index 50b685f82b..0077dd723b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,8 @@ dependencies = [ 'tqdm', 'pyyaml', 'numpy', + 'requests', + 'sympy', ] [tool.setuptools.packages.find]