Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ repos:
hooks:
- id: ruff
args: ["--fix"]
- id: ruff
args: ["check", "--select", "I", "--fix"]
- id: ruff-format
2 changes: 1 addition & 1 deletion docs/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tempfile
import json
import tempfile


def make_dpo_dataset():
Expand Down
3 changes: 1 addition & 2 deletions examples/convert_dcp_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

import argparse
import json
import os
import torch

from nemo_rl.utils.native_checkpoint import convert_dcp_to_hf


Expand Down
11 changes: 5 additions & 6 deletions examples/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@
import os
import pprint
import warnings
from typing import Dict, Any
from typing import Any, Dict

from omegaconf import OmegaConf

from nemo_rl.algorithms.dpo import MasterConfig, dpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir
from nemo_rl.data import DataConfig, hf_datasets
from nemo_rl.data.datasets import AllTaskProcessedDataset
from nemo_rl.data.interfaces import TaskDataSpec, DatumSpec
from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec
from nemo_rl.data.llm_message_utils import get_formatted_message_log
from transformers import AutoTokenizer
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir


def parse_args():
Expand Down
2 changes: 1 addition & 1 deletion examples/run_grpo_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, env_configs):

# Load OpenMathInstruct2Dataset using nemo rl datasets
if data_config["dataset_name"] == "OpenMathInstruct-2":
print(f"Loading nvidia/OpenMathInstruct2Dataset for training and validation")
print("Loading nvidia/OpenMathInstruct2Dataset for training and validation")
data = OpenMathInstruct2Dataset()
else:
raise ValueError(f"No processor for dataset {data_config['dataset_name']}.")
Expand Down
25 changes: 11 additions & 14 deletions examples/run_grpo_sliding_puzzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,29 @@
# limitations under the License.

import argparse
import itertools
import os
import pprint
import itertools
from typing import Any, Dict, Tuple, Iterator
import random
from typing import Any, Dict, Iterator, Tuple

from omegaconf import OmegaConf
from transformers import AutoTokenizer

from torch.utils.data import IterableDataset
from transformers import AutoTokenizer

from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup
from nemo_rl.algorithms.utils import get_tokenizer

from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.models.generation.interfaces import configure_generation_config
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir

from nemo_rl.environments.games.sliding_puzzle import (
SlidingPuzzleGameLogic,
SlidingPuzzleEnv,
SlidingPuzzleConfig,
SlidingPuzzleEnv,
SlidingPuzzleGameLogic,
SlidingPuzzleMetadata,
)
from nemo_rl.data.interfaces import LLMMessageLogType, DatumSpec
from nemo_rl.models.generation.interfaces import configure_generation_config
from nemo_rl.utils.config import load_config, parse_hydra_overrides
from nemo_rl.utils.logger import get_next_experiment_dir


def parse_args():
Expand Down Expand Up @@ -133,7 +130,7 @@ def __init__(
self.length = length

def __iter__(self) -> Iterator[DatumSpec]:
print(f"Starting IterablePuzzleDataset (indefinite generation).")
print("Starting IterablePuzzleDataset (indefinite generation).")
# Use itertools.count for an infinite index generator
for i in itertools.count():
yield generate_puzzle_datum(
Expand Down Expand Up @@ -166,7 +163,7 @@ def setup_puzzle_data(
task_to_env = {task_name: env}
print(f"Environment '{task_name}' created.")

print(f"Creating Sliding Puzzle dataset...")
print("Creating Sliding Puzzle dataset...")
training_dataset = IterablePuzzleDataset(
tokenizer=tokenizer,
game_config=dict(env_config["cfg"]["game_config"]),
Expand Down
6 changes: 3 additions & 3 deletions examples/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
import os
import pprint
from functools import partial
from typing import Dict, Any
from typing import Any, Dict

from omegaconf import OmegaConf
from transformers import AutoTokenizer

from nemo_rl.algorithms.sft import MasterConfig, sft_train, setup
from nemo_rl.algorithms.sft import MasterConfig, setup, sft_train
from nemo_rl.algorithms.utils import get_tokenizer
from nemo_rl.data import DataConfig, hf_datasets
from nemo_rl.data.datasets import AllTaskProcessedDataset
from nemo_rl.data.interfaces import TaskDataSpec, DatumSpec
from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec
from nemo_rl.data.llm_message_utils import get_formatted_message_log
from nemo_rl.distributed.virtual_cluster import init_ray
from nemo_rl.utils.config import load_config
Expand Down
1 change: 1 addition & 0 deletions nemo_rl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os

from nemo_rl.package_info import (
__contact_emails__,
__contact_names__,
Expand Down
13 changes: 6 additions & 7 deletions nemo_rl/algorithms/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,25 @@
from collections import defaultdict
from functools import partial
from pathlib import Path
from transformers import AutoTokenizer
from typing import Optional, Tuple, TypedDict
from tqdm import tqdm

import numpy as np
import torch
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer

from nemo_rl.algorithms.loss_functions import (
DPOLossFn,
)
from nemo_rl.algorithms.utils import set_seed, get_tokenizer
from nemo_rl.algorithms.utils import set_seed
from nemo_rl.data import DataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset, dpo_collate_fn
from nemo_rl.data.interfaces import TaskDataSpec
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.models.interfaces import PolicyInterface
from nemo_rl.models.policy.hf_policy import HfPolicy
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.utils.checkpoint import CheckpointManager, CheckpointingConfig
from nemo_rl.models.policy.hf_policy import HfPolicy
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import Logger, LoggerConfig
from nemo_rl.utils.timer import Timer

Expand Down Expand Up @@ -217,7 +216,7 @@ def setup(
init_reference_model=True,
)
loss_fn = DPOLossFn(master_config["dpo"])
print(f" ✓ Model initialized")
print(" ✓ Model initialized")

print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
Expand Down
52 changes: 22 additions & 30 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,59 +11,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Tuple, TypedDict, Iterable, Optional, List

import os
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, TypedDict

import numpy as np
import ray
import torch
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer

from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt

from nemo_rl.environments.interfaces import (
EnvironmentInterface,
EnvironmentReturn,
)
from nemo_rl.distributed.virtual_cluster import RayVirtualCluster
from nemo_rl.data.interfaces import (
DatumSpec,
LLMMessageLogType,
FlatMessagesType,
)
from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn
from nemo_rl.models.policy.hf_policy import HfPolicy
from nemo_rl.models.generation.vllm import VllmGeneration
from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.loss_functions import (
ClippedPGLossConfig,
ClippedPGLossDataDict,
ClippedPGLossFn,
)
from nemo_rl.algorithms.interfaces import LossFunction
from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt
from nemo_rl.data import DataConfig
from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn
from nemo_rl.data.interfaces import (
DatumSpec,
)
from nemo_rl.data.llm_message_utils import (
get_keys_from_message_log,
batched_message_log_to_flat_message,
get_keys_from_message_log,
)
from nemo_rl.utils.logger import (
print_message_log_samples,
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.environments.interfaces import (
EnvironmentInterface,
)
from nemo_rl.distributed.virtual_cluster import ClusterConfig
from nemo_rl.environments.math_environment import MathEnvConfig
from nemo_rl.experience.rollouts import run_multi_turn_rollout
from nemo_rl.models.generation.interfaces import (
GenerationInterface,
GenerationDatumSpec,
)
from nemo_rl.models.generation.vllm import VllmGeneration
from nemo_rl.models.interfaces import PolicyInterface
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.utils.logger import Logger, LoggerConfig
from nemo_rl.models.policy.hf_policy import HfPolicy
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import (
Logger,
LoggerConfig,
print_message_log_samples,
)
from nemo_rl.utils.timer import Timer
from nemo_rl.utils.checkpoint import CheckpointManager, CheckpointingConfig
from nemo_rl.experience.rollouts import run_multi_turn_rollout


# ===============================================================================
# Configuration
Expand Down
1 change: 0 additions & 1 deletion nemo_rl/algorithms/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
calculate_kl_penalty_joschu2020,
masked_mean,
)

from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.models.dtensor.parallelize import (
get_logprobs_from_vocab_parallel_logits,
Expand Down
9 changes: 5 additions & 4 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
# limitations under the License.
import os
import warnings
from transformers import AutoTokenizer
from pathlib import Path
from typing import Optional, Tuple, TypedDict

import numpy as np
import torch
from torchdata.stateful_dataloader import StatefulDataLoader
from transformers import AutoTokenizer

from nemo_rl.algorithms.loss_functions import (
NLLLoss,
)
Expand All @@ -34,9 +35,9 @@
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster
from nemo_rl.models.interfaces import PolicyInterface
from nemo_rl.models.policy.hf_policy import HfPolicy
from nemo_rl.models.policy import PolicyConfig
from nemo_rl.utils.checkpoint import CheckpointManager, CheckpointingConfig
from nemo_rl.models.policy.hf_policy import HfPolicy
from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager
from nemo_rl.utils.logger import Logger, LoggerConfig
from nemo_rl.utils.timer import Timer

Expand Down Expand Up @@ -195,7 +196,7 @@ def setup(
init_reference_model=False,
)
loss_fn = NLLLoss()
print(f" ✓ Model initialized")
print(" ✓ Model initialized")

print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
Expand Down
1 change: 0 additions & 1 deletion nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import numpy as np
import torch
from torch.masked import as_masked_tensor
from transformers import AutoTokenizer

from nemo_rl.data import hf_datasets
Expand Down
6 changes: 3 additions & 3 deletions nemo_rl/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Union, Tuple
from typing import Any, Dict, List, Tuple, Union

import torch
from datasets import Dataset

from nemo_rl.data.interfaces import (
TaskDataSpec,
TaskDataProcessFnCallable,
DatumSpec,
TaskDataProcessFnCallable,
TaskDataSpec,
)
from nemo_rl.data.llm_message_utils import (
add_loss_mask_to_message_log,
Expand Down
2 changes: 1 addition & 1 deletion nemo_rl/data/hf_datasets/helpsteer3.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from datasets import load_dataset
from absl import logging
from datasets import load_dataset

from nemo_rl.data.interfaces import TaskDataSpec

Expand Down
7 changes: 3 additions & 4 deletions nemo_rl/data/hf_datasets/oasst.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import copy
import gzip
import json
import os
import random

import requests
import copy
from dataclasses import dataclass
from typing import Optional

from nemo_rl.data.interfaces import TaskDataSpec

Expand Down
5 changes: 2 additions & 3 deletions nemo_rl/data/hf_datasets/openmathinstruct2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

from datasets import load_dataset
from dataclasses import dataclass

from nemo_rl.data.interfaces import TaskDataSpec

Expand All @@ -39,7 +38,7 @@ def format_math(data):
def prepare_openinstructmath2_dataset(split: str = "train_1M", seed=42, test_size=0.05):
"""Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split."""
print(
f"WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets."
"WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets."
)

# Load the original dataset
Expand Down
Loading
Loading