Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ We have a reference GRPO experiment config set up trained for math benchmarks us

#### Single Node

To run GRPO on a single GPU for `Llama-3.2-1B-Instruct`:
To run GRPO on a single GPU for `Qwen/Qwen2.5-1.5B`:

```sh
# Run the GRPO math example using a 1B parameter model
Expand All @@ -87,10 +87,10 @@ You can override any of the parameters listed in the yaml configuration file. Fo

```sh
uv run python examples/run_grpo_math.py \
policy.model_name="Qwen/Qwen2-1.5B" \
checkpointing.checkpoint_dir="results/qwen1_5b_math" \
policy.model_name="Llama-3.2-1B-Instruct" \
checkpointing.checkpoint_dir="results/llama1b_math" \
logger.wandb_enabled=True \
logger.wandb.name="grpo-qwen1_5b_math" \
logger.wandb.name="grpo-llama1b_math" \
logger.num_val_samples_to_print=10 \
```

Expand Down
5 changes: 3 additions & 2 deletions examples/configs/grpo_math_1B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ checkpointing:
save_period: 10

policy:
model_name: "meta-llama/Llama-3.2-1B-Instruct"
# Qwen/Qwen2.5-1.5B has tied weights which are only supported with dtensor policy with tp size 1 (https://github.com/NVIDIA/reinforcer/issues/227)
model_name: "Qwen/Qwen2.5-1.5B"
tokenizer:
name: ${policy.model_name} ## specify if you'd like to use a tokenizer different from the model's default
train_global_batch_size: 512
Expand All @@ -42,7 +43,7 @@ policy:
refit_buffer_size_gb: 4 # used for refitting inference engine, the unit is GB

dtensor_cfg:
enabled: false
enabled: true
cpu_offload: False
sequence_parallel: false
activation_checkpointing: false
Expand Down
3 changes: 3 additions & 0 deletions examples/configs/grpo_math_8B.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ policy:
activation_checkpointing_enabled: false
refit_buffer_size_gb: 4 # used for refitting inference engine, the unit is GB

dtensor_cfg:
enabled: False

optimizer:
name: "torch.optim.AdamW"
kwargs:
Expand Down
14 changes: 14 additions & 0 deletions nemo_reinforcer/models/policy/dtensor_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import gc

from collections import defaultdict
Expand All @@ -24,6 +25,7 @@
FSDPModule,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import _get_tied_weight_keys
from nemo_reinforcer.models.dtensor.parallelize import _parallelize_model

from nemo_reinforcer.algorithms.interfaces import LossFunction
Expand Down Expand Up @@ -140,6 +142,7 @@ def __init__(
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)

self.tokenizer = tokenizer
# ------------------------------------------------
# 3) Move to GPU + Composable FSDP
Expand Down Expand Up @@ -253,6 +256,17 @@ def train(
mbs: Optional[int] = None,
) -> Dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
num_tied_weights = len(_get_tied_weight_keys(self.model))
skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK")
if (
num_tied_weights != 0
and self.cfg["dtensor_cfg"]["tensor_parallel_size"] > 1
and not skip_tie_check
):
raise ValueError(
f"Using dtensor policy with tp size {self.cfg['dtensor_cfg']['tensor_parallel_size']} for model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={num_tied_weights}) is not supported (https://github.com/NVIDIA/reinforcer/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
)

if gbs is None:
gbs = self.cfg["train_global_batch_size"]
if mbs is None:
Expand Down
11 changes: 11 additions & 0 deletions nemo_reinforcer/models/policy/fsdp1_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from collections import defaultdict
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, Optional
import os

import ray
import torch
Expand All @@ -38,6 +39,7 @@
)

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import _get_tied_weight_keys
from nemo_reinforcer.models.policy import PolicyConfig
from nemo_reinforcer.models.policy.utils import import_class_from_path
from nemo_reinforcer.distributed.virtual_cluster import (
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
device_map="cpu", # load weights onto CPU initially
torch_dtype=torch.float32, # use full precision in sft until https://github.com/NVIDIA/reinforcer/issues/13 is fixed
)

if init_reference_model:
self.reference_model = AutoModelForCausalLM.from_pretrained(
model_name,
Expand Down Expand Up @@ -225,6 +228,14 @@ def train(
mbs: Optional[int] = None,
) -> Dict[str, Any]:
"""Train the policy on a batch of data with a given loss function."""
# Check if the model has tied weights
num_tied_weights = len(_get_tied_weight_keys(self.model))
skip_tie_check = os.environ.get("NRL_SKIP_TIED_WEIGHT_CHECK")
if num_tied_weights != 0 and not skip_tie_check:
raise ValueError(
f"Using FSP1 with a model ({self.cfg['model_name']}) that has tied weights (num_tied_weights={num_tied_weights}) is not supported (https://github.com/NVIDIA/reinforcer/issues/227). Please use dtensor policy with tensor parallel == 1 instead."
)

if gbs is None:
gbs = self.cfg["train_global_batch_size"]
if mbs is None:
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/models/generation/test_vllm_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import pytest
import torch
import ray
import os

from nemo_reinforcer.algorithms.grpo import refit_policy_generation
from nemo_reinforcer.algorithms.utils import get_tokenizer
Expand All @@ -26,7 +27,6 @@
from nemo_reinforcer.models.generation.vllm import VllmGeneration, VllmConfig
from nemo_reinforcer.models.policy import PolicyConfig


# Define basic vLLM test config
basic_vllm_test_config: VllmConfig = {
"backend": "vllm",
Expand Down Expand Up @@ -161,6 +161,17 @@ def test_input_data(tokenizer):
)


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"

yield

# Restore the original value
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


def test_vllm_missing_required_config_key(cluster):
"""Test that an assertion error is raised when a required config key is missing."""
# Create a config missing a required key by removing 'model_name'
Expand Down
13 changes: 11 additions & 2 deletions tests/unit/models/policy/test_dtensor_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import pprint
import torch
import os
import unittest.mock
import torch.distributed as dist

# Define a custom marker for model configuration tests
pytestmark = pytest.mark.modelconfig
Expand Down Expand Up @@ -88,6 +86,17 @@ def create_test_config(
}


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"

yield

# Restore the original value
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


@pytest.fixture(scope="module")
def two_gpu_virtual_cluster():
cluster_name = "test"
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/models/policy/test_fsdp1_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ray
import pytest
import pprint
import torch
import os
from copy import deepcopy

from nemo_reinforcer.algorithms.interfaces import LossFunction
Expand Down Expand Up @@ -73,6 +75,17 @@
}


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"

yield

# Restore the original value
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


@pytest.fixture(scope="function")
def gc_collect():
"""Helper function to force garbage collection after a test"""
Expand Down
11 changes: 11 additions & 0 deletions tests/unit/utils/test_native_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,17 @@ def policy(cluster, tokenizer):
policy.worker_group.shutdown()


@pytest.fixture(scope="module", autouse=True)
def skip_tied_weight_check_for_all():
"""Automatically skip tied weight check for all tests in this module."""
os.environ["NRL_SKIP_TIED_WEIGHT_CHECK"] = "1"

yield

# Restore the original value
os.environ.pop("NRL_SKIP_TIED_WEIGHT_CHECK", None)


def get_dummy_state_dict(state_dict, dummy_dict={}):
"""Recursively get the dummy state dict
by replacing tensors with random ones of the same shape.
Expand Down