diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8fe30b1e26..c83f1d9c49 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,4 +14,6 @@ repos: hooks: - id: ruff args: ["--fix"] + - id: ruff + args: ["check", "--select", "I", "--fix"] - id: ruff-format diff --git a/docs/helpers.py b/docs/helpers.py index 805d5877d1..538e528bb7 100755 --- a/docs/helpers.py +++ b/docs/helpers.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tempfile import json +import tempfile def make_dpo_dataset(): diff --git a/examples/convert_dcp_to_hf.py b/examples/convert_dcp_to_hf.py index c185d424bf..5dfb1359ab 100644 --- a/examples/convert_dcp_to_hf.py +++ b/examples/convert_dcp_to_hf.py @@ -14,8 +14,7 @@ import argparse import json -import os -import torch + from nemo_rl.utils.native_checkpoint import convert_dcp_to_hf diff --git a/examples/run_dpo.py b/examples/run_dpo.py index 93de160b6e..2ebfa6bc05 100644 --- a/examples/run_dpo.py +++ b/examples/run_dpo.py @@ -16,21 +16,20 @@ import os import pprint import warnings -from typing import Dict, Any +from typing import Any, Dict from omegaconf import OmegaConf from nemo_rl.algorithms.dpo import MasterConfig, dpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.utils.config import load_config, parse_hydra_overrides -from nemo_rl.utils.logger import get_next_experiment_dir from nemo_rl.data import DataConfig, hf_datasets from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import TaskDataSpec, DatumSpec +from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec from nemo_rl.data.llm_message_utils import get_formatted_message_log -from transformers import AutoTokenizer +from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.models.policy import PolicyConfig +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir def parse_args(): diff --git a/examples/run_grpo_math.py b/examples/run_grpo_math.py index 2e70cc889b..c687ed9c07 100644 --- a/examples/run_grpo_math.py +++ b/examples/run_grpo_math.py @@ -177,7 +177,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig, env_configs): # Load OpenMathInstruct2Dataset using nemo rl datasets if data_config["dataset_name"] == "OpenMathInstruct-2": - print(f"Loading nvidia/OpenMathInstruct2Dataset for training and validation") + print("Loading nvidia/OpenMathInstruct2Dataset for training and validation") data = OpenMathInstruct2Dataset() else: raise ValueError(f"No processor for dataset {data_config['dataset_name']}.") diff --git a/examples/run_grpo_sliding_puzzle.py b/examples/run_grpo_sliding_puzzle.py index 2076f11fd0..fbdd22da75 100644 --- a/examples/run_grpo_sliding_puzzle.py +++ b/examples/run_grpo_sliding_puzzle.py @@ -13,32 +13,29 @@ # limitations under the License. import argparse +import itertools import os import pprint -import itertools -from typing import Any, Dict, Tuple, Iterator import random +from typing import Any, Dict, Iterator, Tuple from omegaconf import OmegaConf -from transformers import AutoTokenizer - from torch.utils.data import IterableDataset +from transformers import AutoTokenizer from nemo_rl.algorithms.grpo import MasterConfig, grpo_train, setup from nemo_rl.algorithms.utils import get_tokenizer - +from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType from nemo_rl.distributed.virtual_cluster import init_ray -from nemo_rl.models.generation.interfaces import configure_generation_config -from nemo_rl.utils.config import load_config, parse_hydra_overrides -from nemo_rl.utils.logger import get_next_experiment_dir - from nemo_rl.environments.games.sliding_puzzle import ( - SlidingPuzzleGameLogic, - SlidingPuzzleEnv, SlidingPuzzleConfig, + SlidingPuzzleEnv, + SlidingPuzzleGameLogic, SlidingPuzzleMetadata, ) -from nemo_rl.data.interfaces import LLMMessageLogType, DatumSpec +from nemo_rl.models.generation.interfaces import configure_generation_config +from nemo_rl.utils.config import load_config, parse_hydra_overrides +from nemo_rl.utils.logger import get_next_experiment_dir def parse_args(): @@ -133,7 +130,7 @@ def __init__( self.length = length def __iter__(self) -> Iterator[DatumSpec]: - print(f"Starting IterablePuzzleDataset (indefinite generation).") + print("Starting IterablePuzzleDataset (indefinite generation).") # Use itertools.count for an infinite index generator for i in itertools.count(): yield generate_puzzle_datum( @@ -166,7 +163,7 @@ def setup_puzzle_data( task_to_env = {task_name: env} print(f"Environment '{task_name}' created.") - print(f"Creating Sliding Puzzle dataset...") + print("Creating Sliding Puzzle dataset...") training_dataset = IterablePuzzleDataset( tokenizer=tokenizer, game_config=dict(env_config["cfg"]["game_config"]), diff --git a/examples/run_sft.py b/examples/run_sft.py index 2b7dd9489f..c7fbc5ba74 100644 --- a/examples/run_sft.py +++ b/examples/run_sft.py @@ -16,16 +16,16 @@ import os import pprint from functools import partial -from typing import Dict, Any +from typing import Any, Dict from omegaconf import OmegaConf from transformers import AutoTokenizer -from nemo_rl.algorithms.sft import MasterConfig, sft_train, setup +from nemo_rl.algorithms.sft import MasterConfig, setup, sft_train from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data import DataConfig, hf_datasets from nemo_rl.data.datasets import AllTaskProcessedDataset -from nemo_rl.data.interfaces import TaskDataSpec, DatumSpec +from nemo_rl.data.interfaces import DatumSpec, TaskDataSpec from nemo_rl.data.llm_message_utils import get_formatted_message_log from nemo_rl.distributed.virtual_cluster import init_ray from nemo_rl.utils.config import load_config diff --git a/nemo_rl/__init__.py b/nemo_rl/__init__.py index c755e5ed0f..753a669fc8 100644 --- a/nemo_rl/__init__.py +++ b/nemo_rl/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os + from nemo_rl.package_info import ( __contact_emails__, __contact_names__, diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 1545c050b6..dd6607ef9d 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -16,26 +16,25 @@ from collections import defaultdict from functools import partial from pathlib import Path -from transformers import AutoTokenizer from typing import Optional, Tuple, TypedDict -from tqdm import tqdm import numpy as np import torch from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + from nemo_rl.algorithms.loss_functions import ( DPOLossFn, ) -from nemo_rl.algorithms.utils import set_seed, get_tokenizer +from nemo_rl.algorithms.utils import set_seed from nemo_rl.data import DataConfig from nemo_rl.data.datasets import AllTaskProcessedDataset, dpo_collate_fn from nemo_rl.data.interfaces import TaskDataSpec -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_rl.models.interfaces import PolicyInterface -from nemo_rl.models.policy.hf_policy import HfPolicy from nemo_rl.models.policy import PolicyConfig -from nemo_rl.utils.checkpoint import CheckpointManager, CheckpointingConfig +from nemo_rl.models.policy.hf_policy import HfPolicy +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.timer import Timer @@ -217,7 +216,7 @@ def setup( init_reference_model=True, ) loss_fn = DPOLossFn(master_config["dpo"]) - print(f" ✓ Model initialized") + print(" ✓ Model initialized") print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 88fd3f803b..952a6c172a 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -11,59 +11,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Tuple, TypedDict, Iterable, Optional, List - import os from pathlib import Path +from typing import Any, Dict, Optional, Tuple, TypedDict + import numpy as np -import ray import torch from torchdata.stateful_dataloader import StatefulDataLoader from transformers import AutoTokenizer -from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt - -from nemo_rl.environments.interfaces import ( - EnvironmentInterface, - EnvironmentReturn, -) -from nemo_rl.distributed.virtual_cluster import RayVirtualCluster -from nemo_rl.data.interfaces import ( - DatumSpec, - LLMMessageLogType, - FlatMessagesType, -) -from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn -from nemo_rl.models.policy.hf_policy import HfPolicy -from nemo_rl.models.generation.vllm import VllmGeneration +from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.algorithms.loss_functions import ( ClippedPGLossConfig, ClippedPGLossDataDict, ClippedPGLossFn, ) -from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt from nemo_rl.data import DataConfig +from nemo_rl.data.datasets import AllTaskProcessedDataset, rl_collate_fn +from nemo_rl.data.interfaces import ( + DatumSpec, +) from nemo_rl.data.llm_message_utils import ( - get_keys_from_message_log, batched_message_log_to_flat_message, + get_keys_from_message_log, ) -from nemo_rl.utils.logger import ( - print_message_log_samples, +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster +from nemo_rl.environments.interfaces import ( + EnvironmentInterface, ) -from nemo_rl.distributed.virtual_cluster import ClusterConfig -from nemo_rl.environments.math_environment import MathEnvConfig +from nemo_rl.experience.rollouts import run_multi_turn_rollout from nemo_rl.models.generation.interfaces import ( GenerationInterface, - GenerationDatumSpec, ) +from nemo_rl.models.generation.vllm import VllmGeneration from nemo_rl.models.interfaces import PolicyInterface from nemo_rl.models.policy import PolicyConfig -from nemo_rl.utils.logger import Logger, LoggerConfig +from nemo_rl.models.policy.hf_policy import HfPolicy +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager +from nemo_rl.utils.logger import ( + Logger, + LoggerConfig, + print_message_log_samples, +) from nemo_rl.utils.timer import Timer -from nemo_rl.utils.checkpoint import CheckpointManager, CheckpointingConfig -from nemo_rl.experience.rollouts import run_multi_turn_rollout - # =============================================================================== # Configuration diff --git a/nemo_rl/algorithms/loss_functions.py b/nemo_rl/algorithms/loss_functions.py index 26441fe616..a41c66b115 100644 --- a/nemo_rl/algorithms/loss_functions.py +++ b/nemo_rl/algorithms/loss_functions.py @@ -20,7 +20,6 @@ calculate_kl_penalty_joschu2020, masked_mean, ) - from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.dtensor.parallelize import ( get_logprobs_from_vocab_parallel_logits, diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 6400d287ec..8b5ffcddfd 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -13,13 +13,14 @@ # limitations under the License. import os import warnings -from transformers import AutoTokenizer from pathlib import Path from typing import Optional, Tuple, TypedDict import numpy as np import torch from torchdata.stateful_dataloader import StatefulDataLoader +from transformers import AutoTokenizer + from nemo_rl.algorithms.loss_functions import ( NLLLoss, ) @@ -34,9 +35,9 @@ from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import ClusterConfig, RayVirtualCluster from nemo_rl.models.interfaces import PolicyInterface -from nemo_rl.models.policy.hf_policy import HfPolicy from nemo_rl.models.policy import PolicyConfig -from nemo_rl.utils.checkpoint import CheckpointManager, CheckpointingConfig +from nemo_rl.models.policy.hf_policy import HfPolicy +from nemo_rl.utils.checkpoint import CheckpointingConfig, CheckpointManager from nemo_rl.utils.logger import Logger, LoggerConfig from nemo_rl.utils.timer import Timer @@ -195,7 +196,7 @@ def setup( init_reference_model=False, ) loss_fn = NLLLoss() - print(f" ✓ Model initialized") + print(" ✓ Model initialized") print("\n" + "=" * 60) print(" " * 18 + "SETUP COMPLETE") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index 81124ef85e..84feced16c 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -17,7 +17,6 @@ import numpy as np import torch -from torch.masked import as_masked_tensor from transformers import AutoTokenizer from nemo_rl.data import hf_datasets diff --git a/nemo_rl/data/datasets.py b/nemo_rl/data/datasets.py index 99de1d6520..70b1dd9786 100644 --- a/nemo_rl/data/datasets.py +++ b/nemo_rl/data/datasets.py @@ -11,15 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Union, Tuple +from typing import Any, Dict, List, Tuple, Union import torch from datasets import Dataset from nemo_rl.data.interfaces import ( - TaskDataSpec, - TaskDataProcessFnCallable, DatumSpec, + TaskDataProcessFnCallable, + TaskDataSpec, ) from nemo_rl.data.llm_message_utils import ( add_loss_mask_to_message_log, diff --git a/nemo_rl/data/hf_datasets/helpsteer3.py b/nemo_rl/data/hf_datasets/helpsteer3.py index 73e8828927..a6f67e6032 100644 --- a/nemo_rl/data/hf_datasets/helpsteer3.py +++ b/nemo_rl/data/hf_datasets/helpsteer3.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from datasets import load_dataset from absl import logging +from datasets import load_dataset from nemo_rl.data.interfaces import TaskDataSpec diff --git a/nemo_rl/data/hf_datasets/oasst.py b/nemo_rl/data/hf_datasets/oasst.py index 45307f8704..decf9769dc 100644 --- a/nemo_rl/data/hf_datasets/oasst.py +++ b/nemo_rl/data/hf_datasets/oasst.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +import copy import gzip +import json import os import random + import requests -import copy -from dataclasses import dataclass -from typing import Optional from nemo_rl.data.interfaces import TaskDataSpec diff --git a/nemo_rl/data/hf_datasets/openmathinstruct2.py b/nemo_rl/data/hf_datasets/openmathinstruct2.py index 3c1fc318b4..457db35c7e 100644 --- a/nemo_rl/data/hf_datasets/openmathinstruct2.py +++ b/nemo_rl/data/hf_datasets/openmathinstruct2.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional + from datasets import load_dataset -from dataclasses import dataclass from nemo_rl.data.interfaces import TaskDataSpec @@ -39,7 +38,7 @@ def format_math(data): def prepare_openinstructmath2_dataset(split: str = "train_1M", seed=42, test_size=0.05): """Load and split the OpenMathInstruct-2 dataset into train and validation sets using HF's train_test_split.""" print( - f"WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets." + "WARNING: For reproducible experiments, preprocess the dataset once and define your own HfDataset subclass that directly uses the preprocessed datasets." ) # Load the original dataset diff --git a/nemo_rl/data/hf_datasets/prompt_response_dataset.py b/nemo_rl/data/hf_datasets/prompt_response_dataset.py index a8740527fb..1f35d97f01 100644 --- a/nemo_rl/data/hf_datasets/prompt_response_dataset.py +++ b/nemo_rl/data/hf_datasets/prompt_response_dataset.py @@ -13,6 +13,7 @@ # limitations under the License. from datasets import load_dataset + from nemo_rl.data.interfaces import TaskDataSpec diff --git a/nemo_rl/data/hf_datasets/squad.py b/nemo_rl/data/hf_datasets/squad.py index 4e406011bb..1400dad5ba 100644 --- a/nemo_rl/data/hf_datasets/squad.py +++ b/nemo_rl/data/hf_datasets/squad.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional + from datasets import load_dataset from nemo_rl.data.interfaces import TaskDataSpec diff --git a/nemo_rl/data/interfaces.py b/nemo_rl/data/interfaces.py index 33a44d555d..54d9a5cf6d 100644 --- a/nemo_rl/data/interfaces.py +++ b/nemo_rl/data/interfaces.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, TypedDict, Optional, Union, Protocol import os +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Protocol, TypedDict, Union import torch diff --git a/nemo_rl/data/llm_message_utils.py b/nemo_rl/data/llm_message_utils.py index ec7ccdbed5..f2d24fc421 100644 --- a/nemo_rl/data/llm_message_utils.py +++ b/nemo_rl/data/llm_message_utils.py @@ -16,9 +16,10 @@ import torch from datasets import Dataset + from nemo_rl.data.interfaces import ( - LLMMessageLogType, FlatMessagesType, + LLMMessageLogType, TaskDataSpec, ) from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/nemo_rl/distributed/batched_data_dict.py b/nemo_rl/distributed/batched_data_dict.py index c86562a9f1..957fcdd121 100644 --- a/nemo_rl/distributed/batched_data_dict.py +++ b/nemo_rl/distributed/batched_data_dict.py @@ -11,16 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy from collections import UserDict -from typing import List, Dict, Optional, Iterator, TypeVar, Any, Generic, Union -from typing_extensions import Self +from copy import deepcopy +from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar, Union import torch +from typing_extensions import Self from nemo_rl.distributed.collectives import ( - rebalance_nd_tensor, gather_jagged_object_lists, + rebalance_nd_tensor, ) DictT = TypeVar("DictT", bound=Dict[str, Any]) diff --git a/nemo_rl/distributed/virtual_cluster.py b/nemo_rl/distributed/virtual_cluster.py index 944ddb442e..a6579e2fd6 100644 --- a/nemo_rl/distributed/virtual_cluster.py +++ b/nemo_rl/distributed/virtual_cluster.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import functools -from typing import List, TypedDict, Optional - -from copy import deepcopy -import sys -import os -import ray import logging +import os +import sys import time +from typing import List, Optional, TypedDict + +import ray from ray.util.placement_group import placement_group, remove_placement_group from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy diff --git a/nemo_rl/distributed/worker_groups.py b/nemo_rl/distributed/worker_groups.py index c982718067..76c1be6279 100644 --- a/nemo_rl/distributed/worker_groups.py +++ b/nemo_rl/distributed/worker_groups.py @@ -11,19 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Union, Dict, Any +import os import warnings +from copy import deepcopy from dataclasses import dataclass +from typing import Any, Dict, List, Literal, Optional, Union -import os import ray -from typing import Literal -from copy import deepcopy from ray.util.placement_group import PlacementGroup from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.distributed.batched_data_dict import SlicedDataDict +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.utils.venvs import create_local_venv diff --git a/nemo_rl/environments/games/sliding_puzzle.py b/nemo_rl/environments/games/sliding_puzzle.py index 6a41f004c5..b8bc536f41 100644 --- a/nemo_rl/environments/games/sliding_puzzle.py +++ b/nemo_rl/environments/games/sliding_puzzle.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy +import random +from typing import Any, Dict, List, Optional, Tuple, TypedDict + import ray import torch -from typing import Dict, List, Tuple, Optional, TypedDict, Any -import random -import copy -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES from nemo_rl.environments.interfaces import ( EnvironmentInterface, EnvironmentReturn, ) -from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES class SlidingPuzzleConfig(TypedDict): diff --git a/nemo_rl/environments/interfaces.py b/nemo_rl/environments/interfaces.py index 447f8cc318..e432b829f8 100644 --- a/nemo_rl/environments/interfaces.py +++ b/nemo_rl/environments/interfaces.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from typing import Dict, List, Tuple, NamedTuple, Optional +from typing import Dict, List, NamedTuple, Optional, Tuple from torch import Tensor -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict class EnvironmentReturn(NamedTuple): diff --git a/nemo_rl/environments/math_environment.py b/nemo_rl/environments/math_environment.py index e82cf36050..8da0528652 100644 --- a/nemo_rl/environments/math_environment.py +++ b/nemo_rl/environments/math_environment.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from itertools import tee -from typing import Dict, List, Tuple, TypedDict, Optional +from typing import Dict, List, Optional, Tuple, TypedDict import ray import torch from math_verify import parse, verify from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES from nemo_rl.environments.interfaces import ( EnvironmentInterface, EnvironmentReturn, @@ -27,7 +27,6 @@ calculate_pass_rate_per_prompt, ) from nemo_rl.environments.utils import chunk_list_to_workers -from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES class MathEnvConfig(TypedDict): diff --git a/nemo_rl/environments/utils.py b/nemo_rl/environments/utils.py index d54a74efcf..f6442151d9 100644 --- a/nemo_rl/environments/utils.py +++ b/nemo_rl/environments/utils.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Any +from typing import Any, List def chunk_list_to_workers(to_chunk: List[Any], num_workers: int) -> List[List[Any]]: diff --git a/nemo_rl/evals/eval.py b/nemo_rl/evals/eval.py index 217c4c125b..7f54ce00ba 100644 --- a/nemo_rl/evals/eval.py +++ b/nemo_rl/evals/eval.py @@ -28,7 +28,6 @@ from nemo_rl.models.generation.interfaces import GenerationConfig from nemo_rl.models.generation.vllm import VllmGeneration - # =============================================================================== # Configuration # =============================================================================== diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index e781ab73a7..a556a32a42 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -15,29 +15,29 @@ # Generate rollouts for arbitrary environments # Supports multi-turn rollouts and many simultaneous environments (E.g. you can train on math, code, multi-turn games and more at once) +from typing import Any, Dict, List, Tuple + +import ray import torch -from typing import List, Tuple, Dict, Optional, Any, NamedTuple from transformers import AutoTokenizer -import ray -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.data.interfaces import ( DatumSpec, - LLMMessageLogType, FlatMessagesType, ) from nemo_rl.data.llm_message_utils import ( - get_keys_from_message_log, batched_message_log_to_flat_message, + get_keys_from_message_log, ) -from nemo_rl.models.generation.interfaces import ( - GenerationInterface, - GenerationDatumSpec, -) +from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( EnvironmentInterface, EnvironmentReturn, ) +from nemo_rl.models.generation.interfaces import ( + GenerationDatumSpec, + GenerationInterface, +) def generate_responses( diff --git a/nemo_rl/models/dtensor/parallelize.py b/nemo_rl/models/dtensor/parallelize.py index 8922f203c7..3ae86d70cc 100644 --- a/nemo_rl/models/dtensor/parallelize.py +++ b/nemo_rl/models/dtensor/parallelize.py @@ -12,29 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +from typing import List, Union -from torch.distributed.tensor import DTensor import torch from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper, ) -from torch.distributed.fsdp import fully_shard, CPUOffloadPolicy, MixedPrecisionPolicy +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, fully_shard +from torch.distributed.tensor import DTensor from torch.distributed.tensor.parallel import ( ColwiseParallel, RowwiseParallel, - parallelize_module, SequenceParallel, + parallelize_module, ) - -from torch.distributed.tensor import DTensor -from torch.distributed.tensor.placement_types import Shard, Replicate -from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from torch.distributed.tensor.placement_types import Replicate, Shard from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM -from typing import Union, List from nemo_rl.distributed.model_utils import from_parallel_logits_to_logprobs -from torch.distributed.device_mesh import DeviceMesh def _parallelize_llama( diff --git a/nemo_rl/models/generation/interfaces.py b/nemo_rl/models/generation/interfaces.py index 4d6a5ec4d7..f222f55c86 100644 --- a/nemo_rl/models/generation/interfaces.py +++ b/nemo_rl/models/generation/interfaces.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from typing import Any, TypedDict, Union, Tuple, List, Optional +from typing import Any, List, Optional, Tuple, TypedDict, Union import torch from transformers import AutoTokenizer diff --git a/nemo_rl/models/generation/vllm.py b/nemo_rl/models/generation/vllm.py index dee3040cb3..59fcc26320 100644 --- a/nemo_rl/models/generation/vllm.py +++ b/nemo_rl/models/generation/vllm.py @@ -12,26 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union, List, TypedDict import gc -import warnings +from typing import List, Optional, TypedDict, Union import ray import torch +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import ( + PY_EXECUTABLES, + RayVirtualCluster, +) +from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup from nemo_rl.models.generation.interfaces import ( - GenerationInterface, + GenerationConfig, GenerationDatumSpec, + GenerationInterface, GenerationOutputSpec, verify_right_padding, - GenerationConfig, -) -from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.distributed.virtual_cluster import ( - RayVirtualCluster, - PY_EXECUTABLES, ) -from nemo_rl.distributed.worker_groups import RayWorkerGroup, RayWorkerBuilder class VllmSpecificArgs(TypedDict): @@ -155,7 +154,7 @@ def __init__( self.SamplingParams = vllm.SamplingParams except ImportError: raise ImportError( - f"vLLM is not installed. Please check that VllmGenerationWorker.DEFAULT_PY_EXECUTABLE covers the vllm dependency. " + "vLLM is not installed. Please check that VllmGenerationWorker.DEFAULT_PY_EXECUTABLE covers the vllm dependency. " "If you are working interactively, you can install by running `uv sync --extra vllm` anywhere in the repo." ) vllm_kwargs = self.cfg.get("vllm_kwargs", {}).copy() diff --git a/nemo_rl/models/generation/vllm_backend.py b/nemo_rl/models/generation/vllm_backend.py index 41498b5312..c14543df12 100644 --- a/nemo_rl/models/generation/vllm_backend.py +++ b/nemo_rl/models/generation/vllm_backend.py @@ -14,10 +14,10 @@ import torch try: - import vllm + import vllm # noqa: F401 except ImportError: raise ImportError( - f"vLLM is not installed. Please check that VllmGenerationWorker.DEFAULT_PY_EXECUTABLE covers the vllm dependency. " + "vLLM is not installed. Please check that VllmGenerationWorker.DEFAULT_PY_EXECUTABLE covers the vllm dependency. " "If you are working interactively, you can install by running `uv sync --extra vllm` anywhere in the repo." ) diff --git a/nemo_rl/models/interfaces.py b/nemo_rl/models/interfaces.py index f194363edc..f59e3ade0b 100644 --- a/nemo_rl/models/interfaces.py +++ b/nemo_rl/models/interfaces.py @@ -14,8 +14,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.algorithms.interfaces import LossFunction +from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.models.generation.interfaces import GenerationDatumSpec diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index ca806c4675..47714fb0f5 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TypedDict, Optional, Union +from typing import Optional, TypedDict, Union from nemo_rl.models.generation.interfaces import GenerationConfig diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 27dfb4ac77..29ecd46452 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -12,43 +12,40 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import gc - +import os from collections import defaultdict from contextlib import contextmanager, nullcontext -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional, Tuple, Union import ray import torch +from torch import nn from torch.distributed.fsdp import ( FSDPModule, ) +from torch.distributed.tensor import DTensor from transformers import AutoModelForCausalLM, AutoTokenizer from transformers.integrations.accelerate import find_tied_parameters -from nemo_rl.models.dtensor.parallelize import _parallelize_model from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.utils import import_class_from_path from nemo_rl.distributed.virtual_cluster import ( PY_EXECUTABLES, ) -from typing import Iterable, Tuple, Union -from torch.distributed.tensor import DTensor from nemo_rl.models.dtensor.parallelize import ( - get_logprobs_from_vocab_parallel_logits, - get_grad_norm, + _parallelize_model, clip_grad_by_total_norm_, + get_grad_norm, + get_logprobs_from_vocab_parallel_logits, to_local_if_dtensor, ) +from nemo_rl.models.policy import PolicyConfig +from nemo_rl.models.policy.utils import get_gpu_info, import_class_from_path from nemo_rl.utils.native_checkpoint import ( - save_checkpoint, load_checkpoint, + save_checkpoint, ) -from torch import nn -from nemo_rl.models.policy.utils import get_gpu_info @contextmanager diff --git a/nemo_rl/models/policy/fsdp1_policy_worker.py b/nemo_rl/models/policy/fsdp1_policy_worker.py index b25e930b6f..96e3e71e87 100644 --- a/nemo_rl/models/policy/fsdp1_policy_worker.py +++ b/nemo_rl/models/policy/fsdp1_policy_worker.py @@ -13,11 +13,11 @@ # limitations under the License. import gc +import os import warnings from collections import defaultdict from contextlib import contextmanager, nullcontext from typing import Any, Dict, Optional -import os import ray import torch @@ -28,28 +28,25 @@ MixedPrecision, ) from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.integrations.accelerate import find_tied_parameters from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import ( + PY_EXECUTABLES, +) from nemo_rl.models.generation.interfaces import ( GenerationDatumSpec, GenerationOutputSpec, verify_right_padding, ) - -from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.integrations.accelerate import find_tied_parameters from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.utils import import_class_from_path -from nemo_rl.distributed.virtual_cluster import ( - PY_EXECUTABLES, -) +from nemo_rl.models.policy.utils import get_gpu_info, import_class_from_path from nemo_rl.utils.native_checkpoint import ( - save_checkpoint, load_checkpoint, + save_checkpoint, ) -from nemo_rl.models.policy.utils import get_gpu_info @ray.remote diff --git a/nemo_rl/models/policy/hf_policy.py b/nemo_rl/models/policy/hf_policy.py index b092d7eda3..2a579e3bcd 100644 --- a/nemo_rl/models/policy/hf_policy.py +++ b/nemo_rl/models/policy/hf_policy.py @@ -23,14 +23,14 @@ from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.distributed.worker_groups import RayWorkerBuilder, RayWorkerGroup from nemo_rl.models.generation.interfaces import ( - GenerationInterface, GenerationDatumSpec, + GenerationInterface, GenerationOutputSpec, ) from nemo_rl.models.interfaces import PolicyInterface from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.fsdp1_policy_worker import FSDP1PolicyWorker from nemo_rl.models.policy.dtensor_policy_worker import DTensorPolicyWorker +from nemo_rl.models.policy.fsdp1_policy_worker import FSDP1PolicyWorker class HfPolicy(PolicyInterface, GenerationInterface): diff --git a/nemo_rl/utils/checkpoint.py b/nemo_rl/utils/checkpoint.py index 2425996400..bc916d3d7e 100644 --- a/nemo_rl/utils/checkpoint.py +++ b/nemo_rl/utils/checkpoint.py @@ -17,14 +17,15 @@ own checkpoint saving function (called by the algorithm loop). """ -import os -import json import glob -from typing import Dict, Any, Optional, List, Tuple, TypedDict +import json +import os import shutil from pathlib import Path -import torch +from typing import Any, Dict, List, Optional, Tuple, TypedDict + import numpy as np +import torch class CheckpointingConfig(TypedDict): diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index c48dfc772b..0c56296e2b 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -13,30 +13,30 @@ # limitations under the License. +import glob +import json +import logging import os import re -import glob -import time import threading -import requests -import json +import time from abc import ABC, abstractmethod -import logging -from typing import List, Any, Dict, Optional, TypedDict, Union +from typing import Any, Dict, List, Optional, TypedDict + +import ray +import requests +import torch import wandb -from rich.console import Console -from rich.panel import Panel +from prometheus_client.parser import text_string_to_metric_families +from prometheus_client.samples import Sample from rich.box import ROUNDED +from rich.console import Console from rich.logging import RichHandler -import torch +from rich.panel import Panel +from torch.utils.tensorboard import SummaryWriter from nemo_rl.data.interfaces import LLMMessageLogType from nemo_rl.distributed.batched_data_dict import BatchedDataDict -from torch.utils.tensorboard import SummaryWriter - -import ray -from prometheus_client.parser import text_string_to_metric_families -from prometheus_client.samples import Sample # Flag to track if rich logging has been configured _rich_logging_configured = False diff --git a/nemo_rl/utils/native_checkpoint.py b/nemo_rl/utils/native_checkpoint.py index 04d590e133..3573d2d86d 100644 --- a/nemo_rl/utils/native_checkpoint.py +++ b/nemo_rl/utils/native_checkpoint.py @@ -17,18 +17,18 @@ import os from pathlib import Path from typing import Any, Optional + import torch -from torch.distributed.fsdp import FullyShardedDataParallel -from transformers import AutoConfig, AutoTokenizer import torch.distributed.checkpoint as dcp -from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save from torch.distributed.checkpoint.state_dict import ( get_model_state_dict, - set_model_state_dict, get_optimizer_state_dict, + set_model_state_dict, set_optimizer_state_dict, ) -from torch.distributed.checkpoint.format_utils import dcp_to_torch_save +from torch.distributed.checkpoint.stateful import Stateful +from transformers import AutoConfig, AutoTokenizer ## modified from pytorch tutorial https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html diff --git a/nemo_rl/utils/nvml.py b/nemo_rl/utils/nvml.py index 2f684effef..137374e00b 100644 --- a/nemo_rl/utils/nvml.py +++ b/nemo_rl/utils/nvml.py @@ -13,6 +13,7 @@ # limitations under the License. import contextlib import os + import pynvml diff --git a/nemo_rl/utils/timer.py b/nemo_rl/utils/timer.py index 5796b1da39..188ee14166 100644 --- a/nemo_rl/utils/timer.py +++ b/nemo_rl/utils/timer.py @@ -14,6 +14,7 @@ import time from contextlib import contextmanager from typing import Dict, List, Optional, Union + import numpy as np diff --git a/nemo_rl/utils/venvs.py b/nemo_rl/utils/venvs.py index be34a06c8f..184f529c3d 100644 --- a/nemo_rl/utils/venvs.py +++ b/nemo_rl/utils/venvs.py @@ -11,10 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import logging import os -import subprocess import shlex -import logging +import subprocess from functools import lru_cache dir_path = os.path.dirname(os.path.abspath(__file__)) diff --git a/pyproject.toml b/pyproject.toml index 7da068da54..22be1bf23c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,12 +86,13 @@ python_files = "test_*.py" [tool.ruff.lint] # Enable all `pydocstyle` rules, limiting to those that adhere to the # Google convention via `convention = "google"`, below. -select = ["D"] +select = ["D", "F"] -# On top of the Google convention, disable `D417`, which requires -# documentation for every function parameter. +# - On top of the Google convention, disable `D417`, which requires +# documentation for every function parameter. +# - F841: local variable assigned but never used (exluced to favor readability) # TODO: Remove D10 once we are about to release to get all the docstrings written -ignore = ["D417", "D10"] +ignore = ["D417", "D10", "F841"] [tool.ruff.lint.pydocstyle] convention = "google" @@ -102,6 +103,8 @@ convention = "google" "tests/**" = ["D"] # Ignore all files that end in `_test.py`. "*_test.py" = ["D"] +# Ignore F401 (import but unused) in __init__.py +"__init__.py" = ["F401"] [tool.uv] # Users may use different link-modes depending on their scenario: diff --git a/tests/check_metrics.py b/tests/check_metrics.py index b1da6bc924..18f86c9936 100644 --- a/tests/check_metrics.py +++ b/tests/check_metrics.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json -import sys import argparse +import json import statistics -from typing import Dict, Tuple, Any, Union, List +import sys +from typing import Dict, Tuple + from rich.console import Console from rich.table import Table diff --git a/tests/json_dump_tb_logs.py b/tests/json_dump_tb_logs.py index 58554eadf9..7d2e5607fa 100644 --- a/tests/json_dump_tb_logs.py +++ b/tests/json_dump_tb_logs.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import argparse +import datetime import glob import json import os -import datetime import statistics -from collections import defaultdict -from tensorboard.backend.event_processing import event_accumulator import sys +from collections import defaultdict + +from rich.box import SIMPLE from rich.console import Console from rich.table import Table -from rich.box import SIMPLE -from rich.panel import Panel from rich.text import Text +from tensorboard.backend.event_processing import event_accumulator # By default TB tries to be smart about what to load in memory to avoid OOM # Since we expect every step to be there when we do our comparisons, we explicitly @@ -137,7 +137,7 @@ def merge_tb_logs_to_json(log_dir, output_path, allow_conflicts=False): # Create metric header with better highlighting metric_text = Text() - metric_text.append(f"🔹 ", style="bold blue") + metric_text.append("🔹 ", style="bold blue") metric_text.append(f"{metric}", style="bold magenta") metric_text.append(f" - {len(steps)} steps", style="green") console.print(metric_text) @@ -199,7 +199,7 @@ def merge_tb_logs_to_json(log_dir, output_path, allow_conflicts=False): ) else: console.print( - f"[bold red]✓ To save the merged data, use --output_path[/bold red]" + "[bold red]✓ To save the merged data, use --output_path[/bold red]" ) diff --git a/tests/unit/algorithms/test_dpo.py b/tests/unit/algorithms/test_dpo.py index 014d73f9ca..76565730b0 100644 --- a/tests/unit/algorithms/test_dpo.py +++ b/tests/unit/algorithms/test_dpo.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest +from unittest.mock import MagicMock + import torch -from unittest.mock import MagicMock, patch from nemo_rl.algorithms.dpo import add_ref_logprobs_to_data diff --git a/tests/unit/algorithms/test_grpo.py b/tests/unit/algorithms/test_grpo.py index b219a3f8cd..c0a24fee26 100644 --- a/tests/unit/algorithms/test_grpo.py +++ b/tests/unit/algorithms/test_grpo.py @@ -11,18 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple + import pytest -import torch import ray -from typing import Dict, List, Tuple +import torch -from nemo_rl.experience.rollouts import calculate_rewards -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.data.interfaces import DatumSpec, LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.environments.interfaces import ( EnvironmentInterface, EnvironmentReturn, ) +from nemo_rl.experience.rollouts import calculate_rewards @ray.remote(num_cpus=0) diff --git a/tests/unit/algorithms/test_loss_functions.py b/tests/unit/algorithms/test_loss_functions.py index d36d8c0b89..faeee3f03f 100644 --- a/tests/unit/algorithms/test_loss_functions.py +++ b/tests/unit/algorithms/test_loss_functions.py @@ -13,18 +13,16 @@ # limitations under the License. import pytest import torch -import numpy as np from nemo_rl.algorithms.loss_functions import ( - NLLLoss, ClippedPGLossFn, DPOLossFn, + NLLLoss, ) -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.algorithms.utils import ( - calculate_kl_penalty_joschu2020, masked_mean, ) +from nemo_rl.distributed.batched_data_dict import BatchedDataDict def setup_dpo_loss_test_data(vocab_size=16, batch_size=1): diff --git a/tests/unit/algorithms/test_sft.py b/tests/unit/algorithms/test_sft.py index 93ebdcd511..a41cd35b1c 100644 --- a/tests/unit/algorithms/test_sft.py +++ b/tests/unit/algorithms/test_sft.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from unittest.mock import MagicMock + +import pytest import torch from torchdata.stateful_dataloader import StatefulDataLoader -from nemo_rl.algorithms.sft import sft_train, _default_sft_save_state + from nemo_rl.algorithms.loss_functions import NLLLoss +from nemo_rl.algorithms.sft import _default_sft_save_state, sft_train @pytest.fixture diff --git a/tests/unit/algorithms/test_utils.py b/tests/unit/algorithms/test_utils.py index f2aa579320..82338de026 100755 --- a/tests/unit/algorithms/test_utils.py +++ b/tests/unit/algorithms/test_utils.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from datetime import datetime -from transformers import AutoTokenizer + +import pytest + from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 68df0f4c87..2a3ec3a7c9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -11,32 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from io import StringIO -import time -import pytest -from nemo_rl.utils.logger import GPUMonitoringConfig -from tests import unit -import torch -import torch.distributed as dist -import torch.multiprocessing as mp +import json import os import random -from typing import Callable -import ray +import time +import unittest.mock +from datetime import datetime +from io import StringIO +from typing import Callable, TypedDict import pytest +import ray import torch import torch.distributed as dist import torch.multiprocessing as mp -import os -import random -from typing import Callable -import ray -import json + from nemo_rl.distributed.virtual_cluster import init_ray -from typing import TypedDict -from datetime import datetime -import unittest.mock dir_path = os.path.dirname(os.path.abspath(__file__)) @@ -393,9 +383,10 @@ def mock_2gpu_distributed_env(): @pytest.fixture(scope="session", autouse=True) def tiny_llama_model_path(): """Fixture that returns a path to a tiny llama model with a dummy tokenizer.""" - from transformers import LlamaConfig, LlamaForCausalLM, AutoTokenizer import shutil + from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM + model_path = TEST_ASSETS.TINY_LLAMA_MODEL_PATH # hidden_size//num_attention_heads = 32 (smallest value to not error due to vllm paged attention) # vocab_size=128256 (so we can re-use llama3.2 1b tokenizer) @@ -420,9 +411,10 @@ def tiny_llama_model_path(): @pytest.fixture(scope="session", autouse=True) def tiny_llama_tied_model_path(): """Fixture that returns a path to a tiny llama model with a dummy tokenizer.""" - from transformers import LlamaConfig, LlamaForCausalLM, AutoTokenizer import shutil + from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM + model_path = TEST_ASSETS.TINY_LLAMA_TIED_MODEL_PATH # hidden_size//num_attention_heads = 32 (smallest value to not error due to vllm paged attention) # vocab_size=128256 (so we can re-use llama3.2 1b tokenizer) @@ -447,9 +439,10 @@ def tiny_llama_tied_model_path(): @pytest.fixture(scope="session", autouse=True) def tiny_qwen2_model_path(): """Fixture that returns a path to a tiny llama model with a dummy tokenizer.""" - from transformers import Qwen2Config, Qwen2ForCausalLM, AutoTokenizer import shutil + from transformers import AutoTokenizer, Qwen2Config, Qwen2ForCausalLM + model_path = TEST_ASSETS.TINY_QWEN2_MODEL_PATH # hidden_size//num_attention_heads = 32 (smallest value to not error due to vllm paged attention) # vocab_size=151936 (so we can re-use qwen2 1.5b tokenizer) diff --git a/tests/unit/data/hf_datasets/test_dpo_dataset.py b/tests/unit/data/hf_datasets/test_dpo_dataset.py index 12cf67aef8..ed13df2c99 100644 --- a/tests/unit/data/hf_datasets/test_dpo_dataset.py +++ b/tests/unit/data/hf_datasets/test_dpo_dataset.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os import tempfile -import json + import pytest -from unittest.mock import patch, MagicMock from nemo_rl.data.hf_datasets.dpo import DPODataset diff --git a/tests/unit/data/hf_datasets/test_helpsteer.py b/tests/unit/data/hf_datasets/test_helpsteer.py index 5d297cfe77..5573179b99 100644 --- a/tests/unit/data/hf_datasets/test_helpsteer.py +++ b/tests/unit/data/hf_datasets/test_helpsteer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest + from nemo_rl.data.hf_datasets.helpsteer3 import ( HelpSteer3Dataset, format_helpsteer3, diff --git a/tests/unit/data/hf_datasets/test_prompt_response.py b/tests/unit/data/hf_datasets/test_prompt_response.py index d0aeccc583..8ff7f5c5f6 100644 --- a/tests/unit/data/hf_datasets/test_prompt_response.py +++ b/tests/unit/data/hf_datasets/test_prompt_response.py @@ -12,14 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import tempfile import json +import tempfile + +import pytest +from transformers import AutoTokenizer + from nemo_rl.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.data.hf_datasets.prompt_response_dataset import ( PromptResponseDataset, ) -from transformers import AutoTokenizer @pytest.fixture diff --git a/tests/unit/data/hf_datasets/test_squad.py b/tests/unit/data/hf_datasets/test_squad.py index d959f694f8..5e736ee8ac 100644 --- a/tests/unit/data/hf_datasets/test_squad.py +++ b/tests/unit/data/hf_datasets/test_squad.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest - from transformers import AutoTokenizer -from nemo_rl.data.hf_datasets.chat_templates import COMMON_CHAT_TEMPLATES + from nemo_rl.data.hf_datasets.squad import SquadDataset diff --git a/tests/unit/data/test_data_processor.py b/tests/unit/data/test_data_processor.py index 39cd959dd6..302dfece77 100644 --- a/tests/unit/data/test_data_processor.py +++ b/tests/unit/data/test_data_processor.py @@ -13,8 +13,8 @@ # limitations under the License. import os -import pytest import sys + from datasets import Dataset abspath = os.path.abspath(__file__) @@ -26,7 +26,6 @@ from nemo_rl.data.interfaces import TaskDataSpec from nemo_rl.models.policy import TokenizerConfig - basic_tokenizer_test_config: TokenizerConfig = { "name": "Qwen/Qwen2.5-Math-1.5B-Instruct", "chat_template": "default", diff --git a/tests/unit/data/test_datasets.py b/tests/unit/data/test_datasets.py index 31cad8c82d..d879b09a85 100755 --- a/tests/unit/data/test_datasets.py +++ b/tests/unit/data/test_datasets.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import torch from unittest.mock import MagicMock +import torch + from nemo_rl.data.datasets import dpo_collate_fn from nemo_rl.data.interfaces import DatumSpec from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/tests/unit/data/test_llm_message_utils.py b/tests/unit/data/test_llm_message_utils.py index 1a7ddc568c..0a5cb3ef4b 100644 --- a/tests/unit/data/test_llm_message_utils.py +++ b/tests/unit/data/test_llm_message_utils.py @@ -12,20 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List + import pytest import torch -from typing import Dict, List from transformers import AutoTokenizer +from nemo_rl.data.interfaces import LLMMessageLogType, TaskDataSpec from nemo_rl.data.llm_message_utils import ( - message_log_to_flat_messages, - get_keys_from_message_log, - batched_message_log_to_flat_message, - get_formatted_message_log, add_loss_mask_to_message_log, + batched_message_log_to_flat_message, get_first_index_that_differs, + get_formatted_message_log, + get_keys_from_message_log, + message_log_to_flat_messages, ) -from nemo_rl.data.interfaces import LLMMessageLogType, TaskDataSpec @pytest.fixture diff --git a/tests/unit/distributed/test_batched_data_dict.py b/tests/unit/distributed/test_batched_data_dict.py index f95814eac4..63ca784b01 100644 --- a/tests/unit/distributed/test_batched_data_dict.py +++ b/tests/unit/distributed/test_batched_data_dict.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/tests/unit/distributed/test_cluster_visualization.py b/tests/unit/distributed/test_cluster_visualization.py index 00243025e5..d6dc31e1a5 100644 --- a/tests/unit/distributed/test_cluster_visualization.py +++ b/tests/unit/distributed/test_cluster_visualization.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import patch, MagicMock +from unittest.mock import MagicMock, patch + import pytest from nemo_rl.distributed.virtual_cluster import RayVirtualCluster diff --git a/tests/unit/distributed/test_collectives.py b/tests/unit/distributed/test_collectives.py index 72b900cb73..3599546db3 100644 --- a/tests/unit/distributed/test_collectives.py +++ b/tests/unit/distributed/test_collectives.py @@ -14,8 +14,8 @@ import torch from nemo_rl.distributed.collectives import ( - rebalance_nd_tensor, gather_jagged_object_lists, + rebalance_nd_tensor, ) diff --git a/tests/unit/distributed/test_virtual_cluster.py b/tests/unit/distributed/test_virtual_cluster.py index 8d26ed5a33..c3b4de4853 100644 --- a/tests/unit/distributed/test_virtual_cluster.py +++ b/tests/unit/distributed/test_virtual_cluster.py @@ -11,17 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +from unittest.mock import MagicMock, patch + +import pytest +import ray + from nemo_rl.distributed.virtual_cluster import ( - _get_node_ip_and_free_port, PY_EXECUTABLES, RayVirtualCluster, ResourceInsufficientError, + _get_node_ip_and_free_port, ) -import ray -import pytest -import os -from unittest.mock import patch, MagicMock -import importlib def test_get_node_ip_and_free_port_does_not_start_with_zero(): diff --git a/tests/unit/environments/test_math_environment.py b/tests/unit/environments/test_math_environment.py index 4b2d4069cb..77846db7af 100644 --- a/tests/unit/environments/test_math_environment.py +++ b/tests/unit/environments/test_math_environment.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import time + import pytest import ray + from nemo_rl.environments.math_environment import MathEnvironment -import time -import os @pytest.fixture(scope="module") diff --git a/tests/unit/experience/test_rollouts.py b/tests/unit/experience/test_rollouts.py index cc71b366b4..b45811d4f8 100644 --- a/tests/unit/experience/test_rollouts.py +++ b/tests/unit/experience/test_rollouts.py @@ -12,38 +12,35 @@ # See the License for the specific language governing permissions and # limitations under the License. +import gc +from copy import deepcopy + import pytest import ray import torch -from copy import deepcopy -import gc - from transformers import AutoTokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster +from nemo_rl.environments.games.sliding_puzzle import ( + SlidingPuzzleConfig, + SlidingPuzzleEnv, + SlidingPuzzleGameLogic, + SlidingPuzzleMetadata, +) +from nemo_rl.experience.rollouts import run_multi_turn_rollout +from nemo_rl.models.generation.interfaces import configure_generation_config +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.hf_policy import HfPolicy -from nemo_rl.models.generation.interfaces import configure_generation_config -from nemo_rl.experience.rollouts import run_multi_turn_rollout # Import the test environment definitions from tests.unit.test_envs import ( + MultiStepCalcMetadata, MultiStepCalculatorEnv, _MultiStepCalculatorLogic, - MultiStepCalcMetadata, -) - -from nemo_rl.environments.games.sliding_puzzle import ( - SlidingPuzzleGameLogic, - SlidingPuzzleEnv, - SlidingPuzzleConfig, - SlidingPuzzleMetadata, ) -from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration - - MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct" diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index da454a9265..552ea3dae2 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -12,19 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from copy import deepcopy import pytest -import torch import ray -import os +import torch from nemo_rl.algorithms.grpo import refit_policy_generation from nemo_rl.algorithms.utils import get_tokenizer -from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation.interfaces import configure_generation_config -from nemo_rl.models.generation.vllm import VllmGeneration, VllmConfig +from nemo_rl.models.generation.vllm import VllmConfig, VllmGeneration from nemo_rl.models.policy import PolicyConfig # Define basic vLLM test config @@ -282,7 +282,7 @@ def test_vllm_worker_seed_behavior(cluster, tokenizer): hf_config = get_basic_hf_test_config(enable_dtensor=False) hf_policy = HfPolicy(cluster, hf_config, tokenizer) - print(f"refitting vllm policy...") + print("refitting vllm policy...") refit_policy_generation(hf_policy, policy, hf_config["refit_buffer_size_gb"]) try: @@ -445,7 +445,7 @@ def test_vllm_generation_with_hf_training(cluster, tokenizer, enable_dtensor): print("Creating HF policy...") hf_policy = HfPolicy(cluster, hf_config, tokenizer) - print(f"refitting vllm policy...") + print("refitting vllm policy...") refit_policy_generation( hf_policy, vllm_policy, hf_config["refit_buffer_size_gb"] ) @@ -788,7 +788,7 @@ def test_vllm_weight_update_memory(cluster, tokenizer, enable_dtensor): hf_config = get_basic_hf_test_config(enable_dtensor=enable_dtensor) hf_policy = HfPolicy(cluster, hf_config, tokenizer) - print(f"refitting vllm policy...") + print("refitting vllm policy...") # take it outside statistics to get clean peak memory during refit hf_policy.offload_before_refit() # reset peak memory stats before refit @@ -860,7 +860,7 @@ def test_vllm_generation_with_stop( hf_config = get_basic_hf_test_config(enable_dtensor=enable_dtensor) hf_policy = HfPolicy(cluster, hf_config, tokenizer) - print(f"refitting vllm policy...") + print("refitting vllm policy...") refit_policy_generation( hf_policy, vllm_generation, diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index 9472abeb3a..7f175b3f15 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -11,26 +11,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import ray -import pytest +import os import pprint + +import pytest +import ray import torch -import os # Define a custom marker for model configuration tests pytestmark = pytest.mark.modelconfig +from transformers import AutoModelForCausalLM + from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.generation.interfaces import configure_generation_config from nemo_rl.models.policy import PolicyConfig -from nemo_rl.models.policy.hf_policy import HfPolicy from nemo_rl.models.policy.dtensor_policy_worker import DTensorPolicyWorker -from tests.unit.test_utils import simple_loss +from nemo_rl.models.policy.hf_policy import HfPolicy from tests.unit.conftest import TEST_ASSETS -from transformers import AutoModelForCausalLM +from tests.unit.test_utils import simple_loss def create_test_config( diff --git a/tests/unit/models/policy/test_fsdp1_worker.py b/tests/unit/models/policy/test_fsdp1_worker.py index c4fa020bae..b53c9491b2 100644 --- a/tests/unit/models/policy/test_fsdp1_worker.py +++ b/tests/unit/models/policy/test_fsdp1_worker.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import ray -import pytest -import pprint -import torch import os +import pprint from copy import deepcopy +import pytest +import ray +import torch + from nemo_rl.algorithms.interfaces import LossFunction from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -26,8 +27,7 @@ from nemo_rl.models.generation.interfaces import configure_generation_config from nemo_rl.models.policy import PolicyConfig from nemo_rl.models.policy.hf_policy import HfPolicy -from tests.unit.test_utils import simple_loss, nll_loss - +from tests.unit.test_utils import nll_loss, simple_loss basic_llama_test_config: PolicyConfig = { "model_name": "meta-llama/Llama-3.2-1B", diff --git a/tests/unit/test_envs.py b/tests/unit/test_envs.py index 2e139f272a..a46466979f 100644 --- a/tests/unit/test_envs.py +++ b/tests/unit/test_envs.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, List, Optional, Tuple, TypedDict + import ray import torch -from typing import Dict, List, Tuple, Optional, TypedDict, Literal, Any -from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.data.interfaces import LLMMessageLogType +from nemo_rl.distributed.batched_data_dict import BatchedDataDict +from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES from nemo_rl.environments.interfaces import ( EnvironmentInterface, EnvironmentReturn, ) -from nemo_rl.distributed.virtual_cluster import PY_EXECUTABLES class MultiStepCalcMetadata(TypedDict): diff --git a/tests/unit/test_meta.py b/tests/unit/test_meta.py index 8b43250dec..7ac7945f85 100644 --- a/tests/unit/test_meta.py +++ b/tests/unit/test_meta.py @@ -14,21 +14,16 @@ # This module tests things outside of any package (e.g., things in the root __init__.py) -import pytest import os def test_usage_stats_disabled_by_default(): - import nemo_rl - assert os.environ["RAY_USAGE_STATS_ENABLED"] == "0", ( - f"Our dockerfile, slurm submission script and default environment setting when importing nemo rl should all disable usage stats collection. This failing is not expected." + "Our dockerfile, slurm submission script and default environment setting when importing nemo rl should all disable usage stats collection. This failing is not expected." ) def test_usage_stats_disabled_in_tests(): - import tests - assert os.environ["RAY_USAGE_STATS_ENABLED"] == "0", ( - f"Our dockerfile, slurm submission script and default environment setting when importing nemo rl should all disable usage stats collection. This failing is not expected." + "Our dockerfile, slurm submission script and default environment setting when importing nemo rl should all disable usage stats collection. This failing is not expected." ) diff --git a/tests/unit/test_recipes_and_test_suites.py b/tests/unit/test_recipes_and_test_suites.py index edceba3649..9bac39188e 100644 --- a/tests/unit/test_recipes_and_test_suites.py +++ b/tests/unit/test_recipes_and_test_suites.py @@ -11,11 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import os import glob +import os import subprocess +import pytest + dir_path = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.abspath(os.path.join(dir_path, "..", "..")) test_suites_dir = os.path.join(project_root, "tests", "test_suites") diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 2d3318a053..d94935e95c 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Tuple + import torch from nemo_rl.distributed.batched_data_dict import BatchedDataDict diff --git a/tests/unit/utils/test_checkpoint.py b/tests/unit/utils/test_checkpoint.py index dc6ad0dd5d..2a912e94b2 100644 --- a/tests/unit/utils/test_checkpoint.py +++ b/tests/unit/utils/test_checkpoint.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import json +from pathlib import Path + +import numpy as np import pytest import torch -import numpy as np -from pathlib import Path + from nemo_rl.utils.checkpoint import CheckpointManager diff --git a/tests/unit/utils/test_logger.py b/tests/unit/utils/test_logger.py index ea88a69c7f..baf65c02f4 100644 --- a/tests/unit/utils/test_logger.py +++ b/tests/unit/utils/test_logger.py @@ -20,9 +20,9 @@ from nemo_rl.utils.logger import ( Logger, + RayGpuMonitorLogger, TensorboardLogger, WandbLogger, - RayGpuMonitorLogger, flatten_dict, ) diff --git a/tests/unit/utils/test_native_checkpoint.py b/tests/unit/utils/test_native_checkpoint.py index 54962b4f38..7cebeade90 100755 --- a/tests/unit/utils/test_native_checkpoint.py +++ b/tests/unit/utils/test_native_checkpoint.py @@ -11,23 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import os +from tempfile import TemporaryDirectory + import pytest import torch -from tempfile import TemporaryDirectory +from transformers import AutoModelForCausalLM from nemo_rl.algorithms.utils import get_tokenizer from nemo_rl.distributed.batched_data_dict import BatchedDataDict from nemo_rl.distributed.virtual_cluster import RayVirtualCluster from nemo_rl.models.policy.hf_policy import HfPolicy -from transformers import AutoTokenizer, AutoModelForCausalLM from nemo_rl.utils.native_checkpoint import ( - load_checkpoint, - save_checkpoint, ModelState, OptimizerState, convert_dcp_to_hf, + load_checkpoint, + save_checkpoint, ) from tests.unit.test_utils import simple_loss diff --git a/tests/unit/utils/test_pynvml.py b/tests/unit/utils/test_pynvml.py index f9b667779d..aa0c9a90fe 100644 --- a/tests/unit/utils/test_pynvml.py +++ b/tests/unit/utils/test_pynvml.py @@ -15,9 +15,9 @@ from unittest.mock import patch from nemo_rl.utils.nvml import ( - nvml_context, device_id_to_physical_device_id, get_device_uuid, + nvml_context, ) diff --git a/tests/unit/utils/test_timer.py b/tests/unit/utils/test_timer.py index 5ae6ae8687..56ba315b55 100644 --- a/tests/unit/utils/test_timer.py +++ b/tests/unit/utils/test_timer.py @@ -13,10 +13,11 @@ # limitations under the License. import time -import pytest -import numpy as np from unittest.mock import patch +import numpy as np +import pytest + from nemo_rl.utils.timer import Timer diff --git a/tests/unit/utils/test_venvs.py b/tests/unit/utils/test_venvs.py index 15b229b9a7..59f7c03565 100644 --- a/tests/unit/utils/test_venvs.py +++ b/tests/unit/utils/test_venvs.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -from unittest.mock import patch +import subprocess from tempfile import TemporaryDirectory +from unittest.mock import patch + from nemo_rl.utils.venvs import create_local_venv -import subprocess def test_create_local_venv():