diff --git a/.github/workflows/kernels.yml b/.github/workflows/kernels.yml new file mode 100644 index 00000000000..987800221ba --- /dev/null +++ b/.github/workflows/kernels.yml @@ -0,0 +1,58 @@ +name: kernels +# latest version: Megatron-LM core_r0.11.0 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0 + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - .github/workflows/kernels.yml + pull_request: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - "verl/trainer/config/*.yaml" + - .github/workflows/kernels.yml + - "tests/e2e/*.sh" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + e2e_gsm8k_megatron: + runs-on: [self-hosted, l20-0] + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-megatron0.11.0-v0.0.6 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test] + - name: Testing LinearCrossEntropy Correction, Computation Time and Memory Consumption + run: | + python3 tests/kernel/test_linear_cross_entropy.py + - name: Testing VocabParallelEntropy + run: | + bash tests/kernel/run_vocab_parallel_entropy.sh \ No newline at end of file diff --git a/tests/kernel/run_vocab_parallel_entropy.sh b/tests/kernel/run_vocab_parallel_entropy.sh new file mode 100644 index 00000000000..ee7cd3a7568 --- /dev/null +++ b/tests/kernel/run_vocab_parallel_entropy.sh @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env bash + +set -e -x +torchrun --nproc-per-node=8 --standalone tests/kernel/test_vocab_parallel_entropy.py \ No newline at end of file diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py new file mode 100644 index 00000000000..e48d8c4864d --- /dev/null +++ b/tests/kernel/test_linear_cross_entropy.py @@ -0,0 +1,585 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist +import typing + +try: + from verl.utils.kernel import linear_cross_entropy +except ImportError: + # FIXME: remove these manually included paths + import os + import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) +finally: + from verl.utils.kernel import linear_cross_entropy, set_backward_method, BackwardEnum + +import verl.utils.torch_functional as verl_F +from verl.utils.torch_functional import logprobs_from_logits + +compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + + +def run_torch_entropy(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction="none") -> typing.List[torch.Tensor]: + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] + entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] + logprobs = torch.neg(logprobs) + return logprobs, entropy + + +class TorchEntropyTP(torch.autograd.Function): + """ + it is used for testing the correctness of the kernel + it is not efficient and is not recommended to use in practice + """ + + @staticmethod + def forward(ctx, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, + dist_process_group: torch.distributed.ProcessGroup): + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + whole_logits = torch.empty((logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), + dtype=logits.dtype, + device=logits.device) + whole_logits_ref = [ + whole_logits[:, i * logits.shape[1]:(i + 1) * logits.shape[1]] + for i in range(dist.get_world_size(dist_process_group)) + ] + dist.all_gather(whole_logits_ref, logits, group=dist_process_group) + + pd = torch.nn.functional.softmax(whole_logits, dim=-1) + entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + + logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") + logprobs = torch.neg(logprobs) + + ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) + ctx.dist_process_group = dist_process_group + + return logprobs, entropy + + @staticmethod + def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): + hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors + dist_process_group = ctx.dist_process_group + + batch_size, hidden_size = hidden.shape + vocab_size = weight.shape[1] + world_size = dist.get_world_size(dist_process_group) + rank = dist.get_rank(dist_process_group) + + # Compute softmax probabilities + maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True) + exp_logits = torch.exp(whole_logits - maximum) + accumulate = exp_logits.sum(dim=-1, keepdim=True) + pd = exp_logits / accumulate + + # Gradient for entropy + # entropy = entropy_a - entropy_b + # entropy_a = log(sum(exp(logits))) + # entropy_b = sum(pd * logits) + # d_entropy_a/d_logits = pd + # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = d_entropy_a - d_entropy_b + # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1)) + d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) + + # Gradient for logprobs + # logprobs = -cross_entropy = -log(pd[labels]) + # d_logprobs/d_logits = (pd - one_hot(labels)) + one_hot = torch.zeros_like(whole_logits) + one_hot.scatter_(1, labels.unsqueeze(1), 1) + g_logprobs = torch.neg(g_logprobs) + d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot) + # NOTE: This will lead to wrong result + # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot + + # Combine gradients + d_logits = d_logits_entropy + d_logits_logprobs + + # Get local slice of gradients + local_d_logits = d_logits[:, rank * vocab_size:(rank + 1) * vocab_size] + + # Compute gradients for hidden and weight + d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32).T) + d_weight = torch.matmul(hidden.to(torch.float32).T, local_d_logits) + + return d_hidden, d_weight, None, None + + +run_torch_entropy_tp = TorchEntropyTP.apply + +def run_verl_actor_entropy(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction="none") -> typing.List[torch.Tensor]: + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + # compute entropy + entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + logprobs = logprobs_from_logits(logits=logits, labels=labels) + return logprobs, entropy + + +class TestLinearCrossEntropy: + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + gc.collect() + torch.cuda.synchronize() + + def generate_hyper(self): + self.num_tokens = 80 + self.hidden_size = 4096 + self.vocab_size = 152064 + self.dtype = torch.bfloat16 + + def generate_forward_inputs(self): + hidden = (torch.empty((self.num_tokens, self.hidden_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) + weight = (torch.empty((self.hidden_size, self.vocab_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) + labels = torch.randint(0, self.vocab_size, (self.num_tokens,), device="cuda") + return hidden, weight, labels + + def generate_backward_inputs(self): + g_entropy = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)) + g_logprobs = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)) + return g_entropy, g_logprobs + + def verify_correctness(self, iterations=5): + self.cleanup() + self.generate_hyper() + + torch_forward_latency = list() + torch_backward_latency = list() + verl_forward_latency = list() + verl_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for i in range(iterations): + print(f"[INFO]: Iteration {i + 1} / {iterations}...", end='\r') + hidden, weight, labels = self.generate_forward_inputs() + + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + torch_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (verl_logprobs, verl_entropy) = run_verl_actor_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + verl_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none") + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(verl_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) + + # backward + g_entropy, g_logprobs = self.generate_backward_inputs() + + start_event.record() + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + torch_backward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (d_verl_hidden, d_verl_weight) = torch.autograd.grad((verl_entropy, verl_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + verl_backward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), + (hidden, weight), (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=1e-2, rtol=1e-4) + + # remove first latency + torch_forward_latency = torch_forward_latency[1:] + torch_backward_latency = torch_backward_latency[1:] + verl_forward_latency = verl_forward_latency[1:] + verl_backward_latency = verl_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] + + print(f"\n[INFO]: Verified forward & backward correctness.") + + print(f"[INFO]: Forward pass: Torch implementation average time: " + f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: torch implementation average time: " + f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: VeRL implementation average time: " + f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: VeRL implementation average time: " + f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: " + f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: " + f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") + + def check_storage(self, method_name, run_forward, reduction="none"): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + torch.cuda.reset_peak_memory_stats() + (logprobs, entropy) = run_forward(hidden, weight, labels, reduction) + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") + + g_entropy, g_logprobs = self.generate_backward_inputs() + + torch.cuda.reset_peak_memory_stats() + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((entropy, logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + torch.cuda.synchronize() + torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") + + def check_storage_all(self): + self.check_storage("Torch", run_torch_entropy) + self.check_storage("VeRL", run_verl_actor_entropy) + self.check_storage("Kernel", linear_cross_entropy) + + +class TestLinearCrossEntropy_TensorParallel: + + def __init__(self): + dist.init_process_group(backend="nccl") + self.group = dist.group.WORLD + + self.local_rank = dist.get_rank(self.group) + self.world_size = dist.get_world_size(self.group) + device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(device) + print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}") + + def shutdown(self): + dist.destroy_process_group() + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + gc.collect() + torch.cuda.synchronize() + + def generate_hyper(self): + self.num_tokens = 80 + self.hidden_size = 4096 + self.vocab_size = 152064 + self.dtype = torch.bfloat16 + self.iterations = 5 + + def generate_forward_inputs(self): + hidden = (torch.empty((self.num_tokens, self.hidden_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) + weight = (torch.empty((self.hidden_size, self.vocab_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) + labels = torch.randint(0, self.vocab_size * self.world_size, (self.num_tokens,), device="cuda") + return hidden, weight, labels + + def generate_backward_inputs(self): + g_entropy = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)) + g_logprobs = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)) + return g_entropy, g_logprobs + + def verify_torch_itself(self): + self.cleanup() + self.generate_hyper() + + for i in range(self.iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + # forward pass + whole_weight = torch.empty((weight.shape[0], weight.shape[1] * self.world_size), + dtype=weight.dtype, + device=weight.device) + whole_weight_ref = [ + whole_weight[:, i * weight.shape[1]:(i + 1) * weight.shape[1]] for i in range(self.world_size) + ] + dist.all_gather(whole_weight_ref, weight, group=self.group) + whole_weight.requires_grad_() + + (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels) + + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.group) + + torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + (single_d_hidden, single_d_weight) = torch.autograd.grad((single_entropy, single_logprobs), + (hidden, whole_weight), (g_entropy, g_logprobs), + retain_graph=False) + + (tp_d_hidden, tp_d_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(tp_d_weight, + single_d_weight[:, self.local_rank * tp_d_weight.shape[1]:(self.local_rank + 1) * + tp_d_weight.shape[1]]) #, + # atol=1e-3, rtol=1e-4) + if self.local_rank == 0: + print(f"[PASS] torch TP correctness is verified") + + def check_torch_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.group) + torch.cuda.synchronize() + forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_tp_hidden, d_tp_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + torch.cuda.synchronize() + backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") + print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") + + def verify_kernel_correctness(self): + self.cleanup() + self.generate_hyper() + + torch_forward_latency = list() + torch_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for i in range(self.iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.group) + end_event.record() + torch.cuda.synchronize() + torch_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none", self.group) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + start_event.record() + (torch_d_hidden, torch_d_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + torch_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + start_event.record() + (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), + (hidden, weight), (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=1e-2, rtol=1e-4) + + # remove first latency + torch_forward_latency = torch_forward_latency[1:] + torch_backward_latency = torch_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] + + if self.local_rank == 0: + print(f"\n[PASS]: Verified kernel forward & backward correctness.") + + print(f"[INFO]: Forward pass: Torch implementation average time: " + f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: torch implementation average time: " + f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: " + f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: " + f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") + + def check_kernel_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none", self.group) + torch.cuda.synchronize() + kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + torch.cuda.synchronize() + kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") + print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + + +if __name__ == "__main__": + # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernel/test_linear_cross_entropy.py + + # Check if running with torchrun (distributed mode) + is_distributed = False + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + is_distributed = True + print(f"[INFO]: Running in {'distributed' if is_distributed else 'non-distributed'} mode") + torch.manual_seed(233376 + int(os.environ.get("RANK", 0))) + + # set_backward_method(BackwardEnum._Total_Fuse_MN) + + if not is_distributed: + test = TestLinearCrossEntropy() + + test.verify_correctness() + test.check_storage_all() + else: + test = TestLinearCrossEntropy_TensorParallel() + + test.verify_torch_itself() + test.check_torch_storage() + test.verify_kernel_correctness() + test.check_kernel_storage() + + test.shutdown() diff --git a/tests/kernel/test_vocab_parallel_entropy.py b/tests/kernel/test_vocab_parallel_entropy.py new file mode 100644 index 00000000000..5fd899f2d62 --- /dev/null +++ b/tests/kernel/test_vocab_parallel_entropy.py @@ -0,0 +1,118 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ['NCCL_DEBUG'] = 'WARN' + +import torch +import torch.distributed + +from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy +from verl.utils.torch_functional import logprobs_from_logits, entropy_from_logits + +from verl.utils.debug import log_gpu_memory_usage + +from megatron.core import mpu + + +class Utils: + world_size = torch.cuda.device_count() + rank = int(os.environ.get('LOCAL_RANK', '0')) + + @staticmethod + def initialize_distributed(): + print(f'Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}') + torch.cuda.set_device(Utils.rank % torch.cuda.device_count()) + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '7000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group(backend='nccl', + world_size=Utils.world_size, + rank=Utils.rank, + init_method=init_method) + print(f'successfully created process group') + + @staticmethod + def destroy_model_parallel(): + mpu.destroy_model_parallel() + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + @staticmethod + def initialize_model_parallel(tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None): + mpu.destroy_model_parallel() + if not torch.distributed.is_initialized(): + Utils.initialize_distributed() + mpu.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) + + +def test_vocab_parallel_entropy(): + # check vocab_parallel_entropy + Utils.world_size = 8 + Utils.initialize_model_parallel(8, 1) + + batch_size = 2 + seqlen = 128 + vocab_size = 155136 + + logits = torch.randn(batch_size * seqlen, vocab_size, device='cuda', requires_grad=True) + target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device='cuda', dtype=torch.int64) + + # broadcast across tp + torch.distributed.broadcast(logits, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + torch.distributed.broadcast(target, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + + tp_rank = mpu.get_tensor_model_parallel_rank() + vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size() + + # get the local logits of each tp + vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * + vocab_size_per_tp].requires_grad_() + logits.grad = None + vocab_parallel_logits.grad = None + + log_gpu_memory_usage('begin') + output_entropy = vocab_parallel_entropy(vocab_parallel_logits) + log_gpu_memory_usage('after forward') + grad_output = torch.randn_like(output_entropy) + output_entropy.backward(grad_output) + log_gpu_memory_usage('after backward') + + target_entropy = entropy_from_logits(logits) + torch.testing.assert_close(output_entropy, target_entropy) + target_entropy.backward(grad_output) + torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits.grad) + # make sure logits is not altered + torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits) + + if mpu.get_tensor_model_parallel_rank() == 0: + print('test_vocab_parallel_entropy passes') + + Utils.destroy_model_parallel() + + +if __name__ == '__main__': + test_vocab_parallel_entropy() diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py index e9598202826..196900aae3d 100644 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -35,6 +35,8 @@ from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad +from verl.utils.kernel import linear_cross_entropy +from verl.models.transformers.common import FusedCausalLMOutputWithPast """ TODO: 1. Add weight initialization. Here we need to be careful on TP weight init. @@ -180,7 +182,10 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, + ) -> Union[Tuple, FusedCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -199,17 +204,45 @@ def forward( ) hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) @@ -315,7 +348,10 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, + ) -> Union[Tuple, FusedCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -346,25 +382,62 @@ def forward( max_seqlen_in_batch=max_seqlen_in_batch) hidden_states = outputs + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back, maybe done later + # move outside + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - return CausalLMOutputWithPast( + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) @@ -391,8 +464,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids, attention_mask, position_ids) + output = super().forward(input_ids, attention_mask, position_ids, labels, temperature, fuse_entropy_logprobs) output.logits = torch.squeeze(output.logits, dim=-1) return output @@ -578,6 +654,9 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -612,24 +691,62 @@ def forward( if self.post_process: hidden_states = outputs - # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back + # move outside + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) else: return outputs @@ -659,8 +776,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, temperature=temperature, fuse_entropy_logprobs=fuse_entropy_logprobs) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) return output diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py index 28af135ea5d..e64563cce1d 100644 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -33,6 +33,8 @@ from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config +from verl.utils.kernel import linear_cross_entropy +from verl.models.transformers.common import FusedCausalLMOutputWithPast from .layers import ParallelQwen2DecoderLayer, ParallelQwen2RMSNorm, ParallelQwen2DecoderLayerRmPad """ TODO: @@ -179,6 +181,9 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -198,20 +203,48 @@ def forward( ) hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( + + logits = None + log_probs = None + entropy = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -314,7 +347,10 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, + ) -> Union[Tuple, FusedCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -345,25 +381,64 @@ def forward( max_seqlen_in_batch=max_seqlen_in_batch) hidden_states = outputs + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back, move to later + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, + seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, + seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - return CausalLMOutputWithPast( + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy, ) @@ -390,8 +465,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids, attention_mask, position_ids) + output = super().forward(input_ids, attention_mask, position_ids, labels, temperature, fuse_entropy_logprobs) output.logits = torch.squeeze(output.logits, dim=-1) return output @@ -626,6 +704,9 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -660,23 +741,62 @@ def forward( if self.post_process: hidden_states = outputs - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( + + log_probs = None + entropy = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + # move to outside + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, + seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, + seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy, ) else: return outputs @@ -706,8 +826,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, temperature=temperature, fuse_entropy_logprobs=fuse_entropy_logprobs) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) return output diff --git a/verl/models/transformers/common.py b/verl/models/transformers/common.py new file mode 100644 index 00000000000..c5b9216d62e --- /dev/null +++ b/verl/models/transformers/common.py @@ -0,0 +1,11 @@ +import torch +from transformers.models.llama.modeling_llama import CausalLMOutputWithPast +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast + +class FusedCausalLMOutputWithPast(CausalLMOutputWithPast): + log_probs: torch.Tensor + entropy: torch.Tensor + +class FusedQwen2VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): + log_probs: torch.Tensor + entropy: torch.Tensor \ No newline at end of file diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index 886ccb67d68..532cea9d9bb 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from typing import Optional, Tuple, Callable +from typing import Optional, Tuple, Callable, Union, List import sys if sys.version_info >= (3, 11): from typing import Unpack @@ -26,6 +26,8 @@ from transformers.modeling_flash_attention_utils import _flash_attention_forward from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from verl.utils.kernel import linear_cross_entropy +from .common import FusedCausalLMOutputWithPast logger = logging.get_logger(__name__) @@ -224,3 +226,95 @@ def llama_attn_forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + +def llama_fused_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: Optional[float] = None, + fuse_entropy_logprobs: bool = False, + **kwargs, +) -> Union[Tuple, FusedCausalLMOutputWithPast]: + """ + Codes patch to huggingface/transformers LlamaForCausalLM for Fused lmhead/Entropy/CrossEntropy. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # DO not support TP here + + logits = None + loss = None + log_probs = None + entropy = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + logits_rmpad = logits.squeeze(0) # (total_nnz, vocab_size) + + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + # Inferencce mode + logits = self.lm_head(hidden_states) + # loss is not needed + # if labels is not None: + # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return FusedCausalLMOutputWithPast( + loss=loss, + logits=logits, + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + \ No newline at end of file diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 2ee0d3f1a94..60748bebadd 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -90,14 +90,30 @@ def apply_monkey_patch(model: PreTrainedModel): # TODO: VLM models only, unify monkey patch to LLM models. if model.config.model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope - from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 + from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward, qwen2_vl_fused_forward + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2, Qwen2ForCasalLM + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2, Qwen2_5_VLForConditionalGeneration + + Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_fused_forward + Qwen2ForCasalLM.forward = qwen2_fused_forward Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward + print("Monkey patch FlashAttention2.forward in Qwen2VL") return + elif model.config.model_type in ("llama", "qwen2"): + from verl.models.transformers.qwen2 import qwen2_fused_forward + from verl.models.transformers.llama import llama_fused_forward + + from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + from transformers.models.llama.modeling_llama import LlamaFroCausalLM + + LlamaFroCausalLM.forward = llama_fused_forward + Qwen2ForCausalLM.forward = qwen2_fused_forward + + print("Monkey patch forward in Qwen2 and Llama") + # transformers<=4.47.1 if hasattr(module, "_flash_attention_forward"): diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py index 63d9ae98b5e..295537a9c4a 100644 --- a/verl/models/transformers/qwen2.py +++ b/verl/models/transformers/qwen2.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from typing import Optional, Tuple, Callable +from typing import Optional, Tuple, Callable, List, Union from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.cache_utils import Cache @@ -22,6 +22,8 @@ from transformers.processing_utils import Unpack from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from verl.utils.kernel import linear_cross_entropy +from .common import FusedCausalLMOutputWithPast logger = logging.get_logger(__name__) @@ -224,3 +226,90 @@ def qwen2_attn_forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + +def qwen2_fused_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + fuse_entropy_logprobs: bool = False, + **kwargs, +) -> Union[Tuple, FusedCausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + # DO not support TP here + + logits = None + loss = None + log_probs = None + entropy = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + logits_rmpad = logits.squeeze(0) # (total_nnz, vocab_size) + + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + # Inferencce mode + logits = self.lm_head(hidden_states) + # loss is not needed + # if labels is not None: + # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return FusedCausalLMOutputWithPast( + loss=loss, + logits=logits, + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + \ No newline at end of file diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 718b9ca6f5b..e3de116fba7 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Union import inspect import torch +from torch.nn import CrossEntropyLoss import os from transformers.utils import is_flash_attn_greater_or_equal from transformers.modeling_flash_attention_utils import _flash_attention_forward from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from verl.utils.kernel import linear_cross_entropy +from .common import FusedCausalLMOutputWithPast try: from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -288,3 +291,169 @@ def ulysses_flash_attn_forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, None + +def qwen2_vl_fused_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + fuse_entropy_logprobs: bool = False, +) -> Union[Tuple, FusedCausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + delta = delta.to(position_ids.device) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + # DO not support TP here + + logits = None + loss = None + log_probs = None + entropy = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + logits = self.lm_head(hidden_states) + + logits_rmpad = logits.squeeze(0) # (total_nnz, vocab_size) + + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + # Inferencce mode + logits = self.lm_head(hidden_states) + # loss may not needed + # if labels is not None: + # # Upcast to float if we need to compute the loss to avoid potential precision issues + # logits = logits.float() + # # Shift so that tokens < n predict n + # shift_logits = logits[..., :-1, :].contiguous() + # shift_labels = labels[..., 1:].contiguous() + # # Flatten the tokens + # loss_fct = CrossEntropyLoss() + # shift_logits = shift_logits.view(-1, self.config.vocab_size) + # shift_labels = shift_labels.view(-1) + # # Enable model parallelism + # shift_labels = shift_labels.to(shift_logits.device) + # loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return FusedCausalLMOutputWithPast( + loss=loss, + logits=logits, + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + \ No newline at end of file diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 2bd634aaa75..07eeea37c54 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -41,6 +41,7 @@ actor_rollout_ref: ulysses_sequence_parallel_size: 1 # sp size checkpoint: contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + use_fused_kernels: True optim: lr: 1e-6 lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py new file mode 100644 index 00000000000..1f1d534f632 --- /dev/null +++ b/verl/utils/kernel/__init__.py @@ -0,0 +1,35 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .linear_cross_entropy import linear_cross_entropy +from .kernels import set_backward_method, BackwardEnum + +__all__ = ["linear_cross_entropy", "set_backward_method", "BackwardEnum"] diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py new file mode 100644 index 00000000000..edd8e78d55f --- /dev/null +++ b/verl/utils/kernel/kernels.py @@ -0,0 +1,871 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Implementations of the linear cross entropy with token entropy kernel. +""" + +import typing +from dataclasses import dataclass +import torch +import torch.distributed as dist +import triton +import triton.language as tl + + +@dataclass +class EntropyReductionEnum: + """ + Enum for the reduction method of cross entropy. + """ + _None = 0 + _Sum = 1 + _Mean = 2 + + +def get_entropy_reduction_enum_number(reduction: str) -> int: + """ + Get the enum number for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if reduction == "none": + _enum = EntropyReductionEnum._None + elif reduction == "sum": + _enum = EntropyReductionEnum._Sum + elif reduction == "mean": + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid reduction: {reduction}") + return _enum + + +def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: + """ + Get the enum for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None + if ce_reduction == 0: + _enum = EntropyReductionEnum._None + elif ce_reduction == 1: + _enum = EntropyReductionEnum._Sum + elif ce_reduction == 2: + _enum = EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid ce_reduction: {ce_reduction}") + return _enum + + +@dataclass +class BackwardEnum: + """ + Enum for the backward method. + """ + _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + + +_BACKWARD: BackwardEnum = BackwardEnum._Total_Separate + + +def set_backward_method(backward_method: BackwardEnum): + """ + Set the backward method. + """ + global _BACKWARD + _BACKWARD = backward_method + + +@triton.autotune( + configs=[triton.Config({ + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64 + }, num_stages=3, num_warps=4)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_kernel_general_mainloop( + rank, + hidden_ptr, + weight_ptr, + labels_ptr, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + stride_hidden_m, + stride_hidden_k, + stride_weight_k, + stride_weight_n, + max_ptr, + stride_max_m, + stride_max_n, + accu_ptr, + stride_accu_m, + stride_accu_n, + entropy_b_ptr, + stride_entropy_b_m, + stride_entropy_b_n, + global_logprobs_ptr, + stride_global_logprobs, + global_logprobs_scalar_ptr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + """ + forward mainloop + """ + pid = tl.program_id(axis=0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + if pid_m == 0 and pid_n == 0: + tl.store(global_logprobs_scalar_ptr, 0.0) + + # create pointers for the first blocks of hidden + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + + # load labels for this block + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + + # traverse over N dimension + # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) + _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for n in range(0, num_pid_n): + offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + + # iterate over K dimension + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + # load the next block of hidden and weight + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0) + _weight = tl.load(weight_ptrs, + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min( + (pid_n + 1) * vocab_per_split, vocab_size))), + other=0.0) + + # GEMM + logits = tl.dot(_hidden, _weight, logits) + + # advance the ptrs to the next K block + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + # reset hidden_ptrs for next iteration + hidden_ptrs -= hidden_size * stride_hidden_k + + # update global maximum + _max_old = _max + m_pid_n = tl.max(logits, axis=1) + _max = tl.maximum(_max_old, m_pid_n) + + exp_logits = tl.exp(logits - _max[:, None]) + coeff = tl.exp(_max_old - _max) + _accu = coeff * _accu + tl.sum(exp_logits, axis=1) + + _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) + + label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + _logprobs += tl.sum(logits * label_mask, axis=1) + + # store maximum + offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_max_n = pid_n + maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m + tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store entropy + accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m + tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) + entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m + tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store logprobs + vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size + vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size + mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) + mask &= (offs_am < num_tokens) + global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs + # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) + tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n, num_tokens, num_splits, + global_max_ptr, stride_global_max, accu_ptr, stride_accu_m, stride_accu_n, + global_accu_ptr, stride_global_accu, entropy_b_ptr, stride_entropy_b_m, + stride_entropy_b_n, global_entropy_b_ptr, stride_global_entropy_b, + global_entropy_ptr, stride_global_entropy, global_logprobs_ptr, + stride_global_logprobs, global_logprobs_scalar_ptr, reduction: int, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + """ + foward epilogue + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + + _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) + + entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n + _entropy_b = tl.load(entropy_b_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + + # local reduction + _max_old = global_max + _local_max = tl.max(_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + _scale = tl.exp(_max - global_max[:, None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + maximum_ptrs = global_max_ptr + offs_m * stride_global_max + tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + + # store entropy_b + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + # store entropy + global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu + tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) + global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy + tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) + # update logprobs + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + + global_logprobs = -1 * global_logprobs + if reduction == 0: + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + elif reduction == 2: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue_tp( + num_tokens, num_splits, reduced_max_ptr, stride_reduced_max_m, stride_reduced_max_n, original_max_ptr, + stride_original_max_m, stride_original_max_n, accu_ptr, stride_accu_m, stride_accu_n, entropy_b_ptr, + stride_entropy_b_m, stride_entropy_b_n, global_max_ptr, stride_global_max, global_accu_ptr, stride_global_accu, + global_entropy_b_ptr, stride_global_entropy_b, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load(reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + + offs_n[None, :] * stride_reduced_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + _original_max = tl.load(original_max_ptr + offs_m[:, None] * stride_original_max_m + + offs_n[None, :] * stride_original_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + _accu = tl.load(accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + + # local reduce-max + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # update entropy_b + _entropy_b = tl.load(entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def efficient_entropy_triton_epilogue_tp_update(num_tokens, logprobs_ptr, stride_logprobs, maximum_ptr, stride_maximum, + accumulate_ptr, stride_accumulate, entropy_b_ptr, stride_entropy_b, + entropy_ptr, stride_entropy, logprobs_scalar_ptr, reduction: int, + BLOCK_SIZE_M: tl.constexpr): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.fdiv(entropy_b, accumulate) + tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + + entropy = tl.log(accumulate) + maximum - entropy_b + tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) + + logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = maximum + tl.log(accumulate) - logprobs + + logprobs = -1 * logprobs + if reduction == 0: + tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + logprobs_scalar = tl.sum(logprobs, axis=0) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + elif reduction == 2: + logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + + +_dedicated_stream, _dedicated_events = None, None + + +def efficient_entropy_forward( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[int] = 2, + dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: + """ + forward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[0] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + global _dedicated_stream, _dedicated_events + _dedicated_stream = torch.cuda.Stream(hidden.device) + _dedicated_events = [torch.cuda.Event() for _ in range(2)] + efficient_entropy_forward._initialized = True + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + hidden_size, vocab_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): + logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) + else: + raise ValueError(f"Invalid reduction: {reduction}") + + entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert logprobs.is_contiguous() and entropy.is_contiguous() + + maximum = torch.empty_like(entropy) + accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) + accumulate = accumulate_and_entropy_b_view[0, :] + entropy_b = accumulate_and_entropy_b_view[1, :] + assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + + if REDUCTION == EntropyReductionEnum._None: + _logprobs = logprobs + else: + _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda + + # 1D kernel launch, then split the tile + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + + efficient_entropy_kernel_general_mainloop[mainloop_grid](_rank, hidden, weight, labels, num_tokens, hidden_size, + vocab_size, vocab_per_split, hidden.stride(0), + hidden.stride(1), weight.stride(0), weight.stride(1), _max, + _max.stride(0), _max.stride(1), _accu, _accu.stride(0), + _accu.stride(1), _entropy_b, _entropy_b.stride(0), + _entropy_b.stride(1), _logprobs, _logprobs.stride(0), + logprobs) + + # reduction on maximum and maximum_indices + def epilogue_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + if dist_process_group is None: + efficient_entropy_triton_kernel_epilogue[epilogue_grid](_max, _max.stride(0), _max.stride(1), num_tokens, + num_splits, maximum, maximum.stride(0), _accu, + _accu.stride(0), _accu.stride(1), accumulate, + accumulate.stride(0), _entropy_b, _entropy_b.stride(0), + _entropy_b.stride(1), entropy_b, entropy_b.stride(0), + entropy, entropy.stride(0), _logprobs, + _logprobs.stride(0), logprobs, REDUCTION) + else: + # tensor-parallel + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) + + torch.cuda.current_stream().record_event(_dedicated_events[0]) + with torch.cuda.stream(_dedicated_stream): + _dedicated_stream.wait_event(_dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) + _dedicated_stream.record_event(_dedicated_events[1]) + + efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid](num_tokens, num_splits, _max, _max.stride(0), + _max.stride(1), _max_backup, _max_backup.stride(0), + _max_backup.stride(1), _accu, _accu.stride(0), + _accu.stride(1), _entropy_b, _entropy_b.stride(0), + _entropy_b.stride(1), maximum, maximum.stride(0), + accumulate, accumulate.stride(0), entropy_b, + entropy_b.stride(0)) + torch.cuda.current_stream().wait_event(_dedicated_events[1]) + + dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) + + # update logprobs & entropy + efficient_entropy_triton_epilogue_tp_update[epilogue_grid](num_tokens, _logprobs, _logprobs.stride(0), maximum, + maximum.stride(0), accumulate, accumulate.stride(0), + entropy_b, entropy_b.stride(0), entropy, + entropy.stride(0), logprobs, REDUCTION) + + return (logprobs, entropy, maximum, accumulate, entropy_b) + + +# NOTE: merge d_weight & d_hidden here, split along M & N +@triton.autotune( + configs=[ + triton.Config({ + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16 + }, + num_stages=3, + num_warps=4) + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_mainloop_MN( + num_tokens: int, hidden_size: int, vocab_size: int, rank: int, hidden_ptr, stride_hidden_m, stride_hidden_k, + weight_ptr, stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, + stride_accu, d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, entropy_b_ptr, + stride_entropy_b, d_hidden_ptr, stride_d_hidden_m, stride_d_hidden_k, d_weight_ptr, stride_d_weight_k, + stride_d_weight_n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + backward mainloop, where d_logits & d_hidden & d_weight are fused + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k + d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0) + _weight = tl.load(weight_ptrs, + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + other=0.0) + + logits = tl.dot(_hidden, _weight, logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # loop for d_weight & d_hidden + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0) + _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) + tl.atomic_add(d_weight_ptrs, + _d_weight, + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) + + _weight = tl.load(weight_ptrs, + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + other=0.0) + _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) + tl.atomic_add(d_hidden_ptrs, + _d_hidden, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens)) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k + d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k + + +# NOTE: split tile from d_logits' perspective +@triton.autotune( + configs=[ + triton.Config({ + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16 + }, + num_stages=3, + num_warps=4), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits( + num_tokens: int, hidden_size: int, vocab_size: int, rank: int, hidden_ptr, stride_hidden_m, stride_hidden_k, + weight_ptr, stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, + stride_accu, d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, entropy_b_ptr, + stride_entropy_b, d_logits_ptr, stride_d_logits_m, stride_d_logits_n, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + backward d_logits + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs + + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), + other=0.0) + _weight = tl.load(weight_ptrs, + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + other=0.0) + + logits = tl.dot(_hidden, _weight, logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) + + # store d_logits + d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n + tl.store(d_logits_ptrs, + d_logits.to(hidden_ptr.dtype.element_ty), + mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size)) + + +def efficient_entropy_backward( + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + entropy_b: torch.Tensor, + reduction: typing.Optional[int] = 2, + dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: + """ + backward host function + """ + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[0] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + hidden_size, vocab_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + assert dlogprobs.shape == (num_tokens,) + else: + assert dlogprobs.dim() == 0 + + assert dlogprobs.is_contiguous() and dentropy.is_contiguous() + assert dlogprobs.is_cuda and dentropy.is_cuda + assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device + assert dentropy.shape == (num_tokens,) + + d_hidden, d_weight = None, None + if _BACKWARD == BackwardEnum._Total_Fuse_MN: + d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) + d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) + elif _BACKWARD == BackwardEnum._Total_Separate: + d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + assert maximum.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and acc.device == hidden.device + assert maximum.shape == labels.shape == acc.shape + assert maximum.is_cuda and acc.is_cuda + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) + + if _BACKWARD == BackwardEnum._Total_Fuse_MN: + + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + d_hidden, + d_hidden.stride(0), + d_hidden.stride(1), + d_weight, + d_weight.stride(0), + d_weight.stride(1), + ) + elif _BACKWARD == BackwardEnum._Total_Separate: + _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype) + + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + + efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( + num_tokens, + hidden_size, + vocab_size, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + ) + + torch.matmul(_d_logits, weight.T, out=d_hidden) + torch.matmul(hidden.T, _d_logits, out=d_weight) + return d_hidden, d_weight diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py new file mode 100644 index 00000000000..20c6f2cbf4f --- /dev/null +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -0,0 +1,73 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist +import typing +from . import kernels + + +class LinearCrossEntropy(torch.autograd.Function): + + @staticmethod + def forward(ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[str] = "mean", + dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: + with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + + logprobs, entropy, _maximum, _accumulate, _entropy_b =\ + kernels.efficient_entropy_forward(hidden, weight, labels, REDUCTION, + dist_process_group) + + ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) + ctx.REDUCTION = REDUCTION + ctx.dist_process_group = dist_process_group + return logprobs, entropy + + @staticmethod + def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> typing.List[torch.Tensor]: + with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): + (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors + REDUCTION = ctx.REDUCTION + dist_process_group = ctx.dist_process_group + + d_hidden, d_weight = kernels.efficient_entropy_backward(dlogprobs, dentropy, hidden, weight, labels, + _maximum, _accumulate, _entropy_b, REDUCTION, + dist_process_group) + + return (d_hidden, d_weight, None, None, None) + + +linear_cross_entropy = LinearCrossEntropy.apply diff --git a/verl/utils/megatron/tensor_parallel.py b/verl/utils/megatron/tensor_parallel.py index 032cabbedbe..bdbee9e04f5 100644 --- a/verl/utils/megatron/tensor_parallel.py +++ b/verl/utils/megatron/tensor_parallel.py @@ -101,19 +101,21 @@ class _VocabParallelEntropy(torch.autograd.Function): @staticmethod def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: + + @torch.compile(dynamic=True) + def mul_reduce(a, b): + return (a * b).sum(dim=-1, keepdim=True) + logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group()) normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max - normalized_exp_logits = normalized_vocab_parallel_logits.exp() + normalized_exp_logits = normalized_vocab_parallel_logits.exp_() normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group()) - softmax_logits = normalized_exp_logits / normalized_sum_exp_logits + softmax_logits = normalized_exp_logits.div(normalized_sum_exp_logits) # This consume too much VRAM, causing OOM, try optimize # sum_softmax_times_logits = (softmax_logits * vocab_parallel_logits).sum(dim=-1, keepdim=True) - original_shape = softmax_logits.shape - sum_softmax_times_logits = torch.bmm(softmax_logits.view(-1, 1, original_shape[-1]), - vocab_parallel_logits.view(-1, original_shape[-1], - 1)).view(original_shape[:-1] + (1,)) + sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group()) entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) @@ -122,8 +124,14 @@ def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors - grad_input = grad_output.unsqueeze(dim=-1) * softmax_logits * (sum_softmax_times_logits - vocab_parallel_logits) - return grad_input + # reuse softmax_logits as grad + vocab_parallel_logits.sub_(sum_softmax_times_logits) + softmax_logits.mul_(vocab_parallel_logits) + softmax_logits.mul_(grad_output.unsqueeze(dim=-1)) + # recover vocab_parallel_logits + vocab_parallel_logits.add_(sum_softmax_times_logits) + softmax_logits.mul_(-1) + return softmax_logits def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor: diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 32bfda2278f..4dc954c5c92 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -25,9 +25,9 @@ try: from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True + FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True except ImportError: - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False + FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False def gather_from_labels(data, label): @@ -49,7 +49,7 @@ def logprobs_from_logits(logits, labels): """ See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 """ - if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: + if FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: batch_dim = logits.shape[:-1] last_dim = logits.shape[-1] logits = logits.reshape(-1, last_dim) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index eff5ffa26d9..5a4c8666581 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -26,7 +26,7 @@ from verl.trainer.ppo import core_algos from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import logprobs_from_logits, masked_mean +from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F @@ -53,11 +53,6 @@ def __init__( self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 - self.compute_entropy_from_logits = ( - torch.compile(verl_F.entropy_from_logits, dynamic=True) - if self.config.get('use_torch_compile', True) # use torch compile by default - else verl_F.entropy_from_logits) - def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns: @@ -111,16 +106,14 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, attention_mask=None, position_ids=position_ids_rmpad, **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - - logits_rmpad.div_(temperature) - - # compute entropy - entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + use_cache=False, + labels=input_ids_rmpad_rolled, + temperature=temperature, + output_logits=False) # prevent model thinks we are generating + entropy_rmpad = output.entropy # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + log_probs = output.log_probs # gather log_prob if sp > 1 if self.use_ulysses_sp: @@ -149,12 +142,12 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, attention_mask=attention_mask, position_ids=position_ids, **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating - logits = output.logits - logits.div_(temperature) - logits = logits[:, -response_length - 1:-1, :] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch['responses']) - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + use_cache=False, + labels=micro_batch['responses'], + temperature=temperature, + output_logits=False) # prevent model thinks we are generating + entropy = output.entropy + log_probs = output.log_probs return entropy, log_probs diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 2c52c3e5971..79cdda3f240 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -159,14 +159,6 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: """ data.batch = data.batch.contiguous() - def compute_logprobs_fn(output, data): - response = data['responses'] - response_length = response.size(1) - logits = output['logits'] - logits = logits[:, -response_length - 1:-1].contiguous() - log_probs = vocab_parallel_log_probs_from_logits(logits, response) - return {'log_probs': log_probs} - # We make recompute_old_log_prob by default here. # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be handled by user outside recompute_old_log_prob = self.config.get('recompute_old_log_prob', True) @@ -179,7 +171,7 @@ def compute_logprobs_fn(output, data): response = batch['responses'] response_length = response.size(1) with torch.no_grad(): - output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn) + output = self.forward_backward_batch(data, forward_only=True) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank log_probs = torch.cat([o['log_probs'] for o in output], dim=0) # (bs, seq_size) @@ -230,7 +222,7 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: epochs=self.config.ppo_epochs, dataloader_kwargs={'shuffle': self.config.shuffle}) - def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None): + def forward_backward_batch(self, data: DataProto, forward_only=False): """ We assume: - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input @@ -262,12 +254,6 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce forward_backward_func = get_forward_backward_func() def loss_func(output, data, meta_info): - if forward_only: - if post_process_fn is None: - return 1.0, {'logits': output.logits} - else: - return 1.0, post_process_fn(output, data) - responses = data['responses'] response_length = responses.size(1) attention_mask = data['attention_mask'] @@ -279,17 +265,13 @@ def loss_func(output, data, meta_info): entropy_coeff = meta_info['entropy_coeff'] # compute policy loss - logits = output.logits - logits = logits[:, -response_length - 1:-1].contiguous() - logits_back = logits.clone() - log_prob = vocab_parallel_log_probs_from_logits(logits, responses) - logits = logits_back + log_prob = output.log_probs + entropy_loss = output.entropy pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, eos_mask=response_mask, cliprange=clip_ratio) - entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask) policy_loss = pg_loss - entropy_loss * entropy_coeff metrics = {} @@ -320,7 +302,9 @@ def forward_step(batch_iter, model): input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] position_ids = batch['position_ids'] - output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + responses = data['responses'] + temperature = data['temperature'] + output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=responses, temperature=temperature, output_logits=False) if forward_only: meta_info = None else: