Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/special_distributed/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
#!/usr/bin/env bash

set -e -x
torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py
torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py
torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_torch_functional.py
35 changes: 35 additions & 0 deletions tests/special_distributed/test_torch_functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch

from verl.utils.torch_functional import allgather_dict_into_dict

if __name__ == "__main__":
torch.distributed.init_process_group(backend="gloo")

local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

metrics_dict = {"loss": [0 + rank, 1 + rank, 2 + rank], "grad_norm": rank}

result = allgather_dict_into_dict(data=metrics_dict, group=None)

assert result["loss"] == [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5]]
assert result["grad_norm"] == [0, 1, 2, 3]

print(result)
12 changes: 6 additions & 6 deletions verl/trainer/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def add_critic_worker(self, config):
if use_legacy_worker_impl in ["auto", "enable"]:
from verl.workers.fsdp_workers import CriticWorker
elif use_legacy_worker_impl == "disable":
from verl.workers.roles import CriticWorker
from verl.workers.engine_workers import CriticWorker

print("Using new worker implementation")
else:
Expand Down Expand Up @@ -223,17 +223,17 @@ def add_reward_model_worker(self, config):

if config.reward_model.enable:
use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto")
if use_legacy_worker_impl in ["auto", "enable"]:
if use_legacy_worker_impl in ["auto", "enable", "disable"]:
if config.reward_model.strategy in {"fsdp", "fsdp2"}:
from verl.workers.fsdp_workers import RewardModelWorker
elif config.reward_model.strategy == "megatron":
from verl.workers.megatron_workers import RewardModelWorker
else:
raise NotImplementedError
elif use_legacy_worker_impl == "disable":
from verl.workers.roles import RewardModelWorker

print("Using new worker implementation")
# elif use_legacy_worker_impl == "disable":
# from verl.workers.engine_workers import RewardModelWorker
#
# print("Using new worker implementation")
else:
raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}")

Expand Down
116 changes: 57 additions & 59 deletions verl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import os
from functools import partial

from tensordict.tensorclass import NonTensorData

os.environ["NCCL_DEBUG"] = "WARN"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

Expand All @@ -24,7 +26,6 @@
import hydra
import torch
import torch.distributed
from codetiming import Timer
from omegaconf import OmegaConf
from torch.utils.data import DistributedSampler
from torchdata.stateful_dataloader import StatefulDataLoader
Expand All @@ -36,9 +37,9 @@
from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset
from verl.utils.device import get_device_name, is_cuda_available, is_npu_available
from verl.utils.distributed import destroy_global_process_group
from verl.utils.flops_counter import FlopsCounter
from verl.utils.logger import log_with_rank
from verl.utils.tracking import Tracking
from verl.workers.engine_workers import TrainingWorker

if is_cuda_available:
pass
Expand Down Expand Up @@ -74,12 +75,6 @@ def __init__(

self.device_name = self.config.trainer.device

from verl.workers.utils.losses import sft_loss

self.loss_fn = partial(sft_loss, config=None)

self.flops_counter = FlopsCounter(self.model_config.hf_config)

if self.rank == 0:
print(self.config)

Expand Down Expand Up @@ -108,17 +103,24 @@ def _build_config(self):
self.checkpoint_config = omega_conf_to_dataclass(self.config.checkpoint)

def _build_engine(self):
from verl.workers.engine import BaseEngine, EngineRegistry
from verl.workers.engine_workers import TrainingWorkerConfig
from verl.workers.utils.losses import sft_loss

self.loss_fn = partial(sft_loss, config=None)

self.engine: BaseEngine = EngineRegistry.new(
config = TrainingWorkerConfig(
model_type="language_model",
backend=self.engine_config.strategy,
model_config=self.model_config,
engine_config=self.engine_config,
optimizer_config=self.optimizer_config,
checkpoint_config=self.checkpoint_config,
)

self.training_client = TrainingWorker(config=config)
self.training_client.set_loss_fn(loss_fn=self.loss_fn)
# Note that in SPMD world, this abstraction has to break
self.engine = self.training_client.engine

def _init_engine(self):
# patch optimizer config
if self.config.trainer.total_training_steps is not None:
Expand All @@ -138,7 +140,7 @@ def _init_engine(self):
if self.test_freq == "after_each_epoch":
self.test_freq = self.steps_per_epoch

self.engine.initialize()
self.training_client.reset()

def _build_dataset(self):
config = self.config
Expand Down Expand Up @@ -203,6 +205,30 @@ def _build_dataloader(self):
else:
self.val_dataloader = None

def _get_batch_seqlens(self, data):
# mean over dp group
is_nested = data["input_ids"].is_nested
if is_nested:
batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff()
else:
batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1)
batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp)

output_tensor = torch.empty(
(batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),),
dtype=batch_seqlens.dtype,
device=self.device_name,
) # (global_bsz,)

torch.distributed.all_gather_into_tensor(
output_tensor=output_tensor,
input_tensor=batch_seqlens,
group=self.engine.get_data_parallel_group(),
)

batch_seqlens = output_tensor.tolist()
return batch_seqlens

def fit(self):
is_logging = self.engine.is_mp_src_rank_with_outputs() and self.engine.get_data_parallel_rank() == 0

Expand Down Expand Up @@ -267,56 +293,28 @@ def fit(self):

# construct tensordict
data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info)
batch_seqlens = self._get_batch_seqlens(data=data)
# this is necessary. Otherwise, it is interpreted as NonTensorStack
batch_seqlens = NonTensorData(batch_seqlens)

with self.engine.train_mode():
with Timer(name="update_policy", logger=None) as timer:
output = self.engine.train_batch(data=data, loss_function=self.loss_fn)
lr = self.engine.lr_scheduler_step()
tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens)

# train for on batch
output = self.training_client.train_batch(data=data)

if self.engine.is_mp_src_rank_with_outputs():
metrics = output["metrics"]

loss = torch.sum(torch.tensor(metrics["loss"], device=self.device_name))

# mean over dp group
is_nested = data["input_ids"].is_nested
if is_nested:
batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff()
else:
batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1)
batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp)

output_tensor = torch.randint(
0,
100,
(batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),),
device=self.device_name,
) # (global_bsz,)

torch.distributed.all_gather_into_tensor(
output_tensor=output_tensor,
input_tensor=batch_seqlens,
group=self.engine.get_data_parallel_group(),
)
torch.distributed.all_reduce(
loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group()
)

batch_seqlens = output_tensor.tolist()
loss = loss.item()
metrics = tu.get(output, "metrics")

# TODO: we can actual accumulate metrics for N steps and perform aggregate metrics
metrics["loss"] = loss
metrics["train/loss"] = metrics.pop("loss")
metrics["train/grad_norm"] = metrics.pop("grad_norm")
metrics["train/lr"] = lr
metrics["train/global_tokens"] = output_tensor.sum().item()
metrics["train/lr"] = metrics.pop("lr")
metrics["train/mfu"] = metrics.pop("mfu")
metrics["train/global_tokens"] = torch.sum(
torch.tensor(batch_seqlens, device=self.device_name)
).item()
total_tokens += metrics["train/global_tokens"]
metrics["train/total_tokens(B)"] = total_tokens / 1e9
# mfu
delta_time = timer.last
estimated_flops, promised_flops = self.flops_counter.estimate_flops(batch_seqlens, delta_time)
metrics["train/mfu"] = estimated_flops / promised_flops / torch.distributed.get_world_size()

if self.engine.get_data_parallel_rank() == 0:
tracking.log(data=metrics, step=global_step)
Expand All @@ -330,12 +328,12 @@ def fit(self):
# Perform validation
val_losses = []
for val_data in self.val_dataloader:
with self.engine.eval_mode():
# construct tensordict
val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info)
output = self.engine.infer_batch(data=val_data, loss_function=self.loss_fn)
if self.engine.is_mp_src_rank_with_outputs():
val_losses.extend(output["metrics"]["loss"])
val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info)
output = self.training_client.infer_batch(val_data)

if self.engine.is_mp_src_rank_with_outputs():
metrics = tu.get(output, "metrics")
val_losses.append(metrics["loss"])

if self.engine.is_mp_src_rank_with_outputs():
val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name))
Expand Down
26 changes: 26 additions & 0 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,32 @@ def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size,
return output


def allgather_dict_into_dict(data: dict, group=None) -> dict:
"""allgather a dict into a dict of list

Args:
data: a dict
group: the process group to allgather

Returns: dict containing a list of the results from allgather

"""
assert isinstance(data, dict), f"Expect data to be a dictionary, Got {type(data)}"

group_size = torch.distributed.get_world_size(group=group)

final_metrics = {}
all_metrics_lst = [None for _ in range(group_size)]
torch.distributed.all_gather_object(all_metrics_lst, data, group=group)

for all_metrics in all_metrics_lst:
for key, val in all_metrics.items():
if key not in final_metrics:
final_metrics[key] = []
final_metrics[key].append(val)
return final_metrics


def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]:
assert tensors.batch_size[0] % batch_size == 0, (
f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}"
Expand Down
41 changes: 28 additions & 13 deletions verl/workers/config/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,26 @@
from typing import Any, Optional

from verl.base_config import BaseConfig
from verl.trainer.config import CheckpointConfig

__all__ = ["FSDPEngineConfig", "McoreEngineConfig"]
from .model import HFModelConfig
from .optimizer import OptimizerConfig

__all__ = ["FSDPEngineConfig", "McoreEngineConfig", "TrainingWorkerConfig"]


@dataclass
class McoreEngineConfig(BaseConfig):
class EngineConfig(BaseConfig):
param_offload: bool = False
optimizer_offload: bool = False
grad_offload: bool = False
forward_only: bool = False
strategy: str = None
dtype: str = "bfloat16" # ["bfloat16", "float16"]


@dataclass
class McoreEngineConfig(EngineConfig):
"""Configuration for Megatron parallelism.

The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Expand Down Expand Up @@ -51,10 +65,7 @@ class McoreEngineConfig(BaseConfig):

# sequence_parallel is not listed as a frozen field for auto-correction purpose
_mutable_fields = BaseConfig._mutable_fields | {"sequence_parallel"}

param_offload: bool = False
grad_offload: bool = False
optimizer_offload: bool = False
# mcore parallelism
tensor_model_parallel_size: int = 1
expert_model_parallel_size: int = 1
expert_tensor_parallel_size: Optional[int] = None
Expand All @@ -72,9 +83,7 @@ class McoreEngineConfig(BaseConfig):
override_mcore_model_config: dict[str, Any] = field(default_factory=dict)
use_mbridge: bool = False
vanilla_mbridge: bool = True
forward_only: bool = False
strategy: str = "megatron"
dtype: str = "bfloat16" # ["bfloat16", "float16"]

def __post_init__(self) -> None:
"""config validation logics go here"""
Expand All @@ -86,7 +95,7 @@ def __post_init__(self) -> None:


@dataclass
class FSDPEngineConfig(BaseConfig):
class FSDPEngineConfig(EngineConfig):
"""Configuration for FSDP (Fully Sharded Data Parallel).

The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config.
Expand All @@ -108,9 +117,8 @@ class FSDPEngineConfig(BaseConfig):
# ulysses_sequence_parallel_size is mutable for backward compatibility
_mutable_fields = BaseConfig._mutable_fields | {"ulysses_sequence_parallel_size"}

# fsdp specific flags
wrap_policy: dict[str, Any] = field(default_factory=dict)
param_offload: bool = False
optimizer_offload: bool = False
offload_policy: bool = False
reshard_after_forward: bool = True
fsdp_size: int = -1
Expand All @@ -122,9 +130,16 @@ class FSDPEngineConfig(BaseConfig):
entropy_from_logits_with_chunking: bool = False
use_torch_compile: bool = True
entropy_checkpointing: bool = False
forward_only: bool = False
strategy: str = "fsdp"
dtype: str = "bfloat16" # ["bfloat16", "float16"]

def __post_init__(self):
assert self.strategy in ["fsdp", "fsdp2"], f"strategy {self.strategy} not supported"


@dataclass
class TrainingWorkerConfig(BaseConfig):
model_type: str = None # model type (language_model/value_model)
model_config: HFModelConfig = None
engine_config: EngineConfig = None
optimizer_config: OptimizerConfig = None
checkpoint_config: CheckpointConfig = None
Loading