diff --git a/tests/special_distributed/run_all.sh b/tests/special_distributed/run_all.sh index c34edf2229b..3d6c5c71e54 100644 --- a/tests/special_distributed/run_all.sh +++ b/tests/special_distributed/run_all.sh @@ -15,4 +15,5 @@ #!/usr/bin/env bash set -e -x -torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py \ No newline at end of file +torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_tensor_dict.py +torchrun --nproc-per-node=4 --standalone tests/special_distributed/test_torch_functional.py diff --git a/tests/special_distributed/test_torch_functional.py b/tests/special_distributed/test_torch_functional.py new file mode 100644 index 00000000000..d07d335f5a3 --- /dev/null +++ b/tests/special_distributed/test_torch_functional.py @@ -0,0 +1,35 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import torch + +from verl.utils.torch_functional import allgather_dict_into_dict + +if __name__ == "__main__": + torch.distributed.init_process_group(backend="gloo") + + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + metrics_dict = {"loss": [0 + rank, 1 + rank, 2 + rank], "grad_norm": rank} + + result = allgather_dict_into_dict(data=metrics_dict, group=None) + + assert result["loss"] == [[0, 1, 2], [1, 2, 3], [2, 3, 4], [3, 4, 5]] + assert result["grad_norm"] == [0, 1, 2, 3] + + print(result) diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index d0474d9ab01..2f847ca97d7 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -178,7 +178,7 @@ def add_critic_worker(self, config): if use_legacy_worker_impl in ["auto", "enable"]: from verl.workers.fsdp_workers import CriticWorker elif use_legacy_worker_impl == "disable": - from verl.workers.roles import CriticWorker + from verl.workers.engine_workers import CriticWorker print("Using new worker implementation") else: @@ -223,17 +223,17 @@ def add_reward_model_worker(self, config): if config.reward_model.enable: use_legacy_worker_impl = config.trainer.get("use_legacy_worker_impl", "auto") - if use_legacy_worker_impl in ["auto", "enable"]: + if use_legacy_worker_impl in ["auto", "enable", "disable"]: if config.reward_model.strategy in {"fsdp", "fsdp2"}: from verl.workers.fsdp_workers import RewardModelWorker elif config.reward_model.strategy == "megatron": from verl.workers.megatron_workers import RewardModelWorker else: raise NotImplementedError - elif use_legacy_worker_impl == "disable": - from verl.workers.roles import RewardModelWorker - - print("Using new worker implementation") + # elif use_legacy_worker_impl == "disable": + # from verl.workers.engine_workers import RewardModelWorker + # + # print("Using new worker implementation") else: raise ValueError(f"Invalid use_legacy_worker_impl: {use_legacy_worker_impl}") diff --git a/verl/trainer/sft_trainer.py b/verl/trainer/sft_trainer.py index 3d8f72023ae..4fafb4265f5 100644 --- a/verl/trainer/sft_trainer.py +++ b/verl/trainer/sft_trainer.py @@ -16,6 +16,8 @@ import os from functools import partial +from tensordict.tensorclass import NonTensorData + os.environ["NCCL_DEBUG"] = "WARN" os.environ["TOKENIZERS_PARALLELISM"] = "true" @@ -24,7 +26,6 @@ import hydra import torch import torch.distributed -from codetiming import Timer from omegaconf import OmegaConf from torch.utils.data import DistributedSampler from torchdata.stateful_dataloader import StatefulDataLoader @@ -36,9 +37,9 @@ from verl.utils.dataset.multiturn_sft_dataset import MultiTurnSFTDataset from verl.utils.device import get_device_name, is_cuda_available, is_npu_available from verl.utils.distributed import destroy_global_process_group -from verl.utils.flops_counter import FlopsCounter from verl.utils.logger import log_with_rank from verl.utils.tracking import Tracking +from verl.workers.engine_workers import TrainingWorker if is_cuda_available: pass @@ -74,12 +75,6 @@ def __init__( self.device_name = self.config.trainer.device - from verl.workers.utils.losses import sft_loss - - self.loss_fn = partial(sft_loss, config=None) - - self.flops_counter = FlopsCounter(self.model_config.hf_config) - if self.rank == 0: print(self.config) @@ -108,17 +103,24 @@ def _build_config(self): self.checkpoint_config = omega_conf_to_dataclass(self.config.checkpoint) def _build_engine(self): - from verl.workers.engine import BaseEngine, EngineRegistry + from verl.workers.engine_workers import TrainingWorkerConfig + from verl.workers.utils.losses import sft_loss + + self.loss_fn = partial(sft_loss, config=None) - self.engine: BaseEngine = EngineRegistry.new( + config = TrainingWorkerConfig( model_type="language_model", - backend=self.engine_config.strategy, model_config=self.model_config, engine_config=self.engine_config, optimizer_config=self.optimizer_config, checkpoint_config=self.checkpoint_config, ) + self.training_client = TrainingWorker(config=config) + self.training_client.set_loss_fn(loss_fn=self.loss_fn) + # Note that in SPMD world, this abstraction has to break + self.engine = self.training_client.engine + def _init_engine(self): # patch optimizer config if self.config.trainer.total_training_steps is not None: @@ -138,7 +140,7 @@ def _init_engine(self): if self.test_freq == "after_each_epoch": self.test_freq = self.steps_per_epoch - self.engine.initialize() + self.training_client.reset() def _build_dataset(self): config = self.config @@ -203,6 +205,30 @@ def _build_dataloader(self): else: self.val_dataloader = None + def _get_batch_seqlens(self, data): + # mean over dp group + is_nested = data["input_ids"].is_nested + if is_nested: + batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff() + else: + batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1) + batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp) + + output_tensor = torch.empty( + (batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),), + dtype=batch_seqlens.dtype, + device=self.device_name, + ) # (global_bsz,) + + torch.distributed.all_gather_into_tensor( + output_tensor=output_tensor, + input_tensor=batch_seqlens, + group=self.engine.get_data_parallel_group(), + ) + + batch_seqlens = output_tensor.tolist() + return batch_seqlens + def fit(self): is_logging = self.engine.is_mp_src_rank_with_outputs() and self.engine.get_data_parallel_rank() == 0 @@ -267,56 +293,28 @@ def fit(self): # construct tensordict data = tu.get_tensordict(tensor_dict=data, non_tensor_dict=meta_info) + batch_seqlens = self._get_batch_seqlens(data=data) + # this is necessary. Otherwise, it is interpreted as NonTensorStack + batch_seqlens = NonTensorData(batch_seqlens) - with self.engine.train_mode(): - with Timer(name="update_policy", logger=None) as timer: - output = self.engine.train_batch(data=data, loss_function=self.loss_fn) - lr = self.engine.lr_scheduler_step() + tu.assign_non_tensor(data, update_lr_scheduler=True, global_token_num=batch_seqlens) + + # train for on batch + output = self.training_client.train_batch(data=data) if self.engine.is_mp_src_rank_with_outputs(): - metrics = output["metrics"] - - loss = torch.sum(torch.tensor(metrics["loss"], device=self.device_name)) - - # mean over dp group - is_nested = data["input_ids"].is_nested - if is_nested: - batch_seqlens: torch.Tensor = data["input_ids"].offsets().diff() - else: - batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1) - batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp) - - output_tensor = torch.randint( - 0, - 100, - (batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),), - device=self.device_name, - ) # (global_bsz,) - - torch.distributed.all_gather_into_tensor( - output_tensor=output_tensor, - input_tensor=batch_seqlens, - group=self.engine.get_data_parallel_group(), - ) - torch.distributed.all_reduce( - loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() - ) - - batch_seqlens = output_tensor.tolist() - loss = loss.item() + metrics = tu.get(output, "metrics") # TODO: we can actual accumulate metrics for N steps and perform aggregate metrics - metrics["loss"] = loss metrics["train/loss"] = metrics.pop("loss") metrics["train/grad_norm"] = metrics.pop("grad_norm") - metrics["train/lr"] = lr - metrics["train/global_tokens"] = output_tensor.sum().item() + metrics["train/lr"] = metrics.pop("lr") + metrics["train/mfu"] = metrics.pop("mfu") + metrics["train/global_tokens"] = torch.sum( + torch.tensor(batch_seqlens, device=self.device_name) + ).item() total_tokens += metrics["train/global_tokens"] metrics["train/total_tokens(B)"] = total_tokens / 1e9 - # mfu - delta_time = timer.last - estimated_flops, promised_flops = self.flops_counter.estimate_flops(batch_seqlens, delta_time) - metrics["train/mfu"] = estimated_flops / promised_flops / torch.distributed.get_world_size() if self.engine.get_data_parallel_rank() == 0: tracking.log(data=metrics, step=global_step) @@ -330,12 +328,12 @@ def fit(self): # Perform validation val_losses = [] for val_data in self.val_dataloader: - with self.engine.eval_mode(): - # construct tensordict - val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info) - output = self.engine.infer_batch(data=val_data, loss_function=self.loss_fn) - if self.engine.is_mp_src_rank_with_outputs(): - val_losses.extend(output["metrics"]["loss"]) + val_data = tu.get_tensordict(tensor_dict=val_data, non_tensor_dict=meta_info) + output = self.training_client.infer_batch(val_data) + + if self.engine.is_mp_src_rank_with_outputs(): + metrics = tu.get(output, "metrics") + val_losses.append(metrics["loss"]) if self.engine.is_mp_src_rank_with_outputs(): val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name)) diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index e644894d677..52deaf9acb3 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -297,6 +297,32 @@ def allgather_dict_tensors(tensors: dict[str, torch.Tensor] | TensorDict, size, return output +def allgather_dict_into_dict(data: dict, group=None) -> dict: + """allgather a dict into a dict of list + + Args: + data: a dict + group: the process group to allgather + + Returns: dict containing a list of the results from allgather + + """ + assert isinstance(data, dict), f"Expect data to be a dictionary, Got {type(data)}" + + group_size = torch.distributed.get_world_size(group=group) + + final_metrics = {} + all_metrics_lst = [None for _ in range(group_size)] + torch.distributed.all_gather_object(all_metrics_lst, data, group=group) + + for all_metrics in all_metrics_lst: + for key, val in all_metrics.items(): + if key not in final_metrics: + final_metrics[key] = [] + final_metrics[key].append(val) + return final_metrics + + def split_dict_tensor_into_batches(tensors: TensorDict, batch_size) -> list[TensorDict]: assert tensors.batch_size[0] % batch_size == 0, ( f"input data batch size: {tensors.batch_size[0]}, split batch size: {batch_size}" diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index 57dc6a57c43..f0cebf5ad76 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -17,12 +17,26 @@ from typing import Any, Optional from verl.base_config import BaseConfig +from verl.trainer.config import CheckpointConfig -__all__ = ["FSDPEngineConfig", "McoreEngineConfig"] +from .model import HFModelConfig +from .optimizer import OptimizerConfig + +__all__ = ["FSDPEngineConfig", "McoreEngineConfig", "TrainingWorkerConfig"] @dataclass -class McoreEngineConfig(BaseConfig): +class EngineConfig(BaseConfig): + param_offload: bool = False + optimizer_offload: bool = False + grad_offload: bool = False + forward_only: bool = False + strategy: str = None + dtype: str = "bfloat16" # ["bfloat16", "float16"] + + +@dataclass +class McoreEngineConfig(EngineConfig): """Configuration for Megatron parallelism. The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. @@ -51,10 +65,7 @@ class McoreEngineConfig(BaseConfig): # sequence_parallel is not listed as a frozen field for auto-correction purpose _mutable_fields = BaseConfig._mutable_fields | {"sequence_parallel"} - - param_offload: bool = False - grad_offload: bool = False - optimizer_offload: bool = False + # mcore parallelism tensor_model_parallel_size: int = 1 expert_model_parallel_size: int = 1 expert_tensor_parallel_size: Optional[int] = None @@ -72,9 +83,7 @@ class McoreEngineConfig(BaseConfig): override_mcore_model_config: dict[str, Any] = field(default_factory=dict) use_mbridge: bool = False vanilla_mbridge: bool = True - forward_only: bool = False strategy: str = "megatron" - dtype: str = "bfloat16" # ["bfloat16", "float16"] def __post_init__(self) -> None: """config validation logics go here""" @@ -86,7 +95,7 @@ def __post_init__(self) -> None: @dataclass -class FSDPEngineConfig(BaseConfig): +class FSDPEngineConfig(EngineConfig): """Configuration for FSDP (Fully Sharded Data Parallel). The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. @@ -108,9 +117,8 @@ class FSDPEngineConfig(BaseConfig): # ulysses_sequence_parallel_size is mutable for backward compatibility _mutable_fields = BaseConfig._mutable_fields | {"ulysses_sequence_parallel_size"} + # fsdp specific flags wrap_policy: dict[str, Any] = field(default_factory=dict) - param_offload: bool = False - optimizer_offload: bool = False offload_policy: bool = False reshard_after_forward: bool = True fsdp_size: int = -1 @@ -122,9 +130,16 @@ class FSDPEngineConfig(BaseConfig): entropy_from_logits_with_chunking: bool = False use_torch_compile: bool = True entropy_checkpointing: bool = False - forward_only: bool = False strategy: str = "fsdp" - dtype: str = "bfloat16" # ["bfloat16", "float16"] def __post_init__(self): assert self.strategy in ["fsdp", "fsdp2"], f"strategy {self.strategy} not supported" + + +@dataclass +class TrainingWorkerConfig(BaseConfig): + model_type: str = None # model type (language_model/value_model) + model_config: HFModelConfig = None + engine_config: EngineConfig = None + optimizer_config: OptimizerConfig = None + checkpoint_config: CheckpointConfig = None diff --git a/verl/workers/engine/base.py b/verl/workers/engine/base.py index 139bf393acc..ad4676f1f6c 100644 --- a/verl/workers/engine/base.py +++ b/verl/workers/engine/base.py @@ -109,6 +109,7 @@ def train_batch(self, data: TensorDict, loss_function: Callable) -> Any: outputs = self.forward_backward_batch(data, loss_function, forward_only=False) grad_norm = self.optimizer_step() if self.is_mp_src_rank_with_outputs(): + assert "grad_norm" not in outputs["metrics"] outputs["metrics"]["grad_norm"] = grad_norm return outputs diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 5b4fb9619e3..19e768a4678 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -703,10 +703,11 @@ def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: b loss, metrics = loss_function(model_output=model_output, data=data, dp_group=self.get_data_parallel_group()) # scale loss by num_micro_batch because megatron will scale loss # by n_micro_batch inside pp schedule - loss = loss * data["num_micro_batch"] + scaled_loss = loss * data["num_micro_batch"] else: assert forward_only, "forward_only must be True when loss_function is None" loss = torch.tensor(1.0, device=device) + scaled_loss = loss metrics = {} output = { @@ -716,7 +717,7 @@ def postprocess_micro_batch_func(self, output, data: TensorDict, forward_only: b } # return loss and stats - return loss, output + return scaled_loss, output @EngineRegistry.register(model_type="value_model", backend="megatron") diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 8090b872f6a..d5bf19727a7 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -18,13 +18,16 @@ from typing import Any, Optional import psutil +import torch from codetiming import Timer from omegaconf import DictConfig, open_dict +from tensordict import TensorDict from torch.distributed.device_mesh import init_device_mesh from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.base.decorator import Dispatch, make_nd_compute_dataproto_dispatch_fn, register +from verl.utils import tensordict_utils as tu from verl.utils.config import omega_conf_to_dataclass from verl.utils.device import ( get_device_id, @@ -37,7 +40,8 @@ from verl.utils.memory_utils import aggressive_empty_cache from verl.utils.profiler import DistProfiler, DistProfilerExtension, log_gpu_memory_usage from verl.utils.py_functional import append_to_dict -from verl.workers.config import ActorConfig, CriticConfig, HFModelConfig, RolloutConfig +from verl.utils.torch_functional import allgather_dict_into_dict +from verl.workers.config import ActorConfig, CriticConfig, HFModelConfig, RolloutConfig, TrainingWorkerConfig from verl.workers.rollout.base import BaseRollout, get_rollout_class from verl.workers.utils.losses import ppo_loss, value_loss from verl.workers.utils.padding import left_right_2_no_padding, no_padding_2_padding @@ -46,6 +50,170 @@ logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) +class TrainingWorker(Worker): + """ + TrainingWorker provides a Tinker-like API (https://thinkingmachines.ai/tinker/) as a RayWorkerGroup + to a single controller. Currently, we only provide more coarse grained APIs, + and do not provide exact APIs as Tinker does. But this can be added in the future. + """ + + def __init__(self, config: TrainingWorkerConfig): + Worker.__init__(self) + + from verl.workers.engine import BaseEngine, EngineRegistry + + initialize_global_process_group_ray(timeout_second=None) + + self.config = config + self.model_config = self.config.model_config + self.engine_config = self.config.engine_config + self.optimizer_config = self.config.optimizer_config + self.checkpoint_config = self.config.checkpoint_config + self.device_name = get_device_name() + + # TODO: add DistProfilerExtension + # self.profiler_config = self.config.profiler_config + # tool_config = self.profiler_config.tool_config + # DistProfilerExtension.__init__( + # self, DistProfiler(rank=self.rank, config=self.profiler_config, tool_config=tool_config) + # ) + + self.engine: BaseEngine = EngineRegistry.new( + model_type=self.config.model_type, + backend=self.engine_config.strategy, + model_config=self.model_config, + engine_config=self.engine_config, + optimizer_config=self.optimizer_config, + checkpoint_config=self.checkpoint_config, + ) + + # build dispatch info + self._register_dispatch_collect_info( + mesh_name="train", + dp_rank=self.engine.get_data_parallel_rank(), + is_collect=self.engine.is_mp_src_rank_with_outputs(), + ) + + self.flops_counter = FlopsCounter(self.model_config.hf_config) + + self.loss_fn = None + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_loss_fn(self, loss_fn): + self.loss_fn = loss_fn + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def reset(self): + """ + Reset the model engine to the initial state. If the engine is not initialized, + we initialize it. Otherwise, reload ckpt and reset states + """ + self.engine.initialize() + + def _postprocess_output(self, output, *, global_token_num, delta_time, forward_only): + """ + + Args: + output: a dictionary containing loss, model_outputs and metrics + + Returns: + + """ + # TODO: whether to log memory + # metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024 ** 3) + # metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024 ** 3) + # metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024 ** 3) + + metrics: dict = output.pop("metrics") + # perform all gather in dp group to ensure that it's correct. + # Here each metric in metrics can be a list (micro-batch metrics) or a singleton + # we should always sum the loss of each micro-batch as we scale by global_bsz/global_token + loss = torch.sum(torch.tensor(output.pop("loss"), device=self.device_name)) + torch.distributed.all_reduce( + loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() + ) + loss = loss.item() + + # For grad_norm, we do not perform all reduce because it is already been done when clipping grad + grad_norm = metrics.pop("grad_norm", None) + lr = metrics.pop("lr", None) + + # For other metrics, we perform all gather in dp group + final_metrics = allgather_dict_into_dict(data=metrics, group=self.engine.get_data_parallel_group()) + final_metrics["loss"] = loss + if grad_norm is not None: + final_metrics["grad_norm"] = grad_norm + if lr is not None: + final_metrics["lr"] = lr + # compute mfu + if global_token_num is not None: + estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_token_num, delta_time) + final_metrics["mfu"] = estimated_flops / promised_flops / torch.distributed.get_world_size() + if forward_only: + final_metrics["mfu"] /= 3.0 + # model outputs + model_output = output.pop("model_output", {}) + # We only return final_metrics + final_output = tu.get_tensordict(tensor_dict=model_output, non_tensor_dict={"metrics": final_metrics}) + return final_output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) + def train_batch(self, data: TensorDict) -> TensorDict: + assert self.loss_fn is not None, "loss function can't be None when calling train_batch" + # global_token_num should be a list of number of tokens of each seq in this batch + global_token_num = tu.get(data, key="global_token_num") + + with self.engine.train_mode(), Timer(name="train_batch", logger=None) as timer: + output = self.engine.train_batch(data, loss_function=self.loss_fn) + # containing loss, model_output and metrics + # for training, we only care about loss and metrics + delta_time = timer.last + + update_lr_scheduler = tu.get(data, key="update_lr_scheduler", default=False) + # update lr scheduler + if update_lr_scheduler: + lr = self.engine.lr_scheduler_step() + else: + lr = None + + if self.engine.is_mp_src_rank_with_outputs(): + # we don't need model_output in training. Maybe we change out mind later + output.pop("model_output") + if lr is not None: + output["metrics"]["lr"] = lr + final_output = self._postprocess_output( + output, global_token_num=global_token_num, delta_time=delta_time, forward_only=False + ) + else: + final_output = None + return final_output + + @register(dispatch_mode=make_nd_compute_dataproto_dispatch_fn(mesh_name="train"), blocking=False) + def infer_batch(self, data: TensorDict) -> TensorDict: + # add mfu calculator + global_token_num = tu.get(data, key="global_token_num") + + with self.engine.eval_mode(), Timer(name="eval_batch", logger=None) as timer: + output = self.engine.infer_batch(data, loss_function=self.loss_fn) + delta_time = timer.last + + if self.engine.is_mp_src_rank_with_outputs(): + final_output = self._postprocess_output( + output, global_token_num=global_token_num, delta_time=delta_time, forward_only=True + ) + else: + final_output = None + return final_output + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, max_ckpt_to_keep=None): + return self.engine.save_checkpoint(local_path, hdfs_path, global_step, max_ckpt_to_keep) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def load_checkpoint(self, local_path, hdfs_path=None, del_local_after_load=False): + return self.engine.load_checkpoint(local_path, hdfs_path, del_local_after_load) + + class ActorWorker(Worker, DistProfilerExtension): """ This worker can be instantiated as a standalone actor or a standalone reference policy diff --git a/verl/workers/utils/losses.py b/verl/workers/utils/losses.py index a11f619d061..27178c1aba5 100644 --- a/verl/workers/utils/losses.py +++ b/verl/workers/utils/losses.py @@ -50,7 +50,7 @@ def sft_loss(config: ActorConfig, model_output, data: TensorDict, dp_group=None) response_mask = data["response_mask"].to(bool) loss = -masked_sum(log_prob, response_mask) / batch_num_tokens * dp_size - return loss, {"loss": loss.detach().item()} + return loss, {} def _slice_response_from_unpad_output(tensor: torch.Tensor, data: TensorDict) -> torch.Tensor: