Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions nemo_skills/evaluation/evaluate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,14 @@
# limitations under the License.

import logging
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import hydra
from omegaconf import MISSING, OmegaConf

from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.utils import setup_logging
from nemo_skills.utils import get_help_message, setup_logging

LOG = logging.getLogger(__file__)

Expand Down Expand Up @@ -71,5 +70,9 @@ def evaluate_results(cfg: EvaluateResultsConfig):


if __name__ == "__main__":
setup_logging()
evaluate_results()
if '--help' in sys.argv:
help_msg = get_help_message(EvaluateResultsConfig)
print(help_msg)
else:
setup_logging()
evaluate_results()
10 changes: 7 additions & 3 deletions nemo_skills/finetuning/prepare_masked_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

sys.path.append(str(Path(__file__).absolute().parents[2]))

from nemo_skills.utils import setup_logging, unroll_files
from nemo_skills.utils import get_help_message, setup_logging, unroll_files

LOG = logging.getLogger(__file__)

Expand Down Expand Up @@ -157,5 +157,9 @@ def prepare_masked_data(cfg: PrepareMaskedDataConfig):


if __name__ == '__main__':
setup_logging()
prepare_masked_data()
if '--help' in sys.argv:
help_msg = get_help_message(PrepareMaskedDataConfig)
print(help_msg)
else:
setup_logging()
prepare_masked_data()
14 changes: 9 additions & 5 deletions nemo_skills/finetuning/prepare_sft_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from dataclasses import dataclass, field
from itertools import zip_longest
from pathlib import Path
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import hydra
import numpy as np
Expand All @@ -33,7 +33,7 @@

from nemo_skills.finetuning.filtering_utils import downsample_data, process_bad_solutions
from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt
from nemo_skills.utils import setup_logging, unroll_files
from nemo_skills.utils import get_help_message, setup_logging, unroll_files

LOG = logging.getLogger(__file__)

Expand Down Expand Up @@ -63,7 +63,7 @@ class PrepareSFTDataConfig:
preprocessed_dataset_files: Optional[str] = None # can specify datasets from HF instead of prediction_jsonl_files
output_path: str = MISSING
# can provide additional metadata to store (e.g. dataset or generation_type)
metadata: Dict = field(default_factory=dict)
metadata: Dict[Any, Any] = field(default_factory=dict)
skip_first: int = 0 # useful for skipping validation set from train_full generation (it's always first)
add_correct: bool = True # can set to False if only want to export incorrect solutions
add_incorrect: bool = False # if True, saves only incorrect solutions instead of correct
Expand Down Expand Up @@ -206,5 +206,9 @@ def prepare_sft_data(cfg: PrepareSFTDataConfig):


if __name__ == "__main__":
setup_logging()
prepare_sft_data()
if '--help' in sys.argv:
help_msg = get_help_message(PrepareSFTDataConfig)
print(help_msg)
else:
setup_logging()
prepare_sft_data()
11 changes: 8 additions & 3 deletions nemo_skills/inference/generate_solutions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
import logging
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Optional
Expand All @@ -25,7 +26,7 @@
from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.inference.prompt.utils import PromptConfig, get_prompt
from nemo_skills.inference.server.model import get_model
from nemo_skills.utils import setup_logging
from nemo_skills.utils import get_help_message, setup_logging

LOG = logging.getLogger(__file__)

Expand Down Expand Up @@ -137,5 +138,9 @@ def generate_solutions(cfg: GenerateSolutionsConfig):


if __name__ == "__main__":
setup_logging()
generate_solutions()
if '--help' in sys.argv:
help_msg = get_help_message(GenerateSolutionsConfig)
print(help_msg)
else:
setup_logging()
generate_solutions()
123 changes: 123 additions & 0 deletions nemo_skills/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
# limitations under the License.

import glob
import inspect
import io
import logging
import sys
import tokenize
import typing
from dataclasses import MISSING, fields, is_dataclass


def unroll_files(prediction_jsonl_files):
Expand All @@ -36,3 +41,121 @@ def setup_logging(disable_hydra_logs: bool = True):
sys.argv.extend(
["hydra.run.dir=.", "hydra.output_subdir=null", "hydra/job_logging=none", "hydra/hydra_logging=none"]
)


def extract_comments(code: str):
"""Extract a list of comments from the given Python code."""
comments = []
tokens = tokenize.tokenize(io.BytesIO(code.encode()).readline)

for token, line, *_ in tokens:
if token is tokenize.COMMENT:
comments.append(line.lstrip('#').strip())

return comments


def type_to_str(type_hint):
"""Convert type hints to a more readable string."""
origin = typing.get_origin(type_hint)
args = typing.get_args(type_hint)

if hasattr(type_hint, '__name__'):
return type_hint.__name__.replace('NoneType', 'None')
elif origin is typing.Union:
if len(args) == 2 and type(None) in args:
return f'Optional[{type_to_str(args[0])}]'
else:
return ' or '.join(type_to_str(arg) for arg in args)
elif origin is typing.Callable:
if args[0] is Ellipsis:
args_str = '...'
else:
args_str = ', '.join(type_to_str(arg) for arg in args[:-1])
return f'Callable[[{args_str}], {type_to_str(args[-1])}]'
elif origin:
inner_types = ', '.join(type_to_str(arg) for arg in args)
origin_name = origin.__name__ if hasattr(origin, '__name__') else str(origin)
return f'{origin_name}[{inner_types}]'
else:
return str(type_hint).replace('typing.', '')


def extract_comments_above_fields(dataclass_obj, prefix: str = '', level: int = 0):
source_lines = inspect.getsource(dataclass_obj).split('\n')
fields_info = {
field.name: {
'type': field.type,
'default': field.default if field.default != MISSING else None,
'default_factory': field.default_factory if field.default_factory != MISSING else None,
}
for field in fields(dataclass_obj)
}
comments, comment_cache = {}, []

for line in source_lines:
# skip unfinished multiline comments
if line.count("'") == 3 or line.count('"') == 3:
continue
line_comment = extract_comments(line)
if line_comment:
comment_cache.append(line_comment[0])
if ':' not in line:
continue

field_name = line.split(':')[0].strip()
if field_name not in fields_info:
continue

field_info = fields_info[field_name]
field_name = prefix + field_name
field_type = type_to_str(field_info['type'])
default = field_info['default']
default_factory = field_info['default_factory']
if default == '???':
default_str = ' = MISSING'
else:
default_str = f' = {default}'
if default_factory:
try:
default_factory = default_factory()
default_str = f' = {default_factory}'
except:
pass
if is_dataclass(default_factory):
default_str = f' = {field_type}()'

indent = ' ' * level
comment = f"\n{indent}".join(comment_cache)
comment = "- " + comment if comment else ""
field_detail = f"{indent}{field_name}: {field_type}{default_str} {comment}"
comments[field_name] = field_detail
comment_cache = []

# Recursively extract nested dataclasses
if is_dataclass(field_info['type']):
nested_comments = extract_comments_above_fields(
field_info['type'], prefix=field_name + '.', level=level + 1
)
for k, v in nested_comments.items():
comments[f"{field_name}.{k}"] = v

return comments


def get_fields_docstring(dataclass_obj):
commented_fields = extract_comments_above_fields(dataclass_obj)
docstring = [content for content in commented_fields.values()]
return '\n'.join(docstring)


def get_help_message(dataclass_obj):
heading = """
This script uses Hydra (https://hydra.cc/) for dynamic configuration management.
You can apply Hydra's command-line syntax for overriding configuration values directly.
Below are the available configuration options and their default values:
""".strip()

docstring = get_fields_docstring(dataclass_obj)

return f"{heading}\n{'-' * 75}\n{docstring}"
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dependencies = [
'tqdm',
'pyyaml',
'numpy',
'requests',
'sympy',
]

[tool.setuptools.packages.find]
Expand Down