From b7aa6d50180551666a18b9ef8c3d910f4656156a Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Tue, 4 Mar 2025 23:00:02 -0800 Subject: [PATCH 01/15] Add linear_cross_entropy Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 128 ++++ verl/utils/kernel/__init__.py | 20 + verl/utils/kernel/kernels.py | 841 ++++++++++++++++++++++ verl/utils/kernel/linear_cross_entropy.py | 53 ++ 4 files changed, 1042 insertions(+) create mode 100644 tests/kernel/test_linear_cross_entropy.py create mode 100644 verl/utils/kernel/__init__.py create mode 100644 verl/utils/kernel/kernels.py create mode 100644 verl/utils/kernel/linear_cross_entropy.py diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py new file mode 100644 index 00000000000..1b34e4e8ec9 --- /dev/null +++ b/tests/kernel/test_linear_cross_entropy.py @@ -0,0 +1,128 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import typing + +try: + from verl.utils.kernel import linear_cross_entropy +except ImportError: + # FIXME: remove these manually included paths + import os + import sys + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))) +finally: + from verl.utils.kernel import linear_cross_entropy, set_backward_method, BackwardEnum + +def run_torch_entropy(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor) -> typing.List[torch.Tensor]: + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] + entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1] + return logprobs, entropy + +if __name__ == "__main__": + num_tokens = 80 + hidden_size = 4096 + vocab_size = 152064 + + dtype = torch.bfloat16 + + enabled_fileds = { + "forward": {"Torch": True, "Kernel": True}, + "backward": {"Torch": True, "Kernel": True} + } + + # set_backward_method(BackwardEnum._Total_Separate) + + iterations = 5 + for i in range(iterations): + print(f"[INFO]: ---------- Iteration {i} starts. ----------") + with torch.cuda.nvtx.range(f"iteration_{i}"): + hidden = (torch.empty((num_tokens, hidden_size), dtype=dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_()) + weight = (torch.empty((hidden_size, vocab_size), dtype=dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_()) + labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda") + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + if enabled_fileds["forward"]["Torch"]: + torch.cuda.reset_peak_memory_stats() + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Forward pass: Torch implementation peak memory: {torch_max_memory:.2f} MB") + print(f"[INFO]: Forward pass: Torch implementation time: {start_event.elapsed_time(end_event):.2f} ms") + + if enabled_fileds["forward"]["Kernel"]: + torch.cuda.reset_peak_memory_stats() + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Forward pass: Kernel implementation peak memory: {kernel_max_memory:.2f} MB") + print(f"[INFO]: Forward pass: Kernel implementation time: {start_event.elapsed_time(end_event):.2f} ms") + + if enabled_fileds["forward"]["Torch"]: + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-3, rtol=1e-3) + print(f"[INFO]: Forward pass: Kernel implementation passed.") + + if enabled_fileds["backward"]["Torch"] or enabled_fileds["backward"]["Kernel"]: + g_entropy = (torch.empty((num_tokens,), dtype=dtype, device="cuda") + .uniform_(-0.5, 0.5)) + g_logprobs = (torch.empty((), dtype=dtype, device="cuda") + .uniform_(-1, 1)) + + if enabled_fileds["backward"]["Torch"]: + torch.cuda.reset_peak_memory_stats() + start_event.record() + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + torch_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Backward pass: torch implementation peak memory: {torch_backward_max_memory:.2f} MB") + print(f"[INFO]: Backward pass: torch gradient time: {start_event.elapsed_time(end_event):.2f} ms") + + if enabled_fileds["backward"]["Kernel"]: + torch.cuda.reset_peak_memory_stats() + start_event.record() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Backward pass: kernel implementation peak memory: {kernel_backward_max_memory:.2f} MB") + print(f"[INFO]: Backward pass: kernel gradient time: {start_event.elapsed_time(end_event):.2f} ms") + + if enabled_fileds["backward"]["Torch"]: + torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-3) + torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=1e-2, rtol=1e-3) + print(f"[INFO]: Backward pass: kernel implementation passed.") \ No newline at end of file diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py new file mode 100644 index 00000000000..ae59d82f645 --- /dev/null +++ b/verl/utils/kernel/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .linear_cross_entropy import linear_cross_entropy +from .kernels import set_backward_method, BackwardEnum + +__all__ = ["linear_cross_entropy", + "set_backward_method", + "BackwardEnum"] diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py new file mode 100644 index 00000000000..3daea5221b2 --- /dev/null +++ b/verl/utils/kernel/kernels.py @@ -0,0 +1,841 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import typing +from dataclasses import dataclass +import triton +import triton.language as tl + +@dataclass +class EntropyReductionEnum: + _None = 0 + _Sum = 1 + _Mean = 2 + +def get_entropy_reduction_enum_number(reduction: str) -> int: + if reduction == "none": + return EntropyReductionEnum._None + elif reduction == "sum": + return EntropyReductionEnum._Sum + elif reduction == "mean": + return EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid reduction: {reduction}") + + +def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: + if ce_reduction == 0: + return EntropyReductionEnum._None + elif ce_reduction == 1: + return EntropyReductionEnum._Sum + elif ce_reduction == 2: + return EntropyReductionEnum._Mean + else: + raise ValueError(f"Invalid ce_reduction: {ce_reduction}") + + +@dataclass +class BackwardEnum: + _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + +_BACKWARD: BackwardEnum = BackwardEnum._Total_Separate + +def set_backward_method(backward_method: BackwardEnum): + global _BACKWARD + _BACKWARD = backward_method + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, num_warps=4)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr, + num_tokens, hidden_size, vocab_size, vocab_per_split, + stride_hidden_m, stride_hidden_k, + stride_weight_k, stride_weight_n, + max_ptr, stride_max_m, stride_max_n, + max_indices_ptr, stride_max_indices_m, stride_max_indices_n, + accu_ptr, stride_accu_m, stride_accu_n, + entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, + global_logprobs_ptr, stride_global_logprobs, + global_logprobs_scalar_ptr, + reduction: int, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + pid = tl.program_id(axis=0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + if pid_m == 0 and pid_n == 0: + tl.store(global_logprobs_scalar_ptr, 0.0) + + # create pointers for the first blocks of hidden + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + + # load labels for this block + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + + # traverse over N dimension + # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) + _max_indices = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int64) + _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for n in range(0, num_pid_n): + offs_bn = (pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + + # iterate over K dimension + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + # load the next block of hidden and weight + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), + other=0.0) + _weight = tl.load(weight_ptrs, + mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[None,:] < (min((pid_n + 1) * vocab_per_split, vocab_size))), + other=0.0) + + # GEMM + logits = tl.dot(_hidden, _weight, logits) + + # advance the ptrs to the next K block + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + # reset hidden_ptrs for next iteration + hidden_ptrs -= hidden_size * stride_hidden_k + + # update global maximum + _max_old = _max + m_pid_n, m_pid_n_idx = tl.max(logits, axis=1, return_indices=True) + _max = tl.maximum(_max_old, m_pid_n) + # update indices when we find a new maximum + local_indices = pid_n * vocab_per_split + n * BLOCK_SIZE_N + m_pid_n_idx + _max_indices = tl.where(_max > _max_old, local_indices, _max_indices) + + exp_logits = tl.exp(logits - _max[:,None]) + coeff = tl.exp(_max_old - _max) + _accu = coeff * _accu + tl.sum(exp_logits, axis=1) + + _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) + + label_mask = offs_bn[None,:] == labels[:,None] + _logprobs += tl.sum(logits * label_mask, axis=1) + + + # store maximum + offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_max_n = pid_n + maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m + tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + maximum_indices_ptrs = max_indices_ptr + offs_max_n * stride_max_indices_n + offs_max_m * stride_max_indices_m + tl.store(maximum_indices_ptrs, _max_indices, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store entropy + accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m + tl.store(accu_ptrs, _accu, mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits)) + entropy_b_ptrs = entropy_b_ptr + offs_max_n * stride_entropy_b_n + offs_max_m * stride_entropy_b_m + tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + + # store logprobs + mask = (labels >= pid_n * vocab_per_split) & (labels < min((pid_n + 1) * vocab_per_split, vocab_size)) + mask &= (offs_am < num_tokens) + global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs + # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) + tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64}) + ], + key=["num_tokens", "num_splits"] +) +@triton.jit +def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n, + max_indices_ptr, stride_max_indices_m, stride_max_indices_n, + num_tokens, num_splits, + global_max_ptr, stride_global_max, + global_max_indices_ptr, stride_global_max_indices, + accu_ptr, stride_accu_m, stride_accu_n, + global_accu_ptr, stride_global_accu, + entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, + global_entropy_ptr, stride_global_entropy, + global_logprobs_ptr, stride_global_logprobs, + global_logprobs_scalar_ptr, + reduction: int, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr): + pid_m = tl.program_id(axis=0) + + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_max_indices = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + max_ptrs = max_ptr + offs_m[:,None] * stride_max_m + offs_n[None,:] * stride_max_n + max_indices_ptrs = max_indices_ptr + offs_m[:,None] * stride_max_indices_m + offs_n[None,:] * stride_max_indices_n + + _max = tl.load(max_ptrs, + mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), + other=0.0) + + accu_ptrs = accu_ptr + offs_m[:,None] * stride_accu_m + offs_n[None,:] * stride_accu_n + _accu = tl.load(accu_ptrs, + mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), + other=0.0) + + entropy_b_ptrs = entropy_b_ptr + offs_m[:,None] * stride_entropy_b_m + offs_n[None,:] * stride_entropy_b_n + _entropy_b = tl.load(entropy_b_ptrs, + mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), + other=0.0) + + # local reduction + _max_old = global_max + _local_max, _local_indices = tl.max(_max, axis=1, return_indices=True) + global_max = tl.maximum(global_max, _local_max) + _local_indices += pid_n * BLOCK_SIZE_N + global_max_indices = tl.where(global_max > _max_old, _local_indices, global_max_indices) + + _scale = tl.exp(_max - global_max[:,None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + maximum_ptrs = global_max_ptr + offs_m * stride_global_max + tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + + # gather values from max_indices_ptr using global_max_indices + offs_n = global_max_indices + max_indices_ptrs = max_indices_ptr + offs_m * stride_max_indices_m + offs_n * stride_max_indices_n + final_indices = tl.load(max_indices_ptrs, mask=(offs_m < num_tokens) & (offs_n < num_splits)) + + # store to global max indices ptr + maximum_indices_ptrs = global_max_indices_ptr + offs_m * stride_global_max_indices + tl.store(maximum_indices_ptrs, final_indices, mask=offs_m < num_tokens) + + # store entropy + global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu + tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) + + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + global_entropy_b = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy + tl.store(global_entropy_ptrs, global_entropy_b, mask=offs_m < num_tokens) + + # update logprobs + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + + if reduction == 0: + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + elif reduction == 2: + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + +def efficient_entropy_foward(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[0] + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + hidden_size, vocab_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + elif REDUCTION == EntropyReductionEnum._Sum or REDUCTION == EntropyReductionEnum._Mean: + logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) + else: + raise ValueError(f"Invalid reduction: {reduction}") + + entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert logprobs.is_contiguous() and entropy.is_contiguous() + + maximum = torch.empty_like(entropy) + maximum_indices = torch.empty_like(entropy, dtype=torch.int64) + acc = torch.empty_like(entropy) + assert maximum.is_contiguous() and maximum_indices.is_contiguous() and acc.is_contiguous() + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _max_indices = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.int64) + _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + + if REDUCTION == EntropyReductionEnum._None: + _logprobs = logprobs + else: + _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() and _max_indices.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda and _max_indices.is_cuda + + # 1D kernel launch, then split the tile + grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + efficient_entropy_kernel_general_mainloop[grid]( + hidden, weight, labels, + num_tokens, hidden_size, vocab_size, vocab_per_split, + hidden.stride(0), hidden.stride(1), + weight.stride(0), weight.stride(1), + _max, _max.stride(0), _max.stride(1), + _max_indices, _max_indices.stride(0), _max_indices.stride(1), + _accu, _accu.stride(0), _accu.stride(1), + _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), + _logprobs, _logprobs.stride(0), + logprobs, REDUCTION + ) + + # reduction on maximum and maximum_indices + grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + efficient_entropy_triton_kernel_epilogue[grid]( + _max, _max.stride(0), _max.stride(1), + _max_indices, _max_indices.stride(0), _max_indices.stride(1), + num_tokens, num_splits, + maximum, maximum.stride(0), + maximum_indices, maximum_indices.stride(0), + _accu, _accu.stride(0), _accu.stride(1), + acc, acc.stride(0), + _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), + entropy, entropy.stride(0), + _logprobs, _logprobs.stride(0), + logprobs, REDUCTION + ) + + return (logprobs, entropy, maximum, maximum_indices, acc) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, num_warps=4), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, + hidden_size: int, + vocab_size: int, + vocab_per_split: int, + hidden_ptr, stride_hidden_m, stride_hidden_k, + weight_ptr, stride_weight_k, stride_weight_n, + labels_ptr, stride_labels, + d_entropy_ptr, stride_d_entropy, + d_logprobs_ptr, stride_d_logprobs, + reduction: int, + maximum_ptr, stride_maximum, + maximum_indices_ptr, stride_maximum_indices, + accu_ptr, stride_accu, + d_max_ptr, stride_d_max, + d_max_additional_ptr, stride_d_max_additional, + d_acc_exp_logits_ptr, stride_d_acc_exp_logits, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + pid = tl.program_id(0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + # pointers for this block + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) + + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + d_entropy = tl.load(d_entropy_ptr + offs_am, mask=offs_am < num_tokens) + if reduction == 0: # none + d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr).broadcast_to((BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)).broadcast_to((BLOCK_SIZE_M,)) + + maximum = tl.load(maximum_ptr + offs_am, mask=offs_am < num_tokens) + accu = tl.load(accu_ptr + offs_am, mask=offs_am < num_tokens) + + d_acc_exp_logits = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + d_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + d_max_additional = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for n in range(0, num_pid_n): + offs_bn = (pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), + other=0.0) + _weight = tl.load(weight_ptrs, + mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[None,:] < min((pid_n + 1) * vocab_per_split, vocab_size)), + other=0.0) + logits = tl.dot(_hidden, _weight, logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + + exp_logits = tl.exp(logits - maximum[:,None]) + pd = tl.fdiv(exp_logits, accu[:,None]) + + d_pd = logits * -d_entropy[:,None] + label_mask = offs_bn[None,:] == labels[:,None] + d_pd += tl.fdiv(-d_logprobs[:,None], pd) * label_mask + + accu_rcp = tl.fdiv(1.0, accu) + d_acc_exp_logits += tl.sum(-d_pd * (accu_rcp * accu_rcp)[:,None] * exp_logits, axis=1) + d_max += tl.sum(d_pd * accu_rcp[:,None] * exp_logits, axis=1) + d_max_additional += tl.sum(exp_logits, axis=1) + + # store + # NOTE: perhaps we need to store those results separately, so that numerical determinism is guaranteed + d_max_ptrs = d_max_ptr + offs_am * stride_d_max + tl.atomic_add(d_max_ptrs, d_max, mask=(offs_am < num_tokens)) + d_max_additional_ptrs = d_max_additional_ptr + offs_am * stride_d_max_additional + tl.atomic_add(d_max_additional_ptrs, d_max_additional, mask=(offs_am < num_tokens)) + d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits + tl.atomic_add(d_acc_exp_logits_ptrs, d_acc_exp_logits, mask=(offs_am < num_tokens)) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 32})], + key=["num_tokens"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_preprocess_update(num_tokens: int, + accu_ptr, stride_accu, + d_entropy_ptr, stride_d_entropy, + d_acc_exp_logits_ptr, stride_d_acc_exp_logits, + d_max_additional_ptr, stride_d_max_additional, + d_max_ptr, stride_d_max, + BLOCK_SIZE_M: tl.constexpr): + pid_m = tl.program_id(0) + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + + d_max_ptrs = d_max_ptr + offs_am * stride_d_max + d_max = tl.load(d_max_ptrs, mask=offs_am < num_tokens, other=0.0) + + d_max_additional_ptrs = d_max_additional_ptr + offs_am * stride_d_max_additional + d_max_additional = tl.load(d_max_additional_ptrs, mask=offs_am < num_tokens, other=0.0) + + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=0.0) + + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + + d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits + d_acc_exp_logits = tl.load(d_acc_exp_logits_ptrs, mask=offs_am < num_tokens, other=0.0) + + d_acc_exp_logits += tl.fdiv(d_entropy, accu) + d_max += d_max_additional * d_acc_exp_logits + + # store + tl.store(d_acc_exp_logits_ptrs, d_acc_exp_logits, mask=offs_am < num_tokens) + tl.store(d_max_ptrs, d_max, mask=offs_am < num_tokens) + + +# NOTE: merge d_weight & d_hidden here, split along M & N +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=3, num_warps=4)], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, + hidden_size: int, + vocab_size: int, + hidden_ptr, stride_hidden_m, stride_hidden_k, + weight_ptr, stride_weight_k, stride_weight_n, + labels_ptr, stride_labels, + maximum_ptr, stride_maximum, + maximum_indices_ptr, stride_maximum_indices, + accu_ptr, stride_accu, + d_entropy_ptr, stride_d_entropy, + d_logprobs_ptr, stride_d_logprobs, + reduction: int, + d_max_ptr, stride_d_max, + d_acc_exp_logits_ptr, stride_d_acc_exp_logits, + d_hidden_ptr, stride_d_hidden_m, stride_d_hidden_k, + d_weight_ptr, stride_d_weight_k, stride_d_weight_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + # TODO: perhaps block swizzling here + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + maximum_indices_ptrs = maximum_indices_ptr + offs_am * stride_maximum_indices + maximum_indices = tl.load(maximum_indices_ptrs, mask=offs_am < num_tokens, other=0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr).broadcast_to((BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)).broadcast_to((BLOCK_SIZE_M,)) + + d_max_ptrs = d_max_ptr + offs_am * stride_d_max + d_max = tl.load(d_max_ptrs, mask=offs_am < num_tokens, other=0.0) + d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits + d_acc_exp_logits = tl.load(d_acc_exp_logits_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + d_hidden_ptrs = d_hidden_ptr + offs_am[:,None] * stride_d_hidden_m + offs_k[None,:] * stride_d_hidden_k + d_weight_ptrs = d_weight_ptr + offs_k[:,None] * stride_d_weight_k + offs_bn[None,:] * stride_d_weight_n + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:,None] < num_tokens), + other=0.0) + _weight = tl.load(weight_ptrs, + mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[None,:] < vocab_size), + other=0.0) + + logits = tl.dot(_hidden, _weight, logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + exp_logits = tl.exp(logits - maximum[:,None]) + accu_rcp = tl.fdiv(1.0, accu) + + d_logits = (-d_entropy * accu_rcp)[:, None] * exp_logits + + d_pd = logits * -d_entropy[:,None] + mask = offs_bn[None,:] == labels[:,None] + d_pd += tl.fdiv((-d_logprobs * accu)[:,None], exp_logits) * mask + + d_exp_logits = d_pd * accu_rcp[:,None] + d_exp_logits += d_acc_exp_logits[:,None] + d_logits += d_exp_logits * exp_logits + + d_max = d_entropy - d_max + mask = offs_bn[None,:] == maximum_indices[:,None] + d_logits += tl.where(mask, d_max[:,None], 0.0) + + # loop for d_weight & d_hidden + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:,None] < num_tokens), + other=0.0) + # TODO: perhaps we can convert d_logits to bfloat16 here + _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) + tl.atomic_add(d_weight_ptrs, _d_weight, + mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[None,:] < vocab_size)) + + _weight = tl.load(weight_ptrs, + mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[None,:] < vocab_size), + other=0.0) + _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) + tl.atomic_add(d_hidden_ptrs, _d_hidden, + mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:,None] < num_tokens)) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k + d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k + + +# NOTE: split tile from d_logits' perspective +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, num_warps=4), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, + hidden_size: int, + vocab_size: int, + hidden_ptr, stride_hidden_m, stride_hidden_k, + weight_ptr, stride_weight_k, stride_weight_n, + labels_ptr, stride_labels, + maximum_ptr, stride_maximum, + maximum_indices_ptr, stride_maximum_indices, + accu_ptr, stride_accu, + d_entropy_ptr, stride_d_entropy, + d_logprobs_ptr, stride_d_logprobs, + reduction: int, + d_max_ptr, stride_d_max, + d_acc_exp_logits_ptr, stride_d_acc_exp_logits, + d_logits_ptr, stride_d_logits_m, stride_d_logits_n, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): + pid = tl.program_id(0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) + # TODO: perhaps block swizzling here + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum_ptrs = maximum_ptr + offs_am * stride_maximum + maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) + maximum_indices_ptrs = maximum_indices_ptr + offs_am * stride_maximum_indices + maximum_indices = tl.load(maximum_indices_ptrs, mask=offs_am < num_tokens, other=0) + accu_ptrs = accu_ptr + offs_am * stride_accu + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy + d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) + if reduction == 0: # none + d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs + d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) + elif reduction == 1: # sum + d_logprobs = tl.load(d_logprobs_ptr).broadcast_to((BLOCK_SIZE_M,)) + else: # mean + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)).broadcast_to((BLOCK_SIZE_M,)) + + d_max_ptrs = d_max_ptr + offs_am * stride_d_max + d_max = tl.load(d_max_ptrs, mask=offs_am < num_tokens, other=0.0) + d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits + d_acc_exp_logits = tl.load(d_acc_exp_logits_ptrs, mask=offs_am < num_tokens, other=0.0) + + hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + labels_ptrs = labels_ptr + offs_am * stride_labels + labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:,None] < num_tokens), + other=0.0) + _weight = tl.load(weight_ptrs, + mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[None,:] < vocab_size), + other=0.0) + + logits = tl.dot(_hidden, _weight, logits) + + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + hidden_ptrs -= hidden_size * stride_hidden_k + weight_ptrs -= hidden_size * stride_weight_k + + exp_logits = tl.exp(logits - maximum[:,None]) + accu_rcp = tl.fdiv(1.0, accu) + + d_logits = (-d_entropy * accu_rcp)[:, None] * exp_logits + + d_pd = logits * -d_entropy[:,None] + mask = offs_bn[None,:] == labels[:,None] + d_pd += tl.fdiv((-d_logprobs * accu)[:,None], exp_logits) * mask + + d_exp_logits = d_pd * accu_rcp[:,None] + d_exp_logits += d_acc_exp_logits[:,None] + d_logits += d_exp_logits * exp_logits + + d_max = d_entropy - d_max + mask = offs_bn[None,:] == maximum_indices[:,None] + d_logits += tl.where(mask, d_max[:,None], 0.0) + + # store d_logits + d_logits_ptrs = d_logits_ptr + offs_am[:,None] * stride_d_logits_m + offs_bn[None,:] * stride_d_logits_n + tl.store(d_logits_ptrs, d_logits.to(hidden_ptr.dtype.element_ty), + mask=(offs_am[:,None] < num_tokens) + & (offs_bn[None,:] < vocab_size)) + + +def efficient_entropy_backward(dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + maximum_indices: torch.Tensor, + acc: torch.Tensor, + reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[0] + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + hidden_size, vocab_size = weight.shape + assert hidden_size % 128 == 0 + assert vocab_size % 128 == 0 + + REDUCTION = get_entropy_reduction_enum(reduction) + + if REDUCTION == EntropyReductionEnum._None: + assert dlogprobs.shape == (num_tokens,) + else: + assert dlogprobs.dim() == 0 + + assert dlogprobs.is_contiguous() and dentropy.is_contiguous() + assert dlogprobs.is_cuda and dentropy.is_cuda + assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device + assert dentropy.shape == (num_tokens,) + + if _BACKWARD == BackwardEnum._Total_Fuse_MN: + d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) + d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) + elif _BACKWARD == BackwardEnum._Total_Separate: + d_hidden = torch.empty_like(hidden, dtype=hidden.dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + assert maximum.is_contiguous() and maximum_indices.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and maximum_indices.device == hidden.device and acc.device == hidden.device + assert maximum.shape == maximum_indices.shape == labels.shape == acc.shape + assert maximum.is_cuda and maximum_indices.is_cuda and acc.is_cuda + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _d_max = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + _d_max_additional = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + _d_acc_exp_logits = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + assert _d_max.is_contiguous() and _d_acc_exp_logits.is_contiguous() + assert _d_max.is_cuda and _d_acc_exp_logits.is_cuda + + grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + efficient_entropy_backward_kernel_general_preprocess[grid]( + num_tokens, hidden_size, vocab_size, vocab_per_split, + hidden, hidden.stride(0), hidden.stride(1), + weight, weight.stride(0), weight.stride(1), + labels, labels.stride(0), + dentropy, dentropy.stride(0), + dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + maximum, maximum.stride(0), + maximum_indices, maximum_indices.stride(0), + acc, acc.stride(0), + _d_max, _d_max.stride(0), + _d_max_additional, _d_max_additional.stride(0), + _d_acc_exp_logits, _d_acc_exp_logits.stride(0), + ) + + grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + efficient_entropy_backward_kernel_general_preprocess_update[grid]( + num_tokens, + acc, acc.stride(0), + dentropy, dentropy.stride(0), + _d_acc_exp_logits, _d_acc_exp_logits.stride(0), + _d_max_additional, _d_max_additional.stride(0), + _d_max, _d_max.stride(0), + ) + + if _BACKWARD == BackwardEnum._Total_Fuse_MN: + grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + efficient_entropy_backward_kernel_general_mainloop_MN[grid]( + num_tokens, hidden_size, vocab_size, + hidden, hidden.stride(0), hidden.stride(1), + weight, weight.stride(0), weight.stride(1), + labels, labels.stride(0), + maximum, maximum.stride(0), + maximum_indices, maximum_indices.stride(0), + acc, acc.stride(0), + dentropy, dentropy.stride(0), + dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + _d_max, _d_max.stride(0), + _d_acc_exp_logits, _d_acc_exp_logits.stride(0), + d_hidden, d_hidden.stride(0), d_hidden.stride(1), + d_weight, d_weight.stride(0), d_weight.stride(1), + ) + elif _BACKWARD == BackwardEnum._Total_Separate: + _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=torch.bfloat16) + grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + efficient_entropy_backward_kernel_general_d_logits[grid]( + num_tokens, hidden_size, vocab_size, + hidden, hidden.stride(0), hidden.stride(1), + weight, weight.stride(0), weight.stride(1), + labels, labels.stride(0), + maximum, maximum.stride(0), + maximum_indices, maximum_indices.stride(0), + acc, acc.stride(0), + dentropy, dentropy.stride(0), + dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + REDUCTION, + _d_max, _d_max.stride(0), + _d_acc_exp_logits, _d_acc_exp_logits.stride(0), + _d_logits, _d_logits.stride(0), _d_logits.stride(1), + ) + + torch.matmul(_d_logits, weight.T, out=d_hidden) + torch.matmul(hidden.T, _d_logits, out=d_weight) + + return d_hidden, d_weight diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py new file mode 100644 index 00000000000..cd5385497b8 --- /dev/null +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -0,0 +1,53 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import typing +from . import kernels + +class LinearCrossEntropy(torch.autograd.Function): + @staticmethod + def forward(ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[str] = "mean") -> typing.List[torch.Tensor]: + with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + + logprobs, entropy, _maximum, _maximum_indices, _acc =\ + kernels.efficient_entropy_foward(hidden, weight, labels, REDUCTION) + + ctx.save_for_backward(hidden, weight, labels, _maximum, _maximum_indices, _acc) + ctx.REDUCTION = REDUCTION + + return logprobs, entropy + + @staticmethod + def backward(ctx, + dlogprobs: torch.Tensor, + dentropy: torch.Tensor) -> typing.List[torch.Tensor]: + with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): + (hidden, weight, labels, _maximum, _maximum_indices, _acc) = ctx.saved_tensors + REDUCTION = ctx.REDUCTION + + d_hidden, d_weight = kernels.efficient_entropy_backward( + dlogprobs, dentropy, + hidden, weight, labels, + _maximum, _maximum_indices, _acc, + REDUCTION) + + return (d_hidden, d_weight, None, None) + +linear_cross_entropy = LinearCrossEntropy.apply \ No newline at end of file From b8a1678e7ffb0bdf16adec0721af4201c56074b1 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Tue, 4 Mar 2025 23:11:38 -0800 Subject: [PATCH 02/15] prepend nvidia's copyright Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 10 ++++++++++ verl/utils/kernel/__init__.py | 10 ++++++++++ verl/utils/kernel/kernels.py | 10 ++++++++++ verl/utils/kernel/linear_cross_entropy.py | 10 ++++++++++ 4 files changed, 40 insertions(+) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index 1b34e4e8ec9..ca7b37e8ea9 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -1,3 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + # Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py index ae59d82f645..b6bf219f125 100644 --- a/verl/utils/kernel/__init__.py +++ b/verl/utils/kernel/__init__.py @@ -1,3 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + # Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index 3daea5221b2..98401ba5c7b 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -1,3 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + # Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py index cd5385497b8..18ecfce3f4f 100644 --- a/verl/utils/kernel/linear_cross_entropy.py +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -1,3 +1,13 @@ +# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. + # Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); From a26cf02fd95f3b8554e4596264508e46b7eaf441 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Wed, 5 Mar 2025 00:32:41 -0800 Subject: [PATCH 03/15] update license Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 23 +++++++++++++-------- verl/utils/kernel/__init__.py | 25 +++++++++++++++-------- verl/utils/kernel/kernels.py | 23 +++++++++++++-------- verl/utils/kernel/linear_cross_entropy.py | 25 +++++++++++++++-------- 4 files changed, 62 insertions(+), 34 deletions(-) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index ca7b37e8ea9..181f1d0d7b4 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -1,12 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Copyright 2024 Bytedance Ltd. and/or its affiliates # diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py index b6bf219f125..2453d149bad 100644 --- a/verl/utils/kernel/__init__.py +++ b/verl/utils/kernel/__init__.py @@ -1,12 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Copyright 2024 Bytedance Ltd. and/or its affiliates # diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index 98401ba5c7b..b4826f924fa 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -1,12 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary # -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Copyright 2024 Bytedance Ltd. and/or its affiliates # diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py index 18ecfce3f4f..581a3f50b02 100644 --- a/verl/utils/kernel/linear_cross_entropy.py +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -1,12 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Copyright 2024 Bytedance Ltd. and/or its affiliates # From f5abf5ae32077f79694bf57ab635dcb179083099 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Sun, 16 Mar 2025 23:59:19 -0700 Subject: [PATCH 04/15] refactored test_linear_cross_entropy Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 230 ++++++++++++++-------- 1 file changed, 149 insertions(+), 81 deletions(-) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index 181f1d0d7b4..fb51bdbe9ca 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -53,93 +53,161 @@ def run_torch_entropy(hidden: torch.Tensor, logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1] return logprobs, entropy -if __name__ == "__main__": - num_tokens = 80 - hidden_size = 4096 - vocab_size = 152064 - - dtype = torch.bfloat16 - - enabled_fileds = { - "forward": {"Torch": True, "Kernel": True}, - "backward": {"Torch": True, "Kernel": True} - } - - # set_backward_method(BackwardEnum._Total_Separate) - - iterations = 5 - for i in range(iterations): - print(f"[INFO]: ---------- Iteration {i} starts. ----------") - with torch.cuda.nvtx.range(f"iteration_{i}"): - hidden = (torch.empty((num_tokens, hidden_size), dtype=dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_()) - weight = (torch.empty((hidden_size, vocab_size), dtype=dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_()) - labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda") - - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - if enabled_fileds["forward"]["Torch"]: - torch.cuda.reset_peak_memory_stats() - start_event.record() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) - end_event.record() - torch.cuda.synchronize() - torch_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Forward pass: Torch implementation peak memory: {torch_max_memory:.2f} MB") - print(f"[INFO]: Forward pass: Torch implementation time: {start_event.elapsed_time(end_event):.2f} ms") - - if enabled_fileds["forward"]["Kernel"]: - torch.cuda.reset_peak_memory_stats() - start_event.record() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels) - end_event.record() - torch.cuda.synchronize() - kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Forward pass: Kernel implementation peak memory: {kernel_max_memory:.2f} MB") - print(f"[INFO]: Forward pass: Kernel implementation time: {start_event.elapsed_time(end_event):.2f} ms") - - if enabled_fileds["forward"]["Torch"]: - torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=1e-3) - torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-3, rtol=1e-3) - print(f"[INFO]: Forward pass: Kernel implementation passed.") - - if enabled_fileds["backward"]["Torch"] or enabled_fileds["backward"]["Kernel"]: - g_entropy = (torch.empty((num_tokens,), dtype=dtype, device="cuda") +class TestLinearCrossEntropy: + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + gc.collect() + torch.cuda.synchronize() + + def generate_hyper(self): + self.num_tokens = 80 + self.hidden_size = 4096 + self.vocab_size = 152064 + self.dtype = torch.bfloat16 + + def generate_forward_inputs(self): + hidden = (torch.empty((self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_()) + weight = (torch.empty((self.hidden_size, self.vocab_size), dtype=self.dtype, device="cuda") + .uniform_(-0.5, 0.5) + .requires_grad_()) + labels = torch.randint(0, self.vocab_size, (self.num_tokens,), device="cuda") + return hidden, weight, labels + + def generate_backward_inputs(self): + g_entropy = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda") .uniform_(-0.5, 0.5)) - g_logprobs = (torch.empty((), dtype=dtype, device="cuda") + g_logprobs = (torch.empty((), dtype=self.dtype, device="cuda") .uniform_(-1, 1)) + return g_entropy, g_logprobs + + def verify_correctness(self): + self.cleanup() + self.generate_hyper() + + iterations = 5 + + torch_forward_latency = list() + torch_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for i in range(iterations): + hidden, weight, labels = self.generate_forward_inputs() + + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + torch_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) + + # backward + g_entropy, g_logprobs = self.generate_backward_inputs() + + start_event.record() + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + torch_backward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), + (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) - if enabled_fileds["backward"]["Torch"]: - torch.cuda.reset_peak_memory_stats() - start_event.record() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), + torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=1e-2, rtol=1e-4) + + # remove first latency + torch_forward_latency = torch_forward_latency[1:] + torch_backward_latency = torch_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] + + print(f"[INFO]: Verified forward & backward correctness.") + + print(f"[INFO]: Forward pass: Torch implementation average time: " + f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: torch implementation average time: " + f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: " + f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: " + f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") + + + def check_torch_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + torch.cuda.reset_peak_memory_stats() + (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Torch Forward pass peak memory: {torch_max_memory:.2f} MB") + + g_entropy, g_logprobs = self.generate_backward_inputs() + + torch.cuda.reset_peak_memory_stats() + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) - end_event.record() - torch.cuda.synchronize() - torch_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Backward pass: torch implementation peak memory: {torch_backward_max_memory:.2f} MB") - print(f"[INFO]: Backward pass: torch gradient time: {start_event.elapsed_time(end_event):.2f} ms") - - if enabled_fileds["backward"]["Kernel"]: - torch.cuda.reset_peak_memory_stats() - start_event.record() - (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), + torch.cuda.synchronize() + torch_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Torch Backward pass peak memory: {torch_backward_max_memory:.2f} MB") + + def check_kernel_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + torch.cuda.reset_peak_memory_stats() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels) + torch.cuda.synchronize() + kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") + + g_entropy, g_logprobs = self.generate_backward_inputs() + + torch.cuda.reset_peak_memory_stats() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) - end_event.record() - torch.cuda.synchronize() - kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Backward pass: kernel implementation peak memory: {kernel_backward_max_memory:.2f} MB") - print(f"[INFO]: Backward pass: kernel gradient time: {start_event.elapsed_time(end_event):.2f} ms") - - if enabled_fileds["backward"]["Torch"]: - torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-3) - torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=1e-2, rtol=1e-3) - print(f"[INFO]: Backward pass: kernel implementation passed.") \ No newline at end of file + torch.cuda.synchronize() + kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + + +if __name__ == "__main__": + test = TestLinearCrossEntropy() + + test.verify_correctness() + test.check_torch_storage() + test.check_kernel_storage() \ No newline at end of file From 2b146a2c3950214ba2c274b81a9e4000d09c2091 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Mon, 17 Mar 2025 00:29:13 -0700 Subject: [PATCH 05/15] simplified backward pass Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 5 +- verl/utils/kernel/kernels.py | 469 ++++++++++------------ verl/utils/kernel/linear_cross_entropy.py | 8 +- 3 files changed, 213 insertions(+), 269 deletions(-) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index fb51bdbe9ca..79bd793d38e 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -99,6 +99,7 @@ def verify_correctness(self): end_event = torch.cuda.Event(enable_timing=True) for i in range(iterations): + print(f"[INFO]: Iteration {i} / {iterations}...", end="\r") hidden, weight, labels = self.generate_forward_inputs() start_event.record() @@ -146,7 +147,7 @@ def verify_correctness(self): kernel_forward_latency = kernel_forward_latency[1:] kernel_backward_latency = kernel_backward_latency[1:] - print(f"[INFO]: Verified forward & backward correctness.") + print(f"\n[INFO]: Verified forward & backward correctness.") print(f"[INFO]: Forward pass: Torch implementation average time: " f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") @@ -206,6 +207,8 @@ def check_kernel_storage(self): if __name__ == "__main__": + # set_backward_method(BackwardEnum._Total_Fuse_MN) + test = TestLinearCrossEntropy() test.verify_correctness() diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index b4826f924fa..17e1a095bba 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -29,48 +29,70 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch +""" +Implementations of the linear cross entropy with token entropy kernel. +""" + import typing from dataclasses import dataclass +import torch import triton import triton.language as tl @dataclass class EntropyReductionEnum: + """ + Enum for the reduction method of cross entropy. + """ _None = 0 _Sum = 1 _Mean = 2 def get_entropy_reduction_enum_number(reduction: str) -> int: + """ + Get the enum number for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None if reduction == "none": - return EntropyReductionEnum._None + _enum = EntropyReductionEnum._None elif reduction == "sum": - return EntropyReductionEnum._Sum + _enum = EntropyReductionEnum._Sum elif reduction == "mean": - return EntropyReductionEnum._Mean + _enum = EntropyReductionEnum._Mean else: raise ValueError(f"Invalid reduction: {reduction}") + return _enum def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: + """ + Get the enum for the reduction method of cross entropy. + """ + _enum = EntropyReductionEnum._None if ce_reduction == 0: - return EntropyReductionEnum._None + _enum = EntropyReductionEnum._None elif ce_reduction == 1: - return EntropyReductionEnum._Sum + _enum = EntropyReductionEnum._Sum elif ce_reduction == 2: - return EntropyReductionEnum._Mean + _enum = EntropyReductionEnum._Mean else: raise ValueError(f"Invalid ce_reduction: {ce_reduction}") - + return _enum @dataclass class BackwardEnum: + """ + Enum for the backward method. + """ _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight _BACKWARD: BackwardEnum = BackwardEnum._Total_Separate def set_backward_method(backward_method: BackwardEnum): + """ + Set the backward method. + """ global _BACKWARD _BACKWARD = backward_method @@ -85,16 +107,17 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr stride_hidden_m, stride_hidden_k, stride_weight_k, stride_weight_n, max_ptr, stride_max_m, stride_max_n, - max_indices_ptr, stride_max_indices_m, stride_max_indices_n, accu_ptr, stride_accu_m, stride_accu_n, entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, global_logprobs_ptr, stride_global_logprobs, - global_logprobs_scalar_ptr, - reduction: int, + global_logprobs_scalar_ptr, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + """ + forward mainloop + """ pid = tl.program_id(axis=0) num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) @@ -106,7 +129,7 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr tl.store(global_logprobs_scalar_ptr, 0.0) # create pointers for the first blocks of hidden - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = tl.arange(0, BLOCK_SIZE_K) hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) @@ -116,24 +139,23 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr # traverse over N dimension # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) - _max_indices = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int64) _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for n in range(0, num_pid_n): - offs_bn = (pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) # iterate over K dimension logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): # load the next block of hidden and weight - _hidden = tl.load(hidden_ptrs, - mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), + _hidden = tl.load(hidden_ptrs, + mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), other=0.0) - _weight = tl.load(weight_ptrs, - mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[None,:] < (min((pid_n + 1) * vocab_per_split, vocab_size))), + _weight = tl.load(weight_ptrs, + mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[None,:] < (min((pid_n + 1) * vocab_per_split, vocab_size))), other=0.0) # GEMM @@ -147,11 +169,8 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr # update global maximum _max_old = _max - m_pid_n, m_pid_n_idx = tl.max(logits, axis=1, return_indices=True) + m_pid_n = tl.max(logits, axis=1) _max = tl.maximum(_max_old, m_pid_n) - # update indices when we find a new maximum - local_indices = pid_n * vocab_per_split + n * BLOCK_SIZE_N + m_pid_n_idx - _max_indices = tl.where(_max > _max_old, local_indices, _max_indices) exp_logits = tl.exp(logits - _max[:,None]) coeff = tl.exp(_max_old - _max) @@ -168,8 +187,6 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr offs_max_n = pid_n maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m tl.store(maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) - maximum_indices_ptrs = max_indices_ptr + offs_max_n * stride_max_indices_n + offs_max_m * stride_max_indices_m - tl.store(maximum_indices_ptrs, _max_indices, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) # store entropy accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m @@ -193,10 +210,8 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr ) @triton.jit def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n, - max_indices_ptr, stride_max_indices_m, stride_max_indices_n, num_tokens, num_splits, global_max_ptr, stride_global_max, - global_max_indices_ptr, stride_global_max_indices, accu_ptr, stride_accu_m, stride_accu_n, global_accu_ptr, stride_global_accu, entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, @@ -206,25 +221,25 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n reduction: int, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + """ + foward epilogue + """ pid_m = tl.program_id(axis=0) - offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - global_max_indices = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): - offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) max_ptrs = max_ptr + offs_m[:,None] * stride_max_m + offs_n[None,:] * stride_max_n - max_indices_ptrs = max_indices_ptr + offs_m[:,None] * stride_max_indices_m + offs_n[None,:] * stride_max_indices_n - _max = tl.load(max_ptrs, + _max = tl.load(max_ptrs, mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), other=0.0) accu_ptrs = accu_ptr + offs_m[:,None] * stride_accu_m + offs_n[None,:] * stride_accu_n - _accu = tl.load(accu_ptrs, + _accu = tl.load(accu_ptrs, mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), other=0.0) @@ -235,10 +250,8 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n # local reduction _max_old = global_max - _local_max, _local_indices = tl.max(_max, axis=1, return_indices=True) + _local_max = tl.max(_max, axis=1) global_max = tl.maximum(global_max, _local_max) - _local_indices += pid_n * BLOCK_SIZE_N - global_max_indices = tl.where(global_max > _max_old, _local_indices, global_max_indices) _scale = tl.exp(_max - global_max[:,None]) _coeff = tl.exp(_max_old - global_max) @@ -249,24 +262,13 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n maximum_ptrs = global_max_ptr + offs_m * stride_global_max tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) - # gather values from max_indices_ptr using global_max_indices - offs_n = global_max_indices - max_indices_ptrs = max_indices_ptr + offs_m * stride_max_indices_m + offs_n * stride_max_indices_n - final_indices = tl.load(max_indices_ptrs, mask=(offs_m < num_tokens) & (offs_n < num_splits)) - - # store to global max indices ptr - maximum_indices_ptrs = global_max_indices_ptr + offs_m * stride_global_max_indices - tl.store(maximum_indices_ptrs, final_indices, mask=offs_m < num_tokens) - # store entropy global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) - global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b global_entropy_b = tl.log(global_accu) + global_max - global_entropy_b # entropy_a global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy tl.store(global_entropy_ptrs, global_entropy_b, mask=offs_m < num_tokens) - # update logprobs global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) @@ -274,7 +276,7 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n if reduction == 0: tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) - elif reduction == 1: + elif reduction == 1: global_logprobs_scalar = tl.sum(global_logprobs, axis=0) tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) elif reduction == 2: @@ -285,6 +287,9 @@ def efficient_entropy_foward(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: + """ + forward host function + """ assert hidden.is_cuda and weight.is_cuda and labels.is_cuda assert weight.device == hidden.device and labels.device == hidden.device assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 @@ -301,7 +306,7 @@ def efficient_entropy_foward(hidden: torch.Tensor, if REDUCTION == EntropyReductionEnum._None: logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) - elif REDUCTION == EntropyReductionEnum._Sum or REDUCTION == EntropyReductionEnum._Mean: + elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) else: raise ValueError(f"Invalid reduction: {reduction}") @@ -310,16 +315,14 @@ def efficient_entropy_foward(hidden: torch.Tensor, assert logprobs.is_contiguous() and entropy.is_contiguous() maximum = torch.empty_like(entropy) - maximum_indices = torch.empty_like(entropy, dtype=torch.int64) acc = torch.empty_like(entropy) - assert maximum.is_contiguous() and maximum_indices.is_contiguous() and acc.is_contiguous() + assert maximum.is_contiguous() and acc.is_contiguous() vocab_per_split = 1024 assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _max_indices = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.int64) _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) _entropy_b = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) @@ -328,32 +331,31 @@ def efficient_entropy_foward(hidden: torch.Tensor, else: _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) - assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() and _max_indices.is_contiguous() - assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda and _max_indices.is_cuda + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda # 1D kernel launch, then split the tile - grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) - efficient_entropy_kernel_general_mainloop[grid]( + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + efficient_entropy_kernel_general_mainloop[mainloop_grid]( hidden, weight, labels, num_tokens, hidden_size, vocab_size, vocab_per_split, hidden.stride(0), hidden.stride(1), weight.stride(0), weight.stride(1), _max, _max.stride(0), _max.stride(1), - _max_indices, _max_indices.stride(0), _max_indices.stride(1), _accu, _accu.stride(0), _accu.stride(1), _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), _logprobs, _logprobs.stride(0), - logprobs, REDUCTION + logprobs ) # reduction on maximum and maximum_indices - grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) - efficient_entropy_triton_kernel_epilogue[grid]( + def epilogue_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + efficient_entropy_triton_kernel_epilogue[epilogue_grid]( _max, _max.stride(0), _max.stride(1), - _max_indices, _max_indices.stride(0), _max_indices.stride(1), num_tokens, num_splits, maximum, maximum.stride(0), - maximum_indices, maximum_indices.stride(0), _accu, _accu.stride(0), _accu.stride(1), acc, acc.stride(0), _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), @@ -362,7 +364,7 @@ def efficient_entropy_foward(hidden: torch.Tensor, logprobs, REDUCTION ) - return (logprobs, entropy, maximum, maximum_indices, acc) + return (logprobs, entropy, maximum, acc) @triton.autotune( @@ -376,58 +378,41 @@ def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, hidden_size: int, vocab_size: int, vocab_per_split: int, + num_splits: int, hidden_ptr, stride_hidden_m, stride_hidden_k, weight_ptr, stride_weight_k, stride_weight_n, - labels_ptr, stride_labels, - d_entropy_ptr, stride_d_entropy, - d_logprobs_ptr, stride_d_logprobs, - reduction: int, - maximum_ptr, stride_maximum, - maximum_indices_ptr, stride_maximum_indices, - accu_ptr, stride_accu, - d_max_ptr, stride_d_max, - d_max_additional_ptr, stride_d_max_additional, - d_acc_exp_logits_ptr, stride_d_acc_exp_logits, + maximum_ptr, + d_scale_ptr, stride_d_scale_m, stride_d_scale_n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): + """ + backward preprocess + """ pid = tl.program_id(0) - num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) pid_m = pid % num_pid_m pid_n = pid // num_pid_m # pointers for this block - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = tl.arange(0, BLOCK_SIZE_K) hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) - labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) - d_entropy = tl.load(d_entropy_ptr + offs_am, mask=offs_am < num_tokens) - if reduction == 0: # none - d_logprobs = tl.load(d_logprobs_ptr + offs_am * stride_d_logprobs, mask=offs_am < num_tokens) - elif reduction == 1: # sum - d_logprobs = tl.load(d_logprobs_ptr).broadcast_to((BLOCK_SIZE_M,)) - else: # mean - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)).broadcast_to((BLOCK_SIZE_M,)) - maximum = tl.load(maximum_ptr + offs_am, mask=offs_am < num_tokens) - accu = tl.load(accu_ptr + offs_am, mask=offs_am < num_tokens) - d_acc_exp_logits = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - d_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - d_max_additional = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _scale = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for n in range(0, num_pid_n): - offs_bn = (pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load(hidden_ptrs, + _hidden = tl.load(hidden_ptrs, mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), other=0.0) - _weight = tl.load(weight_ptrs, + _weight = tl.load(weight_ptrs, mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None,:] < min((pid_n + 1) * vocab_per_split, vocab_size)), other=0.0) @@ -438,68 +423,18 @@ def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, hidden_ptrs -= hidden_size * stride_hidden_k exp_logits = tl.exp(logits - maximum[:,None]) - pd = tl.fdiv(exp_logits, accu[:,None]) - - d_pd = logits * -d_entropy[:,None] - label_mask = offs_bn[None,:] == labels[:,None] - d_pd += tl.fdiv(-d_logprobs[:,None], pd) * label_mask - - accu_rcp = tl.fdiv(1.0, accu) - d_acc_exp_logits += tl.sum(-d_pd * (accu_rcp * accu_rcp)[:,None] * exp_logits, axis=1) - d_max += tl.sum(d_pd * accu_rcp[:,None] * exp_logits, axis=1) - d_max_additional += tl.sum(exp_logits, axis=1) - - # store - # NOTE: perhaps we need to store those results separately, so that numerical determinism is guaranteed - d_max_ptrs = d_max_ptr + offs_am * stride_d_max - tl.atomic_add(d_max_ptrs, d_max, mask=(offs_am < num_tokens)) - d_max_additional_ptrs = d_max_additional_ptr + offs_am * stride_d_max_additional - tl.atomic_add(d_max_additional_ptrs, d_max_additional, mask=(offs_am < num_tokens)) - d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits - tl.atomic_add(d_acc_exp_logits_ptrs, d_acc_exp_logits, mask=(offs_am < num_tokens)) - - -@triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 32})], - key=["num_tokens"], -) -@triton.jit -def efficient_entropy_backward_kernel_general_preprocess_update(num_tokens: int, - accu_ptr, stride_accu, - d_entropy_ptr, stride_d_entropy, - d_acc_exp_logits_ptr, stride_d_acc_exp_logits, - d_max_additional_ptr, stride_d_max_additional, - d_max_ptr, stride_d_max, - BLOCK_SIZE_M: tl.constexpr): - pid_m = tl.program_id(0) - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - - d_max_ptrs = d_max_ptr + offs_am * stride_d_max - d_max = tl.load(d_max_ptrs, mask=offs_am < num_tokens, other=0.0) - - d_max_additional_ptrs = d_max_additional_ptr + offs_am * stride_d_max_additional - d_max_additional = tl.load(d_max_additional_ptrs, mask=offs_am < num_tokens, other=0.0) - accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=0.0) + _scale += tl.sum(exp_logits * logits, axis=1) - d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy - d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) - - d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits - d_acc_exp_logits = tl.load(d_acc_exp_logits_ptrs, mask=offs_am < num_tokens, other=0.0) - - d_acc_exp_logits += tl.fdiv(d_entropy, accu) - d_max += d_max_additional * d_acc_exp_logits - - # store - tl.store(d_acc_exp_logits_ptrs, d_acc_exp_logits, mask=offs_am < num_tokens) - tl.store(d_max_ptrs, d_max, mask=offs_am < num_tokens) + # tl.store(d_scale_ptr + offs_am * stride_d_scale_m + pid_n * stride_d_scale_n, + # _scale, mask=offs_am < num_tokens) + tl.store(d_scale_ptr + pid_n * stride_d_scale_m + offs_am * stride_d_scale_n, + _scale, mask=(offs_am < num_tokens) & (pid_n < num_splits)) # NOTE: merge d_weight & d_hidden here, split along M & N @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + configs=[triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=4)], key=["num_tokens", "hidden_size", "vocab_size"], ) @@ -511,49 +446,60 @@ def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, weight_ptr, stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, - maximum_indices_ptr, stride_maximum_indices, accu_ptr, stride_accu, d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, - d_max_ptr, stride_d_max, - d_acc_exp_logits_ptr, stride_d_acc_exp_logits, + d_scale_ptr, stride_d_scale, d_hidden_ptr, stride_d_hidden_m, stride_d_hidden_k, d_weight_ptr, stride_d_weight_k, stride_d_weight_n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): - pid = tl.program_id(0) + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + backward mainloop, where d_logits & d_hidden & d_weight are fused + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - # TODO: perhaps block swizzling here - pid_m = pid % num_pid_m - pid_n = pid // num_pid_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) maximum_ptrs = maximum_ptr + offs_am * stride_maximum maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) - maximum_indices_ptrs = maximum_indices_ptr + offs_am * stride_maximum_indices - maximum_indices = tl.load(maximum_indices_ptrs, mask=offs_am < num_tokens, other=0) accu_ptrs = accu_ptr + offs_am * stride_accu accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) if reduction == 0: # none d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) elif reduction == 1: # sum - d_logprobs = tl.load(d_logprobs_ptr).broadcast_to((BLOCK_SIZE_M,)) + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) else: # mean - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)).broadcast_to((BLOCK_SIZE_M,)) - - d_max_ptrs = d_max_ptr + offs_am * stride_d_max - d_max = tl.load(d_max_ptrs, mask=offs_am < num_tokens, other=0.0) - d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits - d_acc_exp_logits = tl.load(d_acc_exp_logits_ptrs, mask=offs_am < num_tokens, other=0.0) + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + + d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, + mask=offs_am < num_tokens, other=0.0) hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) @@ -582,21 +528,14 @@ def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, weight_ptrs -= hidden_size * stride_weight_k exp_logits = tl.exp(logits - maximum[:,None]) - accu_rcp = tl.fdiv(1.0, accu) - - d_logits = (-d_entropy * accu_rcp)[:, None] * exp_logits d_pd = logits * -d_entropy[:,None] mask = offs_bn[None,:] == labels[:,None] - d_pd += tl.fdiv((-d_logprobs * accu)[:,None], exp_logits) * mask - - d_exp_logits = d_pd * accu_rcp[:,None] - d_exp_logits += d_acc_exp_logits[:,None] - d_logits += d_exp_logits * exp_logits + d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:,None], exp_logits) * mask - d_max = d_entropy - d_max - mask = offs_bn[None,:] == maximum_indices[:,None] - d_logits += tl.where(mask, d_max[:,None], 0.0) + coeff = d_scale * d_entropy * accu_rcp * accu_rcp + d_logprobs * accu_rcp + d_logits = exp_logits * coeff[:,None] + d_logits += exp_logits * d_pd * accu_rcp[:,None] # loop for d_weight & d_hidden for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): @@ -604,30 +543,29 @@ def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), other=0.0) - # TODO: perhaps we can convert d_logits to bfloat16 here _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) - tl.atomic_add(d_weight_ptrs, _d_weight, + tl.atomic_add(d_weight_ptrs, _d_weight, mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None,:] < vocab_size)) - _weight = tl.load(weight_ptrs, + _weight = tl.load(weight_ptrs, mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None,:] < vocab_size), other=0.0) _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) - tl.atomic_add(d_hidden_ptrs, _d_hidden, + tl.atomic_add(d_hidden_ptrs, _d_hidden, mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens)) hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k weight_ptrs += BLOCK_SIZE_K * stride_weight_k d_hidden_ptrs += BLOCK_SIZE_K * stride_d_hidden_k - d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k + d_weight_ptrs += BLOCK_SIZE_K * stride_d_weight_k # NOTE: split tile from d_logits' perspective @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, num_stages=3, num_warps=4), ], key=["num_tokens", "hidden_size", "vocab_size"], @@ -640,48 +578,68 @@ def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, weight_ptr, stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, - maximum_indices_ptr, stride_maximum_indices, accu_ptr, stride_accu, d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, - d_max_ptr, stride_d_max, - d_acc_exp_logits_ptr, stride_d_acc_exp_logits, + d_scale_ptr, stride_d_scale, d_logits_ptr, stride_d_logits_m, stride_d_logits_n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): - pid = tl.program_id(0) + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): + """ + backward d_logits + """ + # block swizzling + # pid = tl.program_id(axis=0) + # num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + # pid_m = pid % num_pid_m + # pid_n = pid // num_pid_m + + pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) num_pid_n = tl.cdiv(vocab_size, BLOCK_SIZE_N) - # TODO: perhaps block swizzling here - pid_m = pid % num_pid_m - pid_n = pid // num_pid_m - - offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_k = tl.arange(0, BLOCK_SIZE_K) maximum_ptrs = maximum_ptr + offs_am * stride_maximum maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) - maximum_indices_ptrs = maximum_indices_ptr + offs_am * stride_maximum_indices - maximum_indices = tl.load(maximum_indices_ptrs, mask=offs_am < num_tokens, other=0) accu_ptrs = accu_ptr + offs_am * stride_accu accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu_rcp = tl.fdiv(1.0, accu) + d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) if reduction == 0: # none d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) elif reduction == 1: # sum - d_logprobs = tl.load(d_logprobs_ptr).broadcast_to((BLOCK_SIZE_M,)) + d_logprobs = tl.load(d_logprobs_ptr) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) else: # mean - d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)).broadcast_to((BLOCK_SIZE_M,)) - - d_max_ptrs = d_max_ptr + offs_am * stride_d_max - d_max = tl.load(d_max_ptrs, mask=offs_am < num_tokens, other=0.0) - d_acc_exp_logits_ptrs = d_acc_exp_logits_ptr + offs_am * stride_d_acc_exp_logits - d_acc_exp_logits = tl.load(d_acc_exp_logits_ptrs, mask=offs_am < num_tokens, other=0.0) + d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) + d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + + d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, + mask=offs_am < num_tokens, other=0.0) + + # d_acc_exp_logits = d_scale * d_entropy * accu_rcp * accu_rcp + # d_acc_exp_logits += d_logprobs * accu_rcp + # d_acc_exp_logits += d_entropy * accu_rcp + + # These equal to d_max = d_entropy + # d_max = d_scale * -d_entropy * accu_rcp + # d_max -= d_logprobs + # d_max += accu * d_acc_exp_logits hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) @@ -707,21 +665,21 @@ def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, weight_ptrs -= hidden_size * stride_weight_k exp_logits = tl.exp(logits - maximum[:,None]) - accu_rcp = tl.fdiv(1.0, accu) - - d_logits = (-d_entropy * accu_rcp)[:, None] * exp_logits d_pd = logits * -d_entropy[:,None] mask = offs_bn[None,:] == labels[:,None] - d_pd += tl.fdiv((-d_logprobs * accu)[:,None], exp_logits) * mask - - d_exp_logits = d_pd * accu_rcp[:,None] - d_exp_logits += d_acc_exp_logits[:,None] - d_logits += d_exp_logits * exp_logits + d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:,None], exp_logits) * mask + + coeff = d_scale * d_entropy * accu_rcp * accu_rcp + d_logprobs * accu_rcp + d_logits = exp_logits * coeff[:,None] + d_logits += exp_logits * d_pd * accu_rcp[:,None] + # d_logits += exp_logits * logits * (-d_entropy * accu_rcp)[:,None] + # d_logits -= tl.where(mask, d_logprobs[:,None], 0.0) - d_max = d_entropy - d_max - mask = offs_bn[None,:] == maximum_indices[:,None] - d_logits += tl.where(mask, d_max[:,None], 0.0) + # d_max is always zeros + # d_max = d_entropy - d_max + # mask = offs_bn[None,:] == maximum_indices[:,None] + # d_logits += tl.where(mask, d_max[:,None], 0.0) # store d_logits d_logits_ptrs = d_logits_ptr + offs_am[:,None] * stride_d_logits_m + offs_bn[None,:] * stride_d_logits_n @@ -730,15 +688,17 @@ def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, & (offs_bn[None,:] < vocab_size)) -def efficient_entropy_backward(dlogprobs: torch.Tensor, +def efficient_entropy_backward(dlogprobs: torch.Tensor, dentropy: torch.Tensor, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, maximum: torch.Tensor, - maximum_indices: torch.Tensor, acc: torch.Tensor, reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: + """ + backward host function + """ assert hidden.is_cuda and weight.is_cuda and labels.is_cuda assert weight.device == hidden.device and labels.device == hidden.device assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 @@ -763,6 +723,7 @@ def efficient_entropy_backward(dlogprobs: torch.Tensor, assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device assert dentropy.shape == (num_tokens,) + d_hidden, d_weight = None, None if _BACKWARD == BackwardEnum._Total_Fuse_MN: d_hidden = torch.zeros_like(hidden, dtype=torch.float32, device=hidden.device) d_weight = torch.zeros_like(weight, dtype=torch.float32, device=weight.device) @@ -771,88 +732,68 @@ def efficient_entropy_backward(dlogprobs: torch.Tensor, d_weight = torch.empty_like(weight, dtype=hidden.dtype, device=weight.device) assert d_hidden.is_contiguous() and d_weight.is_contiguous() - assert maximum.is_contiguous() and maximum_indices.is_contiguous() and acc.is_contiguous() - assert maximum.device == hidden.device and maximum_indices.device == hidden.device and acc.device == hidden.device - assert maximum.shape == maximum_indices.shape == labels.shape == acc.shape - assert maximum.is_cuda and maximum_indices.is_cuda and acc.is_cuda + assert maximum.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and acc.device == hidden.device + assert maximum.shape == labels.shape == acc.shape + assert maximum.is_cuda and acc.is_cuda vocab_per_split = 1024 assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - _d_max = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) - _d_max_additional = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) - _d_acc_exp_logits = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) - assert _d_max.is_contiguous() and _d_acc_exp_logits.is_contiguous() - assert _d_max.is_cuda and _d_acc_exp_logits.is_cuda + # _d_scale_non_reduced = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _d_scale_non_reduced = torch.empty((num_splits, num_tokens), device=hidden.device, dtype=torch.float32) + assert _d_scale_non_reduced.is_contiguous() and _d_scale_non_reduced.is_cuda - grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) - efficient_entropy_backward_kernel_general_preprocess[grid]( - num_tokens, hidden_size, vocab_size, vocab_per_split, + def preprocess_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + efficient_entropy_backward_kernel_general_preprocess[preprocess_grid]( + num_tokens, hidden_size, vocab_size, vocab_per_split, num_splits, hidden, hidden.stride(0), hidden.stride(1), weight, weight.stride(0), weight.stride(1), - labels, labels.stride(0), - dentropy, dentropy.stride(0), - dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, - REDUCTION, - maximum, maximum.stride(0), - maximum_indices, maximum_indices.stride(0), - acc, acc.stride(0), - _d_max, _d_max.stride(0), - _d_max_additional, _d_max_additional.stride(0), - _d_acc_exp_logits, _d_acc_exp_logits.stride(0), - ) - - grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) - efficient_entropy_backward_kernel_general_preprocess_update[grid]( - num_tokens, - acc, acc.stride(0), - dentropy, dentropy.stride(0), - _d_acc_exp_logits, _d_acc_exp_logits.stride(0), - _d_max_additional, _d_max_additional.stride(0), - _d_max, _d_max.stride(0), + maximum, + _d_scale_non_reduced, _d_scale_non_reduced.stride(0), _d_scale_non_reduced.stride(1), ) + _d_scale = _d_scale_non_reduced.sum(dim=0) if _BACKWARD == BackwardEnum._Total_Fuse_MN: - grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) - * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) - efficient_entropy_backward_kernel_general_mainloop_MN[grid]( + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( num_tokens, hidden_size, vocab_size, hidden, hidden.stride(0), hidden.stride(1), weight, weight.stride(0), weight.stride(1), labels, labels.stride(0), maximum, maximum.stride(0), - maximum_indices, maximum_indices.stride(0), acc, acc.stride(0), dentropy, dentropy.stride(0), dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - _d_max, _d_max.stride(0), - _d_acc_exp_logits, _d_acc_exp_logits.stride(0), + _d_scale, _d_scale.stride(0), d_hidden, d_hidden.stride(0), d_hidden.stride(1), d_weight, d_weight.stride(0), d_weight.stride(1), ) elif _BACKWARD == BackwardEnum._Total_Separate: - _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=torch.bfloat16) - grid = lambda meta: (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) - * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) - efficient_entropy_backward_kernel_general_d_logits[grid]( + _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype) + def d_logits_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( num_tokens, hidden_size, vocab_size, hidden, hidden.stride(0), hidden.stride(1), weight, weight.stride(0), weight.stride(1), labels, labels.stride(0), maximum, maximum.stride(0), - maximum_indices, maximum_indices.stride(0), acc, acc.stride(0), dentropy, dentropy.stride(0), dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - _d_max, _d_max.stride(0), - _d_acc_exp_logits, _d_acc_exp_logits.stride(0), + _d_scale, _d_scale.stride(0), _d_logits, _d_logits.stride(0), _d_logits.stride(1), ) torch.matmul(_d_logits, weight.T, out=d_hidden) torch.matmul(hidden.T, _d_logits, out=d_weight) - return d_hidden, d_weight + diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py index 581a3f50b02..d0058e6a874 100644 --- a/verl/utils/kernel/linear_cross_entropy.py +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -43,10 +43,10 @@ def forward(ctx, with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) - logprobs, entropy, _maximum, _maximum_indices, _acc =\ + logprobs, entropy, _maximum, _acc =\ kernels.efficient_entropy_foward(hidden, weight, labels, REDUCTION) - ctx.save_for_backward(hidden, weight, labels, _maximum, _maximum_indices, _acc) + ctx.save_for_backward(hidden, weight, labels, _maximum, _acc) ctx.REDUCTION = REDUCTION return logprobs, entropy @@ -56,13 +56,13 @@ def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> typing.List[torch.Tensor]: with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): - (hidden, weight, labels, _maximum, _maximum_indices, _acc) = ctx.saved_tensors + (hidden, weight, labels, _maximum, _acc) = ctx.saved_tensors REDUCTION = ctx.REDUCTION d_hidden, d_weight = kernels.efficient_entropy_backward( dlogprobs, dentropy, hidden, weight, labels, - _maximum, _maximum_indices, _acc, + _maximum, _acc, REDUCTION) return (d_hidden, d_weight, None, None) From c11bdf1d5ba77a7a55ec16415d0f1475e3791302 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Mon, 17 Mar 2025 00:36:11 -0700 Subject: [PATCH 06/15] yapf formatting Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 66 ++- verl/utils/kernel/__init__.py | 4 +- verl/utils/kernel/kernels.py | 488 ++++++++++++---------- verl/utils/kernel/linear_cross_entropy.py | 16 +- 4 files changed, 294 insertions(+), 280 deletions(-) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index 79bd793d38e..decafcb1532 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -42,18 +42,19 @@ finally: from verl.utils.kernel import linear_cross_entropy, set_backward_method, BackwardEnum -def run_torch_entropy(hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor) -> typing.List[torch.Tensor]: - logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] - pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] - entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] - entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] + +def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] + entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1] + logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1] return logprobs, entropy + class TestLinearCrossEntropy: + def cleanup(self): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() @@ -68,20 +69,16 @@ def generate_hyper(self): self.dtype = torch.bfloat16 def generate_forward_inputs(self): - hidden = (torch.empty((self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_()) - weight = (torch.empty((self.hidden_size, self.vocab_size), dtype=self.dtype, device="cuda") - .uniform_(-0.5, 0.5) - .requires_grad_()) + hidden = (torch.empty((self.num_tokens, self.hidden_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) + weight = (torch.empty((self.hidden_size, self.vocab_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) labels = torch.randint(0, self.vocab_size, (self.num_tokens,), device="cuda") return hidden, weight, labels def generate_backward_inputs(self): - g_entropy = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda") - .uniform_(-0.5, 0.5)) - g_logprobs = (torch.empty((), dtype=self.dtype, device="cuda") - .uniform_(-1, 1)) + g_entropy = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)) + g_logprobs = (torch.empty((), dtype=self.dtype, device="cuda").uniform_(-1, 1)) return g_entropy, g_logprobs def verify_correctness(self): @@ -99,7 +96,7 @@ def verify_correctness(self): end_event = torch.cuda.Event(enable_timing=True) for i in range(iterations): - print(f"[INFO]: Iteration {i} / {iterations}...", end="\r") + print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r") hidden, weight, labels = self.generate_forward_inputs() start_event.record() @@ -121,19 +118,17 @@ def verify_correctness(self): g_entropy, g_logprobs = self.generate_backward_inputs() start_event.record() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), - (hidden, weight), - (g_entropy, g_logprobs), - retain_graph=False) + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) end_event.record() torch.cuda.synchronize() torch_backward_latency.append(start_event.elapsed_time(end_event)) start_event.record() (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), - (hidden, weight), - (g_entropy, g_logprobs), - retain_graph=False) + (hidden, weight), (g_entropy, g_logprobs), + retain_graph=False) end_event.record() torch.cuda.synchronize() kernel_backward_latency.append(start_event.elapsed_time(end_event)) @@ -158,7 +153,6 @@ def verify_correctness(self): print(f"[INFO]: Backward pass: kernel implementation average time: " f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") - def check_torch_storage(self): self.cleanup() self.generate_hyper() @@ -174,10 +168,9 @@ def check_torch_storage(self): g_entropy, g_logprobs = self.generate_backward_inputs() torch.cuda.reset_peak_memory_stats() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), - (hidden, weight), - (g_entropy, g_logprobs), - retain_graph=False) + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) torch.cuda.synchronize() torch_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 print(f"[INFO]: Torch Backward pass peak memory: {torch_backward_max_memory:.2f} MB") @@ -192,15 +185,14 @@ def check_kernel_storage(self): (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels) torch.cuda.synchronize() kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") + print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") g_entropy, g_logprobs = self.generate_backward_inputs() torch.cuda.reset_peak_memory_stats() - (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), - (hidden, weight), - (g_entropy, g_logprobs), - retain_graph=False) + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) torch.cuda.synchronize() kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") @@ -213,4 +205,4 @@ def check_kernel_storage(self): test.verify_correctness() test.check_torch_storage() - test.check_kernel_storage() \ No newline at end of file + test.check_kernel_storage() diff --git a/verl/utils/kernel/__init__.py b/verl/utils/kernel/__init__.py index 2453d149bad..1f1d534f632 100644 --- a/verl/utils/kernel/__init__.py +++ b/verl/utils/kernel/__init__.py @@ -32,6 +32,4 @@ from .linear_cross_entropy import linear_cross_entropy from .kernels import set_backward_method, BackwardEnum -__all__ = ["linear_cross_entropy", - "set_backward_method", - "BackwardEnum"] +__all__ = ["linear_cross_entropy", "set_backward_method", "BackwardEnum"] diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index 17e1a095bba..212da0418df 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -28,7 +28,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Implementations of the linear cross entropy with token entropy kernel. """ @@ -39,6 +38,7 @@ import triton import triton.language as tl + @dataclass class EntropyReductionEnum: """ @@ -48,6 +48,7 @@ class EntropyReductionEnum: _Sum = 1 _Mean = 2 + def get_entropy_reduction_enum_number(reduction: str) -> int: """ Get the enum number for the reduction method of cross entropy. @@ -79,16 +80,19 @@ def get_entropy_reduction_enum(ce_reduction: int) -> EntropyReductionEnum: raise ValueError(f"Invalid ce_reduction: {ce_reduction}") return _enum + @dataclass class BackwardEnum: """ Enum for the backward method. """ - _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight - _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + _Total_Fuse_MN = 0 # Fuse d_logits & d_hidden & d_weight, no intermediate storage, requires fp32 for d_hidden & d_weight + _Total_Separate = 1 # Store d_logits, no special requirements for d_hidden & d_weight + _BACKWARD: BackwardEnum = BackwardEnum._Total_Separate + def set_backward_method(backward_method: BackwardEnum): """ Set the backward method. @@ -96,25 +100,44 @@ def set_backward_method(backward_method: BackwardEnum): global _BACKWARD _BACKWARD = backward_method + @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, num_warps=4)], + configs=[triton.Config({ + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64 + }, num_stages=3, num_warps=4)], key=["num_tokens", "hidden_size", "vocab_size"], ) @triton.jit -def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr, - num_tokens, hidden_size, vocab_size, vocab_per_split, - stride_hidden_m, stride_hidden_k, - stride_weight_k, stride_weight_n, - max_ptr, stride_max_m, stride_max_n, - accu_ptr, stride_accu_m, stride_accu_n, - entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, - global_logprobs_ptr, stride_global_logprobs, - global_logprobs_scalar_ptr, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): +def efficient_entropy_kernel_general_mainloop( + hidden_ptr, + weight_ptr, + labels_ptr, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + stride_hidden_m, + stride_hidden_k, + stride_weight_k, + stride_weight_n, + max_ptr, + stride_max_m, + stride_max_n, + accu_ptr, + stride_accu_m, + stride_accu_n, + entropy_b_ptr, + stride_entropy_b_m, + stride_entropy_b_n, + global_logprobs_ptr, + stride_global_logprobs, + global_logprobs_scalar_ptr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr): """ forward mainloop """ @@ -131,7 +154,7 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr # create pointers for the first blocks of hidden offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = tl.arange(0, BLOCK_SIZE_K) - hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) # load labels for this block labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) @@ -144,18 +167,18 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for n in range(0, num_pid_n): offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) # iterate over K dimension logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): # load the next block of hidden and weight _hidden = tl.load(hidden_ptrs, - mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) _weight = tl.load(weight_ptrs, - mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[None,:] < (min((pid_n + 1) * vocab_per_split, vocab_size))), + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < (min( + (pid_n + 1) * vocab_per_split, vocab_size))), other=0.0) # GEMM @@ -172,16 +195,15 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr m_pid_n = tl.max(logits, axis=1) _max = tl.maximum(_max_old, m_pid_n) - exp_logits = tl.exp(logits - _max[:,None]) + exp_logits = tl.exp(logits - _max[:, None]) coeff = tl.exp(_max_old - _max) _accu = coeff * _accu + tl.sum(exp_logits, axis=1) _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) - label_mask = offs_bn[None,:] == labels[:,None] + label_mask = offs_bn[None, :] == labels[:, None] _logprobs += tl.sum(logits * label_mask, axis=1) - # store maximum offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_max_n = pid_n @@ -202,25 +224,14 @@ def efficient_entropy_kernel_general_mainloop(hidden_ptr, weight_ptr, labels_ptr tl.store(global_logprobs_ptrs, _logprobs, mask=mask) -@triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64}) - ], - key=["num_tokens", "num_splits"] -) +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) @triton.jit -def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n, - num_tokens, num_splits, - global_max_ptr, stride_global_max, - accu_ptr, stride_accu_m, stride_accu_n, - global_accu_ptr, stride_global_accu, - entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, - global_entropy_ptr, stride_global_entropy, - global_logprobs_ptr, stride_global_logprobs, - global_logprobs_scalar_ptr, - reduction: int, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr): +def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n, num_tokens, num_splits, + global_max_ptr, stride_global_max, accu_ptr, stride_accu_m, stride_accu_n, + global_accu_ptr, stride_global_accu, entropy_b_ptr, stride_entropy_b_m, + stride_entropy_b_n, global_entropy_ptr, stride_global_entropy, + global_logprobs_ptr, stride_global_logprobs, global_logprobs_scalar_ptr, + reduction: int, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): """ foward epilogue """ @@ -232,20 +243,16 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - max_ptrs = max_ptr + offs_m[:,None] * stride_max_m + offs_n[None,:] * stride_max_n + max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n - _max = tl.load(max_ptrs, - mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), - other=0.0) + _max = tl.load(max_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) - accu_ptrs = accu_ptr + offs_m[:,None] * stride_accu_m + offs_n[None,:] * stride_accu_n - _accu = tl.load(accu_ptrs, - mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), - other=0.0) + accu_ptrs = accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + _accu = tl.load(accu_ptrs, mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) - entropy_b_ptrs = entropy_b_ptr + offs_m[:,None] * stride_entropy_b_m + offs_n[None,:] * stride_entropy_b_n + entropy_b_ptrs = entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + offs_n[None, :] * stride_entropy_b_n _entropy_b = tl.load(entropy_b_ptrs, - mask=(offs_m[:,None] < num_tokens) & (offs_n[None,:] < num_splits), + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), other=0.0) # local reduction @@ -253,7 +260,7 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n _local_max = tl.max(_max, axis=1) global_max = tl.maximum(global_max, _local_max) - _scale = tl.exp(_max - global_max[:,None]) + _scale = tl.exp(_max - global_max[:, None]) _coeff = tl.exp(_max_old - global_max) global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) @@ -265,8 +272,8 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n # store entropy global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) - global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b - global_entropy_b = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + global_entropy_b = tl.log(global_accu) + global_max - global_entropy_b # entropy_a global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy tl.store(global_entropy_ptrs, global_entropy_b, mask=offs_m < num_tokens) # update logprobs @@ -283,6 +290,7 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + def efficient_entropy_foward(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, @@ -337,55 +345,46 @@ def efficient_entropy_foward(hidden: torch.Tensor, # 1D kernel launch, then split the tile def mainloop_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) - efficient_entropy_kernel_general_mainloop[mainloop_grid]( - hidden, weight, labels, - num_tokens, hidden_size, vocab_size, vocab_per_split, - hidden.stride(0), hidden.stride(1), - weight.stride(0), weight.stride(1), - _max, _max.stride(0), _max.stride(1), - _accu, _accu.stride(0), _accu.stride(1), - _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), - _logprobs, _logprobs.stride(0), - logprobs - ) + + efficient_entropy_kernel_general_mainloop[mainloop_grid](hidden, weight, labels, num_tokens, hidden_size, + vocab_size, vocab_per_split, hidden.stride(0), + hidden.stride(1), weight.stride(0), weight.stride(1), _max, + _max.stride(0), _max.stride(1), _accu, _accu.stride(0), + _accu.stride(1), _entropy_b, _entropy_b.stride(0), + _entropy_b.stride(1), _logprobs, _logprobs.stride(0), + logprobs) # reduction on maximum and maximum_indices def epilogue_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) - efficient_entropy_triton_kernel_epilogue[epilogue_grid]( - _max, _max.stride(0), _max.stride(1), - num_tokens, num_splits, - maximum, maximum.stride(0), - _accu, _accu.stride(0), _accu.stride(1), - acc, acc.stride(0), - _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), - entropy, entropy.stride(0), - _logprobs, _logprobs.stride(0), - logprobs, REDUCTION - ) + + efficient_entropy_triton_kernel_epilogue[epilogue_grid](_max, _max.stride(0), _max.stride(1), num_tokens, + num_splits, maximum, maximum.stride(0), _accu, + _accu.stride(0), _accu.stride(1), acc, + acc.stride(0), _entropy_b, _entropy_b.stride(0), + _entropy_b.stride(1), entropy, entropy.stride(0), _logprobs, + _logprobs.stride(0), logprobs, REDUCTION) return (logprobs, entropy, maximum, acc) @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, - num_stages=3, num_warps=4), - ], + configs=[ + triton.Config({ + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64 + }, num_stages=3, num_warps=4), + ], key=["num_tokens", "hidden_size", "vocab_size"], ) @triton.jit -def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, - hidden_size: int, - vocab_size: int, - vocab_per_split: int, - num_splits: int, - hidden_ptr, stride_hidden_m, stride_hidden_k, - weight_ptr, stride_weight_k, stride_weight_n, - maximum_ptr, - d_scale_ptr, stride_d_scale_m, stride_d_scale_n, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): +def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, hidden_size: int, vocab_size: int, + vocab_per_split: int, num_splits: int, hidden_ptr, + stride_hidden_m, stride_hidden_k, weight_ptr, stride_weight_k, + stride_weight_n, maximum_ptr, d_scale_ptr, stride_d_scale_m, + stride_d_scale_n, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): """ backward preprocess """ @@ -398,23 +397,23 @@ def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, # pointers for this block offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_k = tl.arange(0, BLOCK_SIZE_K) - hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) maximum = tl.load(maximum_ptr + offs_am, mask=offs_am < num_tokens) _scale = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for n in range(0, num_pid_n): offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load(hidden_ptrs, - mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:,None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) _weight = tl.load(weight_ptrs, - mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[None,:] < min((pid_n + 1) * vocab_per_split, vocab_size)), + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < min( + (pid_n + 1) * vocab_per_split, vocab_size)), other=0.0) logits = tl.dot(_hidden, _weight, logits) @@ -422,41 +421,38 @@ def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, weight_ptrs += BLOCK_SIZE_K * stride_weight_k hidden_ptrs -= hidden_size * stride_hidden_k - exp_logits = tl.exp(logits - maximum[:,None]) + exp_logits = tl.exp(logits - maximum[:, None]) _scale += tl.sum(exp_logits * logits, axis=1) # tl.store(d_scale_ptr + offs_am * stride_d_scale_m + pid_n * stride_d_scale_n, # _scale, mask=offs_am < num_tokens) tl.store(d_scale_ptr + pid_n * stride_d_scale_m + offs_am * stride_d_scale_n, - _scale, mask=(offs_am < num_tokens) & (pid_n < num_splits)) + _scale, + mask=(offs_am < num_tokens) & (pid_n < num_splits)) # NOTE: merge d_weight & d_hidden here, split along M & N @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 16}, - num_stages=3, num_warps=4)], + configs=[ + triton.Config({ + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16 + }, + num_stages=3, + num_warps=4) + ], key=["num_tokens", "hidden_size", "vocab_size"], ) @triton.jit -def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, - hidden_size: int, - vocab_size: int, - hidden_ptr, stride_hidden_m, stride_hidden_k, - weight_ptr, stride_weight_k, stride_weight_n, - labels_ptr, stride_labels, - maximum_ptr, stride_maximum, - accu_ptr, stride_accu, - d_entropy_ptr, stride_d_entropy, - d_logprobs_ptr, stride_d_logprobs, - reduction: int, - d_scale_ptr, stride_d_scale, - d_hidden_ptr, stride_d_hidden_m, stride_d_hidden_k, - d_weight_ptr, stride_d_weight_k, stride_d_weight_n, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): +def efficient_entropy_backward_kernel_general_mainloop_MN( + num_tokens: int, hidden_size: int, vocab_size: int, hidden_ptr, stride_hidden_m, stride_hidden_k, weight_ptr, + stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, stride_accu, + d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, d_scale_ptr, stride_d_scale, + d_hidden_ptr, stride_d_hidden_m, stride_d_hidden_k, d_weight_ptr, stride_d_weight_k, stride_d_weight_n, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """ backward mainloop, where d_logits & d_hidden & d_weight are fused """ @@ -483,41 +479,38 @@ def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, maximum_ptrs = maximum_ptr + offs_am * stride_maximum maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero accu_rcp = tl.fdiv(1.0, accu) d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: # none + if reduction == 0: # none d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: # sum + elif reduction == 1: # sum d_logprobs = tl.load(d_logprobs_ptr) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: # mean + else: # mean d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, - mask=offs_am < num_tokens, other=0.0) + d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, mask=offs_am < num_tokens, other=0.0) - hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) labels_ptrs = labels_ptr + offs_am * stride_labels labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) - d_hidden_ptrs = d_hidden_ptr + offs_am[:,None] * stride_d_hidden_m + offs_k[None,:] * stride_d_hidden_k - d_weight_ptrs = d_weight_ptr + offs_k[:,None] * stride_d_weight_k + offs_bn[None,:] * stride_d_weight_n + d_hidden_ptrs = d_hidden_ptr + offs_am[:, None] * stride_d_hidden_m + offs_k[None, :] * stride_d_hidden_k + d_weight_ptrs = d_weight_ptr + offs_k[:, None] * stride_d_weight_k + offs_bn[None, :] * stride_d_weight_n logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load(hidden_ptrs, - mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) - & (offs_am[:,None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) _weight = tl.load(weight_ptrs, - mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[None,:] < vocab_size), + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), other=0.0) logits = tl.dot(_hidden, _weight, logits) @@ -527,35 +520,33 @@ def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, hidden_ptrs -= hidden_size * stride_hidden_k weight_ptrs -= hidden_size * stride_weight_k - exp_logits = tl.exp(logits - maximum[:,None]) + exp_logits = tl.exp(logits - maximum[:, None]) - d_pd = logits * -d_entropy[:,None] - mask = offs_bn[None,:] == labels[:,None] - d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:,None], exp_logits) * mask + d_pd = logits * -d_entropy[:, None] + mask = offs_bn[None, :] == labels[:, None] + d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:, None], exp_logits) * mask coeff = d_scale * d_entropy * accu_rcp * accu_rcp + d_logprobs * accu_rcp - d_logits = exp_logits * coeff[:,None] - d_logits += exp_logits * d_pd * accu_rcp[:,None] + d_logits = exp_logits * coeff[:, None] + d_logits += exp_logits * d_pd * accu_rcp[:, None] # loop for d_weight & d_hidden for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load(hidden_ptrs, - mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) - & (offs_am[:,None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) _d_weight = tl.dot(tl.trans(_hidden).to(tl.float32), d_logits) - tl.atomic_add(d_weight_ptrs, _d_weight, - mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[None,:] < vocab_size)) + tl.atomic_add(d_weight_ptrs, + _d_weight, + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size)) _weight = tl.load(weight_ptrs, - mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[None,:] < vocab_size), - other=0.0) + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), + other=0.0) _d_hidden = tl.dot(d_logits, tl.trans(_weight).to(tl.float32)) - tl.atomic_add(d_hidden_ptrs, _d_hidden, - mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) - & (offs_am[:,None] < num_tokens)) + tl.atomic_add(d_hidden_ptrs, + _d_hidden, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens)) hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k weight_ptrs += BLOCK_SIZE_K * stride_weight_k @@ -565,29 +556,25 @@ def efficient_entropy_backward_kernel_general_mainloop_MN(num_tokens: int, # NOTE: split tile from d_logits' perspective @triton.autotune( - configs=[triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 16}, - num_stages=3, num_warps=4), - ], + configs=[ + triton.Config({ + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16 + }, + num_stages=3, + num_warps=4), + ], key=["num_tokens", "hidden_size", "vocab_size"], ) @triton.jit -def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, - hidden_size: int, - vocab_size: int, - hidden_ptr, stride_hidden_m, stride_hidden_k, - weight_ptr, stride_weight_k, stride_weight_n, - labels_ptr, stride_labels, - maximum_ptr, stride_maximum, - accu_ptr, stride_accu, - d_entropy_ptr, stride_d_entropy, - d_logprobs_ptr, stride_d_logprobs, - reduction: int, - d_scale_ptr, stride_d_scale, - d_logits_ptr, stride_d_logits_m, stride_d_logits_n, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr): +def efficient_entropy_backward_kernel_general_d_logits( + num_tokens: int, hidden_size: int, vocab_size: int, hidden_ptr, stride_hidden_m, stride_hidden_k, weight_ptr, + stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, stride_accu, + d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, d_scale_ptr, stride_d_scale, + d_logits_ptr, stride_d_logits_m, stride_d_logits_n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """ backward d_logits """ @@ -614,23 +601,22 @@ def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, maximum_ptrs = maximum_ptr + offs_am * stride_maximum maximum = tl.load(maximum_ptrs, mask=offs_am < num_tokens, other=0.0) accu_ptrs = accu_ptr + offs_am * stride_accu - accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero + accu = tl.load(accu_ptrs, mask=offs_am < num_tokens, other=1e-6) # epsilon to avoid division by zero accu_rcp = tl.fdiv(1.0, accu) d_entropy_ptrs = d_entropy_ptr + offs_am * stride_d_entropy d_entropy = tl.load(d_entropy_ptrs, mask=offs_am < num_tokens, other=0.0) - if reduction == 0: # none + if reduction == 0: # none d_logprobs_ptrs = d_logprobs_ptr + offs_am * stride_d_logprobs d_logprobs = tl.load(d_logprobs_ptrs, mask=offs_am < num_tokens, other=0.0) - elif reduction == 1: # sum + elif reduction == 1: # sum d_logprobs = tl.load(d_logprobs_ptr) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - else: # mean + else: # mean d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) - d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, - mask=offs_am < num_tokens, other=0.0) + d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, mask=offs_am < num_tokens, other=0.0) # d_acc_exp_logits = d_scale * d_entropy * accu_rcp * accu_rcp # d_acc_exp_logits += d_logprobs * accu_rcp @@ -641,20 +627,18 @@ def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, # d_max -= d_logprobs # d_max += accu * d_acc_exp_logits - hidden_ptrs = hidden_ptr + (offs_am[:,None] * stride_hidden_m + offs_k[None,:] * stride_hidden_k) - weight_ptrs = weight_ptr + (offs_k[:,None] * stride_weight_k + offs_bn[None,:] * stride_weight_n) + hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) + weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) labels_ptrs = labels_ptr + offs_am * stride_labels labels = tl.load(labels_ptrs, mask=offs_am < num_tokens, other=0) logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): _hidden = tl.load(hidden_ptrs, - mask=(offs_k[None,:] < hidden_size - k * BLOCK_SIZE_K) - & (offs_am[:,None] < num_tokens), + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), other=0.0) _weight = tl.load(weight_ptrs, - mask=(offs_k[:,None] < hidden_size - k * BLOCK_SIZE_K) - & (offs_bn[None,:] < vocab_size), + mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < vocab_size), other=0.0) logits = tl.dot(_hidden, _weight, logits) @@ -664,15 +648,15 @@ def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, hidden_ptrs -= hidden_size * stride_hidden_k weight_ptrs -= hidden_size * stride_weight_k - exp_logits = tl.exp(logits - maximum[:,None]) + exp_logits = tl.exp(logits - maximum[:, None]) - d_pd = logits * -d_entropy[:,None] - mask = offs_bn[None,:] == labels[:,None] - d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:,None], exp_logits) * mask + d_pd = logits * -d_entropy[:, None] + mask = offs_bn[None, :] == labels[:, None] + d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:, None], exp_logits) * mask coeff = d_scale * d_entropy * accu_rcp * accu_rcp + d_logprobs * accu_rcp - d_logits = exp_logits * coeff[:,None] - d_logits += exp_logits * d_pd * accu_rcp[:,None] + d_logits = exp_logits * coeff[:, None] + d_logits += exp_logits * d_pd * accu_rcp[:, None] # d_logits += exp_logits * logits * (-d_entropy * accu_rcp)[:,None] # d_logits -= tl.where(mask, d_logprobs[:,None], 0.0) @@ -682,20 +666,20 @@ def efficient_entropy_backward_kernel_general_d_logits(num_tokens: int, # d_logits += tl.where(mask, d_max[:,None], 0.0) # store d_logits - d_logits_ptrs = d_logits_ptr + offs_am[:,None] * stride_d_logits_m + offs_bn[None,:] * stride_d_logits_n - tl.store(d_logits_ptrs, d_logits.to(hidden_ptr.dtype.element_ty), - mask=(offs_am[:,None] < num_tokens) - & (offs_bn[None,:] < vocab_size)) + d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n + tl.store(d_logits_ptrs, + d_logits.to(hidden_ptr.dtype.element_ty), + mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size)) def efficient_entropy_backward(dlogprobs: torch.Tensor, - dentropy: torch.Tensor, - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - maximum: torch.Tensor, - acc: torch.Tensor, - reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: """ backward host function """ @@ -747,53 +731,95 @@ def efficient_entropy_backward(dlogprobs: torch.Tensor, def preprocess_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + efficient_entropy_backward_kernel_general_preprocess[preprocess_grid]( - num_tokens, hidden_size, vocab_size, vocab_per_split, num_splits, - hidden, hidden.stride(0), hidden.stride(1), - weight, weight.stride(0), weight.stride(1), + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + num_splits, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), maximum, - _d_scale_non_reduced, _d_scale_non_reduced.stride(0), _d_scale_non_reduced.stride(1), + _d_scale_non_reduced, + _d_scale_non_reduced.stride(0), + _d_scale_non_reduced.stride(1), ) _d_scale = _d_scale_non_reduced.sum(dim=0) if _BACKWARD == BackwardEnum._Total_Fuse_MN: + def mainloop_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) - * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + efficient_entropy_backward_kernel_general_mainloop_MN[mainloop_grid]( - num_tokens, hidden_size, vocab_size, - hidden, hidden.stride(0), hidden.stride(1), - weight, weight.stride(0), weight.stride(1), - labels, labels.stride(0), - maximum, maximum.stride(0), - acc, acc.stride(0), - dentropy, dentropy.stride(0), - dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + num_tokens, + hidden_size, + vocab_size, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - _d_scale, _d_scale.stride(0), - d_hidden, d_hidden.stride(0), d_hidden.stride(1), - d_weight, d_weight.stride(0), d_weight.stride(1), + _d_scale, + _d_scale.stride(0), + d_hidden, + d_hidden.stride(0), + d_hidden.stride(1), + d_weight, + d_weight.stride(0), + d_weight.stride(1), ) elif _BACKWARD == BackwardEnum._Total_Separate: _d_logits = torch.empty((num_tokens, vocab_size), device=hidden.device, dtype=hidden.dtype) + def d_logits_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) - * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * triton.cdiv(vocab_size, meta["BLOCK_SIZE_N"]),) + efficient_entropy_backward_kernel_general_d_logits[d_logits_grid]( - num_tokens, hidden_size, vocab_size, - hidden, hidden.stride(0), hidden.stride(1), - weight, weight.stride(0), weight.stride(1), - labels, labels.stride(0), - maximum, maximum.stride(0), - acc, acc.stride(0), - dentropy, dentropy.stride(0), - dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, + num_tokens, + hidden_size, + vocab_size, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - _d_scale, _d_scale.stride(0), - _d_logits, _d_logits.stride(0), _d_logits.stride(1), + _d_scale, + _d_scale.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), ) torch.matmul(_d_logits, weight.T, out=d_hidden) torch.matmul(hidden.T, _d_logits, out=d_weight) return d_hidden, d_weight - diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py index d0058e6a874..eafa12e1f6e 100644 --- a/verl/utils/kernel/linear_cross_entropy.py +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -33,7 +33,9 @@ import typing from . import kernels + class LinearCrossEntropy(torch.autograd.Function): + @staticmethod def forward(ctx, hidden: torch.Tensor, @@ -52,19 +54,15 @@ def forward(ctx, return logprobs, entropy @staticmethod - def backward(ctx, - dlogprobs: torch.Tensor, - dentropy: torch.Tensor) -> typing.List[torch.Tensor]: + def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> typing.List[torch.Tensor]: with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): (hidden, weight, labels, _maximum, _acc) = ctx.saved_tensors REDUCTION = ctx.REDUCTION - d_hidden, d_weight = kernels.efficient_entropy_backward( - dlogprobs, dentropy, - hidden, weight, labels, - _maximum, _acc, - REDUCTION) + d_hidden, d_weight = kernels.efficient_entropy_backward(dlogprobs, dentropy, hidden, weight, labels, + _maximum, _acc, REDUCTION) return (d_hidden, d_weight, None, None) -linear_cross_entropy = LinearCrossEntropy.apply \ No newline at end of file + +linear_cross_entropy = LinearCrossEntropy.apply From 9db4b76ed630d4e4321c538249c9dfce4428b5fb Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Mon, 17 Mar 2025 00:58:09 -0700 Subject: [PATCH 07/15] add torch.neg to logprobs Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 9 +++++---- verl/utils/kernel/kernels.py | 3 +++ 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index decafcb1532..37069457fb9 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -49,7 +49,8 @@ def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch. entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels) # [1] + logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction="none") # [num_tokens] + logprobs = torch.neg(logprobs) return logprobs, entropy @@ -78,7 +79,7 @@ def generate_forward_inputs(self): def generate_backward_inputs(self): g_entropy = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)) - g_logprobs = (torch.empty((), dtype=self.dtype, device="cuda").uniform_(-1, 1)) + g_logprobs = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)) return g_entropy, g_logprobs def verify_correctness(self): @@ -106,7 +107,7 @@ def verify_correctness(self): torch_forward_latency.append(start_event.elapsed_time(end_event)) start_event.record() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels) + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none") end_event.record() torch.cuda.synchronize() kernel_forward_latency.append(start_event.elapsed_time(end_event)) @@ -182,7 +183,7 @@ def check_kernel_storage(self): hidden, weight, labels = self.generate_forward_inputs() torch.cuda.reset_peak_memory_stats() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels) + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none") torch.cuda.synchronize() kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index 212da0418df..c6e2fba95bf 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -281,6 +281,7 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) global_logprobs = global_max + tl.log(global_accu) - global_logprobs + global_logprobs = -1 * global_logprobs if reduction == 0: tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) elif reduction == 1: @@ -493,6 +494,7 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( else: # mean d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, mask=offs_am < num_tokens, other=0.0) @@ -615,6 +617,7 @@ def efficient_entropy_backward_kernel_general_d_logits( else: # mean d_logprobs = tl.fdiv(tl.load(d_logprobs_ptr), num_tokens.to(tl.float32)) d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) + d_logprobs = -1 * d_logprobs d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, mask=offs_am < num_tokens, other=0.0) From d4fdf7eef9efa64628f1253a8c89047978ba01ec Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Mon, 17 Mar 2025 01:17:34 -0700 Subject: [PATCH 08/15] moved d_scale into forward pass Signed-off-by: Jianbing Dong --- verl/utils/kernel/kernels.py | 147 ++++++++-------------- verl/utils/kernel/linear_cross_entropy.py | 8 +- 2 files changed, 53 insertions(+), 102 deletions(-) diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index c6e2fba95bf..b058fdecd0b 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -134,6 +134,9 @@ def efficient_entropy_kernel_general_mainloop( global_logprobs_ptr, stride_global_logprobs, global_logprobs_scalar_ptr, + d_scale_non_reduced_ptr, + stride_d_scale_non_reduced_m, + stride_d_scale_non_reduced_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -165,6 +168,7 @@ def efficient_entropy_kernel_general_mainloop( _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _scale = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for n in range(0, num_pid_n): offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) @@ -204,6 +208,9 @@ def efficient_entropy_kernel_general_mainloop( label_mask = offs_bn[None, :] == labels[:, None] _logprobs += tl.sum(logits * label_mask, axis=1) + # preprocess for backward + _scale = coeff * _scale + tl.sum(exp_logits * logits, axis=1) + # store maximum offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_max_n = pid_n @@ -223,6 +230,12 @@ def efficient_entropy_kernel_general_mainloop( # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + # store d_scale_non_reduced + tl.store(d_scale_non_reduced_ptr + offs_max_n * stride_d_scale_non_reduced_n + + offs_max_m * stride_d_scale_non_reduced_m, + _scale, + mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) + @triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) @triton.jit @@ -231,7 +244,10 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_accu_ptr, stride_global_accu, entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, global_entropy_ptr, stride_global_entropy, global_logprobs_ptr, stride_global_logprobs, global_logprobs_scalar_ptr, - reduction: int, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + reduction: int, + d_scale_non_reduced_ptr, stride_d_scale_non_reduced_m, stride_d_scale_non_reduced_n, + d_scale_ptr, stride_d_scale, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): """ foward epilogue """ @@ -241,6 +257,7 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_d_scale = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n @@ -265,6 +282,13 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + # preprocess for backward + d_scale = tl.load(d_scale_non_reduced_ptr + offs_m[:, None] * stride_d_scale_non_reduced_m + + offs_n[None, :] * stride_d_scale_non_reduced_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + global_d_scale = _coeff * global_d_scale + tl.sum(_scale * d_scale, axis=1) + # store maximum_ptrs = global_max_ptr + offs_m * stride_global_max tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) @@ -291,6 +315,10 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + # store d_scale + d_scale_ptrs = d_scale_ptr + offs_m * stride_d_scale + tl.store(d_scale_ptrs, global_d_scale, mask=offs_m < num_tokens) + def efficient_entropy_foward(hidden: torch.Tensor, weight: torch.Tensor, @@ -343,6 +371,11 @@ def efficient_entropy_foward(hidden: torch.Tensor, assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda + # preprocess for backward + _d_scale_non_reduced = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _d_scale = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert _d_scale_non_reduced.is_contiguous() and _d_scale_non_reduced.is_cuda + # 1D kernel launch, then split the tile def mainloop_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) @@ -353,8 +386,9 @@ def mainloop_grid(meta): _max.stride(0), _max.stride(1), _accu, _accu.stride(0), _accu.stride(1), _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), _logprobs, _logprobs.stride(0), - logprobs) - + logprobs, + _d_scale_non_reduced, _d_scale_non_reduced.stride(0), + _d_scale_non_reduced.stride(1)) # reduction on maximum and maximum_indices def epilogue_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) @@ -364,73 +398,12 @@ def epilogue_grid(meta): _accu.stride(0), _accu.stride(1), acc, acc.stride(0), _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), entropy, entropy.stride(0), _logprobs, - _logprobs.stride(0), logprobs, REDUCTION) + _logprobs.stride(0), logprobs, REDUCTION, + _d_scale_non_reduced, _d_scale_non_reduced.stride(0), + _d_scale_non_reduced.stride(1), + _d_scale, _d_scale.stride(0)) - return (logprobs, entropy, maximum, acc) - - -@triton.autotune( - configs=[ - triton.Config({ - "BLOCK_SIZE_M": 32, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64 - }, num_stages=3, num_warps=4), - ], - key=["num_tokens", "hidden_size", "vocab_size"], -) -@triton.jit -def efficient_entropy_backward_kernel_general_preprocess(num_tokens: int, hidden_size: int, vocab_size: int, - vocab_per_split: int, num_splits: int, hidden_ptr, - stride_hidden_m, stride_hidden_k, weight_ptr, stride_weight_k, - stride_weight_n, maximum_ptr, d_scale_ptr, stride_d_scale_m, - stride_d_scale_n, BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr): - """ - backward preprocess - """ - pid = tl.program_id(0) - num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) - num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) - pid_m = pid % num_pid_m - pid_n = pid // num_pid_m - - # pointers for this block - offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_k = tl.arange(0, BLOCK_SIZE_K) - hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) - - maximum = tl.load(maximum_ptr + offs_am, mask=offs_am < num_tokens) - - _scale = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - for n in range(0, num_pid_n): - offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) - - logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): - _hidden = tl.load(hidden_ptrs, - mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) & (offs_am[:, None] < num_tokens), - other=0.0) - _weight = tl.load(weight_ptrs, - mask=(offs_k[:, None] < hidden_size - k * BLOCK_SIZE_K) & (offs_bn[None, :] < min( - (pid_n + 1) * vocab_per_split, vocab_size)), - other=0.0) - logits = tl.dot(_hidden, _weight, logits) - - hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k - weight_ptrs += BLOCK_SIZE_K * stride_weight_k - hidden_ptrs -= hidden_size * stride_hidden_k - - exp_logits = tl.exp(logits - maximum[:, None]) - - _scale += tl.sum(exp_logits * logits, axis=1) - - # tl.store(d_scale_ptr + offs_am * stride_d_scale_m + pid_n * stride_d_scale_n, - # _scale, mask=offs_am < num_tokens) - tl.store(d_scale_ptr + pid_n * stride_d_scale_m + offs_am * stride_d_scale_n, - _scale, - mask=(offs_am < num_tokens) & (pid_n < num_splits)) + return (logprobs, entropy, maximum, acc, _d_scale) # NOTE: merge d_weight & d_hidden here, split along M & N @@ -682,6 +655,7 @@ def efficient_entropy_backward(dlogprobs: torch.Tensor, labels: torch.Tensor, maximum: torch.Tensor, acc: torch.Tensor, + d_scale: torch.Tensor, reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: """ backward host function @@ -728,31 +702,8 @@ def efficient_entropy_backward(dlogprobs: torch.Tensor, assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - # _d_scale_non_reduced = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _d_scale_non_reduced = torch.empty((num_splits, num_tokens), device=hidden.device, dtype=torch.float32) - assert _d_scale_non_reduced.is_contiguous() and _d_scale_non_reduced.is_cuda - - def preprocess_grid(meta): - return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) - - efficient_entropy_backward_kernel_general_preprocess[preprocess_grid]( - num_tokens, - hidden_size, - vocab_size, - vocab_per_split, - num_splits, - hidden, - hidden.stride(0), - hidden.stride(1), - weight, - weight.stride(0), - weight.stride(1), - maximum, - _d_scale_non_reduced, - _d_scale_non_reduced.stride(0), - _d_scale_non_reduced.stride(1), - ) - _d_scale = _d_scale_non_reduced.sum(dim=0) + assert d_scale.is_contiguous() and d_scale.is_cuda + assert d_scale.shape == (num_tokens,) if _BACKWARD == BackwardEnum._Total_Fuse_MN: @@ -780,8 +731,8 @@ def mainloop_grid(meta): dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - _d_scale, - _d_scale.stride(0), + d_scale, + d_scale.stride(0), d_hidden, d_hidden.stride(0), d_hidden.stride(1), @@ -816,8 +767,8 @@ def d_logits_grid(meta): dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - _d_scale, - _d_scale.stride(0), + d_scale, + d_scale.stride(0), _d_logits, _d_logits.stride(0), _d_logits.stride(1), diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py index eafa12e1f6e..1a8788ef831 100644 --- a/verl/utils/kernel/linear_cross_entropy.py +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -45,10 +45,10 @@ def forward(ctx, with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) - logprobs, entropy, _maximum, _acc =\ + logprobs, entropy, _maximum, _acc, _d_scale =\ kernels.efficient_entropy_foward(hidden, weight, labels, REDUCTION) - ctx.save_for_backward(hidden, weight, labels, _maximum, _acc) + ctx.save_for_backward(hidden, weight, labels, _maximum, _acc, _d_scale) ctx.REDUCTION = REDUCTION return logprobs, entropy @@ -56,11 +56,11 @@ def forward(ctx, @staticmethod def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> typing.List[torch.Tensor]: with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): - (hidden, weight, labels, _maximum, _acc) = ctx.saved_tensors + (hidden, weight, labels, _maximum, _acc, _d_scale) = ctx.saved_tensors REDUCTION = ctx.REDUCTION d_hidden, d_weight = kernels.efficient_entropy_backward(dlogprobs, dentropy, hidden, weight, labels, - _maximum, _acc, REDUCTION) + _maximum, _acc, _d_scale, REDUCTION) return (d_hidden, d_weight, None, None) From 1e88718cb589f2466a6bd8b6307baaf75daff046 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Mon, 17 Mar 2025 01:19:11 -0700 Subject: [PATCH 09/15] yapf formatting Signed-off-by: Jianbing Dong --- verl/utils/kernel/kernels.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index b058fdecd0b..5fb57fb170b 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -231,8 +231,8 @@ def efficient_entropy_kernel_general_mainloop( tl.store(global_logprobs_ptrs, _logprobs, mask=mask) # store d_scale_non_reduced - tl.store(d_scale_non_reduced_ptr + offs_max_n * stride_d_scale_non_reduced_n - + offs_max_m * stride_d_scale_non_reduced_m, + tl.store(d_scale_non_reduced_ptr + offs_max_n * stride_d_scale_non_reduced_n + + offs_max_m * stride_d_scale_non_reduced_m, _scale, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) @@ -244,9 +244,8 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_accu_ptr, stride_global_accu, entropy_b_ptr, stride_entropy_b_m, stride_entropy_b_n, global_entropy_ptr, stride_global_entropy, global_logprobs_ptr, stride_global_logprobs, global_logprobs_scalar_ptr, - reduction: int, - d_scale_non_reduced_ptr, stride_d_scale_non_reduced_m, stride_d_scale_non_reduced_n, - d_scale_ptr, stride_d_scale, + reduction: int, d_scale_non_reduced_ptr, stride_d_scale_non_reduced_m, + stride_d_scale_non_reduced_n, d_scale_ptr, stride_d_scale, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): """ foward epilogue @@ -283,10 +282,10 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) # preprocess for backward - d_scale = tl.load(d_scale_non_reduced_ptr + offs_m[:, None] * stride_d_scale_non_reduced_m - + offs_n[None, :] * stride_d_scale_non_reduced_n, - mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), - other=0.0) + d_scale = tl.load(d_scale_non_reduced_ptr + offs_m[:, None] * stride_d_scale_non_reduced_m + + offs_n[None, :] * stride_d_scale_non_reduced_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) global_d_scale = _coeff * global_d_scale + tl.sum(_scale * d_scale, axis=1) # store @@ -386,9 +385,10 @@ def mainloop_grid(meta): _max.stride(0), _max.stride(1), _accu, _accu.stride(0), _accu.stride(1), _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), _logprobs, _logprobs.stride(0), - logprobs, - _d_scale_non_reduced, _d_scale_non_reduced.stride(0), + logprobs, _d_scale_non_reduced, + _d_scale_non_reduced.stride(0), _d_scale_non_reduced.stride(1)) + # reduction on maximum and maximum_indices def epilogue_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) @@ -400,8 +400,8 @@ def epilogue_grid(meta): _entropy_b.stride(1), entropy, entropy.stride(0), _logprobs, _logprobs.stride(0), logprobs, REDUCTION, _d_scale_non_reduced, _d_scale_non_reduced.stride(0), - _d_scale_non_reduced.stride(1), - _d_scale, _d_scale.stride(0)) + _d_scale_non_reduced.stride(1), _d_scale, + _d_scale.stride(0)) return (logprobs, entropy, maximum, acc, _d_scale) From 453a6958216f1afd3d0595a27f3f6adfe623d181 Mon Sep 17 00:00:00 2001 From: Blue Space <57280232+ETOgaosion@users.noreply.github.com> Date: Thu, 20 Mar 2025 16:31:49 +0800 Subject: [PATCH 10/15] Integrate fused Linear Cross Entropy and memory efficient Vocab Parallel Entropy (#1) * try to test linear cross attention and integrate it * add solid veRL unit tests * fix correctness * add @vermouth1992's VocabParallelEntropy optimization * add @vermouth1992's VocabParallelEntropy optimization * fix some bugs * add unit tests to kernels * add bytedance copyright * add vocab_parallel_entropy tests * format --- .github/workflows/kernels.yml | 58 +++++++++ tests/kernel/run_vocab_parallel_entropy.sh | 18 +++ tests/kernel/test_linear_cross_entropy.py | 100 +++++++++------ tests/kernel/test_vocab_parallel_entropy.py | 118 ++++++++++++++++++ verl/trainer/config/ppo_megatron_trainer.yaml | 1 + verl/trainer/config/ppo_trainer.yaml | 1 + verl/utils/megatron/tensor_parallel.py | 24 ++-- verl/utils/torch_functional.py | 6 +- verl/workers/actor/dp_actor.py | 17 ++- 9 files changed, 291 insertions(+), 52 deletions(-) create mode 100644 .github/workflows/kernels.yml create mode 100644 tests/kernel/run_vocab_parallel_entropy.sh create mode 100644 tests/kernel/test_vocab_parallel_entropy.py diff --git a/.github/workflows/kernels.yml b/.github/workflows/kernels.yml new file mode 100644 index 00000000000..987800221ba --- /dev/null +++ b/.github/workflows/kernels.yml @@ -0,0 +1,58 @@ +name: kernels +# latest version: Megatron-LM core_r0.11.0 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0 + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - .github/workflows/kernels.yml + pull_request: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - "verl/trainer/config/*.yaml" + - .github/workflows/kernels.yml + - "tests/e2e/*.sh" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + e2e_gsm8k_megatron: + runs-on: [self-hosted, l20-0] + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-megatron0.11.0-v0.0.6 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test] + - name: Testing LinearCrossEntropy Correction, Computation Time and Memory Consumption + run: | + python3 tests/kernel/test_linear_cross_entropy.py + - name: Testing VocabParallelEntropy + run: | + bash tests/kernel/run_vocab_parallel_entropy.sh \ No newline at end of file diff --git a/tests/kernel/run_vocab_parallel_entropy.sh b/tests/kernel/run_vocab_parallel_entropy.sh new file mode 100644 index 00000000000..ee7cd3a7568 --- /dev/null +++ b/tests/kernel/run_vocab_parallel_entropy.sh @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env bash + +set -e -x +torchrun --nproc-per-node=8 --standalone tests/kernel/test_vocab_parallel_entropy.py \ No newline at end of file diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index 37069457fb9..cafe2b49c54 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -42,18 +42,38 @@ finally: from verl.utils.kernel import linear_cross_entropy, set_backward_method, BackwardEnum +import verl.utils.torch_functional as verl_F +from verl.utils.torch_functional import logprobs_from_logits -def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: +compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + + +def run_torch_entropy(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction="none") -> typing.List[torch.Tensor]: logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction="none") # [num_tokens] + logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] logprobs = torch.neg(logprobs) return logprobs, entropy +def run_verl_actor_entropy(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction="none") -> typing.List[torch.Tensor]: + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + # compute entropy + entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + logprobs = logprobs_from_logits(logits=logits, labels=labels) + return logprobs, entropy + + class TestLinearCrossEntropy: def cleanup(self): @@ -82,14 +102,14 @@ def generate_backward_inputs(self): g_logprobs = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)) return g_entropy, g_logprobs - def verify_correctness(self): + def verify_correctness(self, iterations=5): self.cleanup() self.generate_hyper() - iterations = 5 - torch_forward_latency = list() torch_backward_latency = list() + verl_forward_latency = list() + verl_backward_latency = list() kernel_forward_latency = list() kernel_backward_latency = list() @@ -97,7 +117,7 @@ def verify_correctness(self): end_event = torch.cuda.Event(enable_timing=True) for i in range(iterations): - print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r") + print(f"[INFO]: Iteration {i + 1} / {iterations}...", end='\r') hidden, weight, labels = self.generate_forward_inputs() start_event.record() @@ -106,14 +126,24 @@ def verify_correctness(self): torch.cuda.synchronize() torch_forward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (verl_logprobs, verl_entropy) = run_verl_actor_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + verl_forward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none") end_event.record() torch.cuda.synchronize() kernel_forward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(verl_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) # backward g_entropy, g_logprobs = self.generate_backward_inputs() @@ -126,6 +156,14 @@ def verify_correctness(self): torch.cuda.synchronize() torch_backward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (d_verl_hidden, d_verl_weight) = torch.autograd.grad((verl_entropy, verl_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + verl_backward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), @@ -134,12 +172,18 @@ def verify_correctness(self): torch.cuda.synchronize() kernel_backward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=1e-2, rtol=1e-4) # remove first latency torch_forward_latency = torch_forward_latency[1:] torch_backward_latency = torch_backward_latency[1:] + verl_forward_latency = verl_forward_latency[1:] + verl_backward_latency = verl_backward_latency[1:] kernel_forward_latency = kernel_forward_latency[1:] kernel_backward_latency = kernel_backward_latency[1:] @@ -149,54 +193,41 @@ def verify_correctness(self): f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") print(f"[INFO]: Backward pass: torch implementation average time: " f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: VeRL implementation average time: " + f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: VeRL implementation average time: " + f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms") print(f"[INFO]: Forward pass: Kernel implementation average time: " f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") print(f"[INFO]: Backward pass: kernel implementation average time: " f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") - def check_torch_storage(self): + def check_storage(self, method_name, run_forward, reduction="none"): self.cleanup() self.generate_hyper() hidden, weight, labels = self.generate_forward_inputs() torch.cuda.reset_peak_memory_stats() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + (logprobs, entropy) = run_forward(hidden, weight, labels, reduction) torch.cuda.synchronize() torch_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Torch Forward pass peak memory: {torch_max_memory:.2f} MB") + print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") g_entropy, g_logprobs = self.generate_backward_inputs() torch.cuda.reset_peak_memory_stats() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) torch.cuda.synchronize() torch_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Torch Backward pass peak memory: {torch_backward_max_memory:.2f} MB") - - def check_kernel_storage(self): - self.cleanup() - self.generate_hyper() + print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") - hidden, weight, labels = self.generate_forward_inputs() - - torch.cuda.reset_peak_memory_stats() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none") - torch.cuda.synchronize() - kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") - - g_entropy, g_logprobs = self.generate_backward_inputs() - - torch.cuda.reset_peak_memory_stats() - (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), - (g_entropy, g_logprobs), - retain_graph=False) - torch.cuda.synchronize() - kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + def check_storage_all(self): + self.check_storage("Torch", run_torch_entropy) + self.check_storage("VeRL", run_verl_actor_entropy) + self.check_storage("Kernel", linear_cross_entropy) if __name__ == "__main__": @@ -204,6 +235,5 @@ def check_kernel_storage(self): test = TestLinearCrossEntropy() - test.verify_correctness() - test.check_torch_storage() - test.check_kernel_storage() + test.verify_correctness(100) + test.check_storage_all() diff --git a/tests/kernel/test_vocab_parallel_entropy.py b/tests/kernel/test_vocab_parallel_entropy.py new file mode 100644 index 00000000000..5fd899f2d62 --- /dev/null +++ b/tests/kernel/test_vocab_parallel_entropy.py @@ -0,0 +1,118 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ['NCCL_DEBUG'] = 'WARN' + +import torch +import torch.distributed + +from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy +from verl.utils.torch_functional import logprobs_from_logits, entropy_from_logits + +from verl.utils.debug import log_gpu_memory_usage + +from megatron.core import mpu + + +class Utils: + world_size = torch.cuda.device_count() + rank = int(os.environ.get('LOCAL_RANK', '0')) + + @staticmethod + def initialize_distributed(): + print(f'Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}') + torch.cuda.set_device(Utils.rank % torch.cuda.device_count()) + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '7000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group(backend='nccl', + world_size=Utils.world_size, + rank=Utils.rank, + init_method=init_method) + print(f'successfully created process group') + + @staticmethod + def destroy_model_parallel(): + mpu.destroy_model_parallel() + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + @staticmethod + def initialize_model_parallel(tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None): + mpu.destroy_model_parallel() + if not torch.distributed.is_initialized(): + Utils.initialize_distributed() + mpu.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) + + +def test_vocab_parallel_entropy(): + # check vocab_parallel_entropy + Utils.world_size = 8 + Utils.initialize_model_parallel(8, 1) + + batch_size = 2 + seqlen = 128 + vocab_size = 155136 + + logits = torch.randn(batch_size * seqlen, vocab_size, device='cuda', requires_grad=True) + target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device='cuda', dtype=torch.int64) + + # broadcast across tp + torch.distributed.broadcast(logits, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + torch.distributed.broadcast(target, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + + tp_rank = mpu.get_tensor_model_parallel_rank() + vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size() + + # get the local logits of each tp + vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * + vocab_size_per_tp].requires_grad_() + logits.grad = None + vocab_parallel_logits.grad = None + + log_gpu_memory_usage('begin') + output_entropy = vocab_parallel_entropy(vocab_parallel_logits) + log_gpu_memory_usage('after forward') + grad_output = torch.randn_like(output_entropy) + output_entropy.backward(grad_output) + log_gpu_memory_usage('after backward') + + target_entropy = entropy_from_logits(logits) + torch.testing.assert_close(output_entropy, target_entropy) + target_entropy.backward(grad_output) + torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits.grad) + # make sure logits is not altered + torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits) + + if mpu.get_tensor_model_parallel_rank() == 0: + print('test_vocab_parallel_entropy passes') + + Utils.destroy_model_parallel() + + +if __name__ == '__main__': + test_vocab_parallel_entropy() diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index afd697faaae..0bb8bbf6298 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -34,6 +34,7 @@ actor_rollout_ref: kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: True + use_fused_kernels: True optim: lr: 1e-6 clip_grad: 1.0 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 2bd634aaa75..07eeea37c54 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -41,6 +41,7 @@ actor_rollout_ref: ulysses_sequence_parallel_size: 1 # sp size checkpoint: contents: ['model', 'hf_model', 'optimizer', 'extra'] # with 'hf_model' you can save whole model as hf format, now only use sharded model checkpoint to save space + use_fused_kernels: True optim: lr: 1e-6 lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. diff --git a/verl/utils/megatron/tensor_parallel.py b/verl/utils/megatron/tensor_parallel.py index 032cabbedbe..bdbee9e04f5 100644 --- a/verl/utils/megatron/tensor_parallel.py +++ b/verl/utils/megatron/tensor_parallel.py @@ -101,19 +101,21 @@ class _VocabParallelEntropy(torch.autograd.Function): @staticmethod def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: + + @torch.compile(dynamic=True) + def mul_reduce(a, b): + return (a * b).sum(dim=-1, keepdim=True) + logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group()) normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max - normalized_exp_logits = normalized_vocab_parallel_logits.exp() + normalized_exp_logits = normalized_vocab_parallel_logits.exp_() normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group()) - softmax_logits = normalized_exp_logits / normalized_sum_exp_logits + softmax_logits = normalized_exp_logits.div(normalized_sum_exp_logits) # This consume too much VRAM, causing OOM, try optimize # sum_softmax_times_logits = (softmax_logits * vocab_parallel_logits).sum(dim=-1, keepdim=True) - original_shape = softmax_logits.shape - sum_softmax_times_logits = torch.bmm(softmax_logits.view(-1, 1, original_shape[-1]), - vocab_parallel_logits.view(-1, original_shape[-1], - 1)).view(original_shape[:-1] + (1,)) + sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group()) entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) @@ -122,8 +124,14 @@ def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors - grad_input = grad_output.unsqueeze(dim=-1) * softmax_logits * (sum_softmax_times_logits - vocab_parallel_logits) - return grad_input + # reuse softmax_logits as grad + vocab_parallel_logits.sub_(sum_softmax_times_logits) + softmax_logits.mul_(vocab_parallel_logits) + softmax_logits.mul_(grad_output.unsqueeze(dim=-1)) + # recover vocab_parallel_logits + vocab_parallel_logits.add_(sum_softmax_times_logits) + softmax_logits.mul_(-1) + return softmax_logits def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor: diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 32bfda2278f..4dc954c5c92 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -25,9 +25,9 @@ try: from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True + FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True except ImportError: - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False + FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False def gather_from_labels(data, label): @@ -49,7 +49,7 @@ def logprobs_from_logits(logits, labels): """ See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 """ - if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: + if FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: batch_dim = logits.shape[:-1] last_dim = logits.shape[-1] logits = logits.reshape(-1, last_dim) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index eff5ffa26d9..2b74479857c 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -30,6 +30,7 @@ from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F +from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis @@ -52,6 +53,7 @@ def __init__( print(f'Actor use_remove_padding={self.use_remove_padding}') self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + self.use_fused_kernels = config.use_fused_kernels self.compute_entropy_from_logits = ( torch.compile(verl_F.entropy_from_logits, dynamic=True) @@ -114,13 +116,16 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, use_cache=False) # prevent model thinks we are generating logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - logits_rmpad.div_(temperature) + if not self.use_fused_kernels: + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) - # compute entropy - entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) - - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + else: + weights = torch.eye(logits_rmpad.size(-1), device=logits_rmpad.device) / temperature + log_probs, entropy_rmpad = linear_cross_entropy(logits_rmpad, weights, input_ids_rmpad_rolled) # gather log_prob if sp > 1 if self.use_ulysses_sp: From ce7eedc99881058302757859eb28333d410c3177 Mon Sep 17 00:00:00 2001 From: Blue Space <57280232+ETOgaosion@users.noreply.github.com> Date: Thu, 20 Mar 2025 16:53:41 +0800 Subject: [PATCH 11/15] fix reference also use dp_actor bug (#2) Fix bugs in CI. --- verl/trainer/config/ppo_megatron_trainer.yaml | 1 + verl/trainer/config/ppo_trainer.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 0bb8bbf6298..9513e391bba 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -65,6 +65,7 @@ actor_rollout_ref: param_offload: False log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null + use_fused_kernels: True rollout: name: vllm temperature: 1.0 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 07eeea37c54..7467cae167c 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -67,6 +67,7 @@ actor_rollout_ref: log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + use_fused_kernels: True rollout: name: vllm temperature: 1.0 From fbf4c575aac0def3d76fc61cae2e7fe701ac692c Mon Sep 17 00:00:00 2001 From: Jianbing-D <69858819+Jianbing-D@users.noreply.github.com> Date: Tue, 25 Mar 2025 16:53:56 +0800 Subject: [PATCH 12/15] Linear cross entropy Add TP Support (#3) * add tp support to torch's ops, for correctness checking Signed-off-by: Jianbing Dong * removed redundant d_scale Signed-off-by: Jianbing Dong * add tp along vocab_size dimension Signed-off-by: Jianbing Dong * merge accumulate & entropy_b to one buffer Signed-off-by: Jianbing Dong * add dedicated stream for overlapping _logprobs Signed-off-by: Jianbing Dong * format Signed-off-by: Jianbing Dong * update test api Signed-off-by: Jianbing Dong --------- Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 352 +++++++++++++++++++++- verl/utils/kernel/kernels.py | 324 +++++++++++++------- verl/utils/kernel/linear_cross_entropy.py | 21 +- 3 files changed, 570 insertions(+), 127 deletions(-) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index cafe2b49c54..b7cde24b1c6 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -30,6 +30,7 @@ # limitations under the License. import torch +import torch.distributed as dist import typing try: @@ -62,6 +63,90 @@ def run_torch_entropy(hidden: torch.Tensor, return logprobs, entropy +class TorchEntropyTP(torch.autograd.Function): + """ + it is used for testing the correctness of the kernel + it is not efficient and is not recommended to use in practice + """ + + @staticmethod + def forward(ctx, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, + dist_process_group: torch.distributed.ProcessGroup): + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + whole_logits = torch.empty((logits.shape[0], logits.shape[1] * dist.get_world_size(dist_process_group)), + dtype=logits.dtype, + device=logits.device) + whole_logits_ref = [ + whole_logits[:, i * logits.shape[1]:(i + 1) * logits.shape[1]] + for i in range(dist.get_world_size(dist_process_group)) + ] + dist.all_gather(whole_logits_ref, logits, group=dist_process_group) + + pd = torch.nn.functional.softmax(whole_logits, dim=-1) + entropy_a = torch.logsumexp(whole_logits, dim=-1) # [num_tokens] + entropy_b = torch.sum(pd * whole_logits, dim=-1) # [num_tokens] + entropy = entropy_a - entropy_b + + logprobs = torch.nn.functional.cross_entropy(whole_logits, labels, reduction="none") + logprobs = torch.neg(logprobs) + + ctx.save_for_backward(hidden, weight, labels, whole_logits, entropy_b) + ctx.dist_process_group = dist_process_group + + return logprobs, entropy + + @staticmethod + def backward(ctx, g_logprobs: torch.Tensor, g_entropy: torch.Tensor): + hidden, weight, labels, whole_logits, entropy_b = ctx.saved_tensors + dist_process_group = ctx.dist_process_group + + batch_size, hidden_size = hidden.shape + vocab_size = weight.shape[1] + world_size = dist.get_world_size(dist_process_group) + rank = dist.get_rank(dist_process_group) + + # Compute softmax probabilities + maximum, _ = torch.max(whole_logits, dim=-1, keepdim=True) + exp_logits = torch.exp(whole_logits - maximum) + accumulate = exp_logits.sum(dim=-1, keepdim=True) + pd = exp_logits / accumulate + + # Gradient for entropy + # entropy = entropy_a - entropy_b + # entropy_a = log(sum(exp(logits))) + # entropy_b = sum(pd * logits) + # d_entropy_a/d_logits = pd + # d_entropy_b/d_logits = pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = d_entropy_a - d_entropy_b + # d_entropy/d_logits = pd - pd * (logits - b.unsqueeze(1) + 1) + # d_entropy/d_logits = -pd * (logits - b.unsqueeze(1)) + d_logits_entropy = g_entropy.unsqueeze(1) * (-pd * (whole_logits - entropy_b.unsqueeze(1))) + + # Gradient for logprobs + # logprobs = -cross_entropy = -log(pd[labels]) + # d_logprobs/d_logits = (pd - one_hot(labels)) + one_hot = torch.zeros_like(whole_logits) + one_hot.scatter_(1, labels.unsqueeze(1), 1) + g_logprobs = torch.neg(g_logprobs) + d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - one_hot) + # NOTE: This will lead to wrong result + # d_logits_logprobs = g_logprobs.unsqueeze(1) * (pd - 1) * one_hot + + # Combine gradients + d_logits = d_logits_entropy + d_logits_logprobs + + # Get local slice of gradients + local_d_logits = d_logits[:, rank * vocab_size:(rank + 1) * vocab_size] + + # Compute gradients for hidden and weight + d_hidden = torch.matmul(local_d_logits, weight.to(torch.float32).T) + d_weight = torch.matmul(hidden.to(torch.float32).T, local_d_logits) + + return d_hidden, d_weight, None, None + + +run_torch_entropy_tp = TorchEntropyTP.apply + def run_verl_actor_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, @@ -230,10 +315,271 @@ def check_storage_all(self): self.check_storage("Kernel", linear_cross_entropy) +class TestLinearCrossEntropy_TensorParallel: + + def __init__(self): + dist.init_process_group(backend="nccl") + self.group = dist.group.WORLD + + self.local_rank = dist.get_rank(self.group) + self.world_size = dist.get_world_size(self.group) + device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(device) + print(f"[INFO]: Local rank: {self.local_rank}, World size: {self.world_size}") + + def shutdown(self): + dist.destroy_process_group() + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + gc.collect() + torch.cuda.synchronize() + + def generate_hyper(self): + self.num_tokens = 80 + self.hidden_size = 4096 + self.vocab_size = 152064 + self.dtype = torch.bfloat16 + self.iterations = 5 + + def generate_forward_inputs(self): + hidden = (torch.empty((self.num_tokens, self.hidden_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) + weight = (torch.empty((self.hidden_size, self.vocab_size), dtype=self.dtype, + device="cuda").uniform_(-0.5, 0.5).requires_grad_()) + labels = torch.randint(0, self.vocab_size * self.world_size, (self.num_tokens,), device="cuda") + return hidden, weight, labels + + def generate_backward_inputs(self): + g_entropy = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)) + g_logprobs = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)) + return g_entropy, g_logprobs + + def verify_torch_itself(self): + self.cleanup() + self.generate_hyper() + + for i in range(self.iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + # forward pass + whole_weight = torch.empty((weight.shape[0], weight.shape[1] * self.world_size), + dtype=weight.dtype, + device=weight.device) + whole_weight_ref = [ + whole_weight[:, i * weight.shape[1]:(i + 1) * weight.shape[1]] for i in range(self.world_size) + ] + dist.all_gather(whole_weight_ref, weight, group=self.group) + whole_weight.requires_grad_() + + (single_logprobs, single_entropy) = run_torch_entropy(hidden, whole_weight, labels) + + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.group) + + torch.testing.assert_close(single_logprobs, tp_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(single_entropy, tp_entropy, atol=1e-4, rtol=1e-4) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + (single_d_hidden, single_d_weight) = torch.autograd.grad((single_entropy, single_logprobs), + (hidden, whole_weight), (g_entropy, g_logprobs), + retain_graph=False) + + (tp_d_hidden, tp_d_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(tp_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(tp_d_hidden, single_d_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(tp_d_weight, + single_d_weight[:, self.local_rank * tp_d_weight.shape[1]:(self.local_rank + 1) * + tp_d_weight.shape[1]]) #, + # atol=1e-3, rtol=1e-4) + if self.local_rank == 0: + print(f"[PASS] torch TP correctness is verified") + + def check_torch_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.group) + torch.cuda.synchronize() + forward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_tp_hidden, d_tp_weight) = torch.autograd.grad((tp_entropy, tp_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + torch.cuda.synchronize() + backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Torch Forward pass peak memory: {forward_max_memory:.2f} MB") + print(f"[INFO]: Torch Backward pass peak memory: {backward_max_memory:.2f} MB") + + def verify_kernel_correctness(self): + self.cleanup() + self.generate_hyper() + + torch_forward_latency = list() + torch_backward_latency = list() + kernel_forward_latency = list() + kernel_backward_latency = list() + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + for i in range(self.iterations): + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + start_event.record() + (torch_logprobs, torch_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.group) + end_event.record() + torch.cuda.synchronize() + torch_forward_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none", self.group) + end_event.record() + torch.cuda.synchronize() + kernel_forward_latency.append(start_event.elapsed_time(end_event)) + + torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) + + # backward pass + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + start_event.record() + (torch_d_hidden, torch_d_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + torch_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(torch_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + start_event.record() + (kernel_d_hidden, kernel_d_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), + (hidden, weight), (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + kernel_backward_latency.append(start_event.elapsed_time(end_event)) + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(kernel_d_hidden, op=dist.ReduceOp.SUM, group=self.group) + + torch.testing.assert_close(torch_d_hidden, kernel_d_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(torch_d_weight, kernel_d_weight, atol=1e-2, rtol=1e-4) + + # remove first latency + torch_forward_latency = torch_forward_latency[1:] + torch_backward_latency = torch_backward_latency[1:] + kernel_forward_latency = kernel_forward_latency[1:] + kernel_backward_latency = kernel_backward_latency[1:] + + if self.local_rank == 0: + print(f"\n[PASS]: Verified kernel forward & backward correctness.") + + print(f"[INFO]: Forward pass: Torch implementation average time: " + f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: torch implementation average time: " + f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: Kernel implementation average time: " + f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: kernel implementation average time: " + f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") + + def check_kernel_storage(self): + self.cleanup() + self.generate_hyper() + + hidden, weight, labels = self.generate_forward_inputs() + + # NOTE: we need to manually synchronize hidden and labels among Process Group + dist.broadcast(hidden, src=0, group=self.group) + dist.broadcast(labels, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none", self.group) + torch.cuda.synchronize() + kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + + g_entropy, g_logprobs = self.generate_backward_inputs() + # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group + dist.broadcast(g_entropy, src=0, group=self.group) + dist.broadcast(g_logprobs, src=0, group=self.group) + + torch.cuda.reset_peak_memory_stats() + (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + torch.cuda.synchronize() + kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + # NOTE: all-reduce on hidden is conducted outside the kernel + dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) + + if self.local_rank == 0: + print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") + print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + + if __name__ == "__main__": + # TP command: torchrun --standalone --nnodes=1 --nproc-per-node=2 tests/kernel/test_linear_cross_entropy.py + + # Check if running with torchrun (distributed mode) + is_distributed = False + if "RANK" in os.environ and "WORLD_SIZE" in os.environ: + is_distributed = True + print(f"[INFO]: Running in {'distributed' if is_distributed else 'non-distributed'} mode") + torch.manual_seed(233376 + int(os.environ.get("RANK", 0))) + # set_backward_method(BackwardEnum._Total_Fuse_MN) - test = TestLinearCrossEntropy() + if not is_distributed: + test = TestLinearCrossEntropy() + + test.verify_correctness() + test.check_storage_all() + else: + test = TestLinearCrossEntropy_TensorParallel() + + test.verify_torch_itself() + test.check_torch_storage() + test.verify_kernel_correctness() + test.check_kernel_storage() - test.verify_correctness(100) - test.check_storage_all() + test.shutdown() diff --git a/verl/utils/kernel/kernels.py b/verl/utils/kernel/kernels.py index 5fb57fb170b..edd8e78d55f 100644 --- a/verl/utils/kernel/kernels.py +++ b/verl/utils/kernel/kernels.py @@ -35,6 +35,7 @@ import typing from dataclasses import dataclass import torch +import torch.distributed as dist import triton import triton.language as tl @@ -111,6 +112,7 @@ def set_backward_method(backward_method: BackwardEnum): ) @triton.jit def efficient_entropy_kernel_general_mainloop( + rank, hidden_ptr, weight_ptr, labels_ptr, @@ -134,9 +136,6 @@ def efficient_entropy_kernel_general_mainloop( global_logprobs_ptr, stride_global_logprobs, global_logprobs_scalar_ptr, - d_scale_non_reduced_ptr, - stride_d_scale_non_reduced_m, - stride_d_scale_non_reduced_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, @@ -168,7 +167,6 @@ def efficient_entropy_kernel_general_mainloop( _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - _scale = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for n in range(0, num_pid_n): offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) @@ -205,12 +203,9 @@ def efficient_entropy_kernel_general_mainloop( _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) - label_mask = offs_bn[None, :] == labels[:, None] + label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] _logprobs += tl.sum(logits * label_mask, axis=1) - # preprocess for backward - _scale = coeff * _scale + tl.sum(exp_logits * logits, axis=1) - # store maximum offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_max_n = pid_n @@ -224,28 +219,23 @@ def efficient_entropy_kernel_general_mainloop( tl.store(entropy_b_ptrs, _entropy_b, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) # store logprobs - mask = (labels >= pid_n * vocab_per_split) & (labels < min((pid_n + 1) * vocab_per_split, vocab_size)) + vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size + vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size + mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) mask &= (offs_am < num_tokens) global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) tl.store(global_logprobs_ptrs, _logprobs, mask=mask) - # store d_scale_non_reduced - tl.store(d_scale_non_reduced_ptr + offs_max_n * stride_d_scale_non_reduced_n + - offs_max_m * stride_d_scale_non_reduced_m, - _scale, - mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits)) - @triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) @triton.jit def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n, num_tokens, num_splits, global_max_ptr, stride_global_max, accu_ptr, stride_accu_m, stride_accu_n, global_accu_ptr, stride_global_accu, entropy_b_ptr, stride_entropy_b_m, - stride_entropy_b_n, global_entropy_ptr, stride_global_entropy, - global_logprobs_ptr, stride_global_logprobs, global_logprobs_scalar_ptr, - reduction: int, d_scale_non_reduced_ptr, stride_d_scale_non_reduced_m, - stride_d_scale_non_reduced_n, d_scale_ptr, stride_d_scale, + stride_entropy_b_n, global_entropy_b_ptr, stride_global_entropy_b, + global_entropy_ptr, stride_global_entropy, global_logprobs_ptr, + stride_global_logprobs, global_logprobs_scalar_ptr, reduction: int, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): """ foward epilogue @@ -256,7 +246,6 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) - global_d_scale = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) max_ptrs = max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n @@ -281,24 +270,20 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) - # preprocess for backward - d_scale = tl.load(d_scale_non_reduced_ptr + offs_m[:, None] * stride_d_scale_non_reduced_m + - offs_n[None, :] * stride_d_scale_non_reduced_n, - mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), - other=0.0) - global_d_scale = _coeff * global_d_scale + tl.sum(_scale * d_scale, axis=1) - # store maximum_ptrs = global_max_ptr + offs_m * stride_global_max tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + # store entropy_b + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + # store entropy global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) - global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b - global_entropy_b = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy - tl.store(global_entropy_ptrs, global_entropy_b, mask=offs_m < num_tokens) + tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) # update logprobs global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) @@ -314,15 +299,102 @@ def efficient_entropy_triton_kernel_epilogue(max_ptr, stride_max_m, stride_max_n global_logprobs_scalar = tl.sum(global_logprobs, axis=0) / num_tokens.to(tl.float32) tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) - # store d_scale - d_scale_ptrs = d_scale_ptr + offs_m * stride_d_scale - tl.store(d_scale_ptrs, global_d_scale, mask=offs_m < num_tokens) + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], key=["num_tokens", "num_splits"]) +@triton.jit +def efficient_entropy_triton_kernel_epilogue_tp( + num_tokens, num_splits, reduced_max_ptr, stride_reduced_max_m, stride_reduced_max_n, original_max_ptr, + stride_original_max_m, stride_original_max_n, accu_ptr, stride_accu_m, stride_accu_n, entropy_b_ptr, + stride_entropy_b_m, stride_entropy_b_n, global_max_ptr, stride_global_max, global_accu_ptr, stride_global_accu, + global_entropy_b_ptr, stride_global_entropy_b, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load(reduced_max_ptr + offs_m[:, None] * stride_reduced_max_m + + offs_n[None, :] * stride_reduced_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + _original_max = tl.load(original_max_ptr + offs_m[:, None] * stride_original_max_m + + offs_n[None, :] * stride_original_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + _accu = tl.load(accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + + # local reduce-max + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # update entropy_b + _entropy_b = tl.load(entropy_b_ptr + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0) + global_entropy_b = _coeff * global_entropy_b + tl.sum(_scale * _entropy_b, axis=1) + + # store + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + tl.store(global_entropy_b_ptr + offs_m * stride_global_entropy_b, global_entropy_b, mask=offs_m < num_tokens) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def efficient_entropy_triton_epilogue_tp_update(num_tokens, logprobs_ptr, stride_logprobs, maximum_ptr, stride_maximum, + accumulate_ptr, stride_accumulate, entropy_b_ptr, stride_entropy_b, + entropy_ptr, stride_entropy, logprobs_scalar_ptr, reduction: int, + BLOCK_SIZE_M: tl.constexpr): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + + entropy_b = tl.load(entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens) + entropy_b = tl.fdiv(entropy_b, accumulate) + tl.store(entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens) + + entropy = tl.log(accumulate) + maximum - entropy_b + tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) + + logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + logprobs = maximum + tl.log(accumulate) - logprobs + + logprobs = -1 * logprobs + if reduction == 0: + tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + elif reduction == 1: + logprobs_scalar = tl.sum(logprobs, axis=0) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + elif reduction == 2: + logprobs_scalar = tl.sum(logprobs, axis=0) / num_tokens.to(tl.float32) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + + +_dedicated_stream, _dedicated_events = None, None -def efficient_entropy_foward(hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: +def efficient_entropy_forward( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[int] = 2, + dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: """ forward host function """ @@ -332,6 +404,15 @@ def efficient_entropy_foward(hidden: torch.Tensor, assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[0] + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + + if dist_process_group is not None and not hasattr(efficient_entropy_forward, "_initialized"): + global _dedicated_stream, _dedicated_events + _dedicated_stream = torch.cuda.Stream(hidden.device) + _dedicated_events = [torch.cuda.Event() for _ in range(2)] + efficient_entropy_forward._initialized = True + num_tokens, hidden_size = hidden.shape num_tokens = labels.shape[0] hidden_size, vocab_size = weight.shape @@ -341,7 +422,10 @@ def efficient_entropy_foward(hidden: torch.Tensor, REDUCTION = get_entropy_reduction_enum(reduction) if REDUCTION == EntropyReductionEnum._None: - logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) elif REDUCTION in (EntropyReductionEnum._Sum, EntropyReductionEnum._Mean): logprobs = torch.empty((), device=hidden.device, dtype=torch.float32) else: @@ -351,8 +435,11 @@ def efficient_entropy_foward(hidden: torch.Tensor, assert logprobs.is_contiguous() and entropy.is_contiguous() maximum = torch.empty_like(entropy) - acc = torch.empty_like(entropy) - assert maximum.is_contiguous() and acc.is_contiguous() + accumulate_and_entropy_b = torch.empty((num_tokens * 2,), device=hidden.device, dtype=torch.float32) + accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) + accumulate = accumulate_and_entropy_b_view[0, :] + entropy_b = accumulate_and_entropy_b_view[1, :] + assert maximum.is_contiguous() and accumulate.is_contiguous() and entropy_b.is_contiguous() vocab_per_split = 1024 assert vocab_per_split % 128 == 0 @@ -370,40 +457,59 @@ def efficient_entropy_foward(hidden: torch.Tensor, assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda - # preprocess for backward - _d_scale_non_reduced = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) - _d_scale = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) - assert _d_scale_non_reduced.is_contiguous() and _d_scale_non_reduced.is_cuda - # 1D kernel launch, then split the tile def mainloop_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) - efficient_entropy_kernel_general_mainloop[mainloop_grid](hidden, weight, labels, num_tokens, hidden_size, + efficient_entropy_kernel_general_mainloop[mainloop_grid](_rank, hidden, weight, labels, num_tokens, hidden_size, vocab_size, vocab_per_split, hidden.stride(0), hidden.stride(1), weight.stride(0), weight.stride(1), _max, _max.stride(0), _max.stride(1), _accu, _accu.stride(0), _accu.stride(1), _entropy_b, _entropy_b.stride(0), _entropy_b.stride(1), _logprobs, _logprobs.stride(0), - logprobs, _d_scale_non_reduced, - _d_scale_non_reduced.stride(0), - _d_scale_non_reduced.stride(1)) + logprobs) # reduction on maximum and maximum_indices def epilogue_grid(meta): return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) - efficient_entropy_triton_kernel_epilogue[epilogue_grid](_max, _max.stride(0), _max.stride(1), num_tokens, - num_splits, maximum, maximum.stride(0), _accu, - _accu.stride(0), _accu.stride(1), acc, - acc.stride(0), _entropy_b, _entropy_b.stride(0), - _entropy_b.stride(1), entropy, entropy.stride(0), _logprobs, - _logprobs.stride(0), logprobs, REDUCTION, - _d_scale_non_reduced, _d_scale_non_reduced.stride(0), - _d_scale_non_reduced.stride(1), _d_scale, - _d_scale.stride(0)) + if dist_process_group is None: + efficient_entropy_triton_kernel_epilogue[epilogue_grid](_max, _max.stride(0), _max.stride(1), num_tokens, + num_splits, maximum, maximum.stride(0), _accu, + _accu.stride(0), _accu.stride(1), accumulate, + accumulate.stride(0), _entropy_b, _entropy_b.stride(0), + _entropy_b.stride(1), entropy_b, entropy_b.stride(0), + entropy, entropy.stride(0), _logprobs, + _logprobs.stride(0), logprobs, REDUCTION) + else: + # tensor-parallel + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) + + torch.cuda.current_stream().record_event(_dedicated_events[0]) + with torch.cuda.stream(_dedicated_stream): + _dedicated_stream.wait_event(_dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) + _dedicated_stream.record_event(_dedicated_events[1]) + + efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid](num_tokens, num_splits, _max, _max.stride(0), + _max.stride(1), _max_backup, _max_backup.stride(0), + _max_backup.stride(1), _accu, _accu.stride(0), + _accu.stride(1), _entropy_b, _entropy_b.stride(0), + _entropy_b.stride(1), maximum, maximum.stride(0), + accumulate, accumulate.stride(0), entropy_b, + entropy_b.stride(0)) + torch.cuda.current_stream().wait_event(_dedicated_events[1]) + + dist.all_reduce(accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group) - return (logprobs, entropy, maximum, acc, _d_scale) + # update logprobs & entropy + efficient_entropy_triton_epilogue_tp_update[epilogue_grid](num_tokens, _logprobs, _logprobs.stride(0), maximum, + maximum.stride(0), accumulate, accumulate.stride(0), + entropy_b, entropy_b.stride(0), entropy, + entropy.stride(0), logprobs, REDUCTION) + + return (logprobs, entropy, maximum, accumulate, entropy_b) # NOTE: merge d_weight & d_hidden here, split along M & N @@ -422,11 +528,12 @@ def epilogue_grid(meta): ) @triton.jit def efficient_entropy_backward_kernel_general_mainloop_MN( - num_tokens: int, hidden_size: int, vocab_size: int, hidden_ptr, stride_hidden_m, stride_hidden_k, weight_ptr, - stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, stride_accu, - d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, d_scale_ptr, stride_d_scale, - d_hidden_ptr, stride_d_hidden_m, stride_d_hidden_k, d_weight_ptr, stride_d_weight_k, stride_d_weight_n, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + num_tokens: int, hidden_size: int, vocab_size: int, rank: int, hidden_ptr, stride_hidden_m, stride_hidden_k, + weight_ptr, stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, + stride_accu, d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, entropy_b_ptr, + stride_entropy_b, d_hidden_ptr, stride_d_hidden_m, stride_d_hidden_k, d_weight_ptr, stride_d_weight_k, + stride_d_weight_n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr): """ backward mainloop, where d_logits & d_hidden & d_weight are fused """ @@ -469,7 +576,8 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) d_logprobs = -1 * d_logprobs - d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, mask=offs_am < num_tokens, other=0.0) + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) @@ -497,13 +605,9 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( exp_logits = tl.exp(logits - maximum[:, None]) - d_pd = logits * -d_entropy[:, None] - mask = offs_bn[None, :] == labels[:, None] - d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:, None], exp_logits) * mask - - coeff = d_scale * d_entropy * accu_rcp * accu_rcp + d_logprobs * accu_rcp - d_logits = exp_logits * coeff[:, None] - d_logits += exp_logits * d_pd * accu_rcp[:, None] + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) # loop for d_weight & d_hidden for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): @@ -545,11 +649,11 @@ def efficient_entropy_backward_kernel_general_mainloop_MN( ) @triton.jit def efficient_entropy_backward_kernel_general_d_logits( - num_tokens: int, hidden_size: int, vocab_size: int, hidden_ptr, stride_hidden_m, stride_hidden_k, weight_ptr, - stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, stride_accu, - d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, d_scale_ptr, stride_d_scale, - d_logits_ptr, stride_d_logits_m, stride_d_logits_n, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + num_tokens: int, hidden_size: int, vocab_size: int, rank: int, hidden_ptr, stride_hidden_m, stride_hidden_k, + weight_ptr, stride_weight_k, stride_weight_n, labels_ptr, stride_labels, maximum_ptr, stride_maximum, accu_ptr, + stride_accu, d_entropy_ptr, stride_d_entropy, d_logprobs_ptr, stride_d_logprobs, reduction: int, entropy_b_ptr, + stride_entropy_b, d_logits_ptr, stride_d_logits_m, stride_d_logits_n, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """ backward d_logits """ @@ -592,16 +696,8 @@ def efficient_entropy_backward_kernel_general_d_logits( d_logprobs = tl.broadcast_to(d_logprobs, (BLOCK_SIZE_M,)) d_logprobs = -1 * d_logprobs - d_scale = tl.load(d_scale_ptr + offs_am * stride_d_scale, mask=offs_am < num_tokens, other=0.0) - - # d_acc_exp_logits = d_scale * d_entropy * accu_rcp * accu_rcp - # d_acc_exp_logits += d_logprobs * accu_rcp - # d_acc_exp_logits += d_entropy * accu_rcp - - # These equal to d_max = d_entropy - # d_max = d_scale * -d_entropy * accu_rcp - # d_max -= d_logprobs - # d_max += accu * d_acc_exp_logits + entropy_b_ptrs = entropy_b_ptr + offs_am * stride_entropy_b + entropy_b = tl.load(entropy_b_ptrs, mask=offs_am < num_tokens, other=0.0) hidden_ptrs = hidden_ptr + (offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k) weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) @@ -626,20 +722,9 @@ def efficient_entropy_backward_kernel_general_d_logits( exp_logits = tl.exp(logits - maximum[:, None]) - d_pd = logits * -d_entropy[:, None] - mask = offs_bn[None, :] == labels[:, None] - d_pd += tl.fdiv((-1.0 * d_logprobs * accu)[:, None], exp_logits) * mask - - coeff = d_scale * d_entropy * accu_rcp * accu_rcp + d_logprobs * accu_rcp - d_logits = exp_logits * coeff[:, None] - d_logits += exp_logits * d_pd * accu_rcp[:, None] - # d_logits += exp_logits * logits * (-d_entropy * accu_rcp)[:,None] - # d_logits -= tl.where(mask, d_logprobs[:,None], 0.0) - - # d_max is always zeros - # d_max = d_entropy - d_max - # mask = offs_bn[None,:] == maximum_indices[:,None] - # d_logits += tl.where(mask, d_max[:,None], 0.0) + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += d_entropy[:, None] * (-exp_logits * accu_rcp[:, None]) * (logits - entropy_b[:, None]) # store d_logits d_logits_ptrs = d_logits_ptr + offs_am[:, None] * stride_d_logits_m + offs_bn[None, :] * stride_d_logits_n @@ -648,15 +733,17 @@ def efficient_entropy_backward_kernel_general_d_logits( mask=(offs_am[:, None] < num_tokens) & (offs_bn[None, :] < vocab_size)) -def efficient_entropy_backward(dlogprobs: torch.Tensor, - dentropy: torch.Tensor, - hidden: torch.Tensor, - weight: torch.Tensor, - labels: torch.Tensor, - maximum: torch.Tensor, - acc: torch.Tensor, - d_scale: torch.Tensor, - reduction: typing.Optional[int] = 2) -> typing.List[torch.Tensor]: +def efficient_entropy_backward( + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + entropy_b: torch.Tensor, + reduction: typing.Optional[int] = 2, + dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: """ backward host function """ @@ -666,6 +753,9 @@ def efficient_entropy_backward(dlogprobs: torch.Tensor, assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[0] + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + num_tokens, hidden_size = hidden.shape num_tokens = labels.shape[0] hidden_size, vocab_size = weight.shape @@ -702,8 +792,8 @@ def efficient_entropy_backward(dlogprobs: torch.Tensor, assert vocab_per_split % 128 == 0 num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split - assert d_scale.is_contiguous() and d_scale.is_cuda - assert d_scale.shape == (num_tokens,) + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) if _BACKWARD == BackwardEnum._Total_Fuse_MN: @@ -714,6 +804,7 @@ def mainloop_grid(meta): num_tokens, hidden_size, vocab_size, + _rank, hidden, hidden.stride(0), hidden.stride(1), @@ -731,8 +822,8 @@ def mainloop_grid(meta): dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - d_scale, - d_scale.stride(0), + entropy_b, + entropy_b.stride(0), d_hidden, d_hidden.stride(0), d_hidden.stride(1), @@ -750,6 +841,7 @@ def d_logits_grid(meta): num_tokens, hidden_size, vocab_size, + _rank, hidden, hidden.stride(0), hidden.stride(1), @@ -767,8 +859,8 @@ def d_logits_grid(meta): dlogprobs, dlogprobs.stride(0) if REDUCTION == EntropyReductionEnum._None else 0, REDUCTION, - d_scale, - d_scale.stride(0), + entropy_b, + entropy_b.stride(0), _d_logits, _d_logits.stride(0), _d_logits.stride(1), diff --git a/verl/utils/kernel/linear_cross_entropy.py b/verl/utils/kernel/linear_cross_entropy.py index 1a8788ef831..20c6f2cbf4f 100644 --- a/verl/utils/kernel/linear_cross_entropy.py +++ b/verl/utils/kernel/linear_cross_entropy.py @@ -30,6 +30,7 @@ # limitations under the License. import torch +import torch.distributed as dist import typing from . import kernels @@ -41,28 +42,32 @@ def forward(ctx, hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, - reduction: typing.Optional[str] = "mean") -> typing.List[torch.Tensor]: + reduction: typing.Optional[str] = "mean", + dist_process_group: typing.Optional[dist.ProcessGroup] = None) -> typing.List[torch.Tensor]: with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) - logprobs, entropy, _maximum, _acc, _d_scale =\ - kernels.efficient_entropy_foward(hidden, weight, labels, REDUCTION) + logprobs, entropy, _maximum, _accumulate, _entropy_b =\ + kernels.efficient_entropy_forward(hidden, weight, labels, REDUCTION, + dist_process_group) - ctx.save_for_backward(hidden, weight, labels, _maximum, _acc, _d_scale) + ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) ctx.REDUCTION = REDUCTION - + ctx.dist_process_group = dist_process_group return logprobs, entropy @staticmethod def backward(ctx, dlogprobs: torch.Tensor, dentropy: torch.Tensor) -> typing.List[torch.Tensor]: with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): - (hidden, weight, labels, _maximum, _acc, _d_scale) = ctx.saved_tensors + (hidden, weight, labels, _maximum, _accumulate, _entropy_b) = ctx.saved_tensors REDUCTION = ctx.REDUCTION + dist_process_group = ctx.dist_process_group d_hidden, d_weight = kernels.efficient_entropy_backward(dlogprobs, dentropy, hidden, weight, labels, - _maximum, _acc, _d_scale, REDUCTION) + _maximum, _accumulate, _entropy_b, REDUCTION, + dist_process_group) - return (d_hidden, d_weight, None, None) + return (d_hidden, d_weight, None, None, None) linear_cross_entropy = LinearCrossEntropy.apply From c6907ffd5aefd1cc2b5dc9c5357e40f468d95ec6 Mon Sep 17 00:00:00 2001 From: Jianbing Dong Date: Tue, 25 Mar 2025 02:26:49 -0700 Subject: [PATCH 13/15] use max_memory_allocated() Signed-off-by: Jianbing Dong --- tests/kernel/test_linear_cross_entropy.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index b7cde24b1c6..e48d8c4864d 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -296,7 +296,7 @@ def check_storage(self, method_name, run_forward, reduction="none"): torch.cuda.reset_peak_memory_stats() (logprobs, entropy) = run_forward(hidden, weight, labels, reduction) torch.cuda.synchronize() - torch_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") g_entropy, g_logprobs = self.generate_backward_inputs() @@ -306,7 +306,7 @@ def check_storage(self, method_name, run_forward, reduction="none"): (g_entropy, g_logprobs), retain_graph=False) torch.cuda.synchronize() - torch_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") def check_storage_all(self): @@ -422,7 +422,7 @@ def check_torch_storage(self): torch.cuda.reset_peak_memory_stats() (tp_logprobs, tp_entropy) = run_torch_entropy_tp(hidden, weight, labels, self.group) torch.cuda.synchronize() - forward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + forward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 g_entropy, g_logprobs = self.generate_backward_inputs() # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group @@ -434,7 +434,7 @@ def check_torch_storage(self): (g_entropy, g_logprobs), retain_graph=False) torch.cuda.synchronize() - backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 # NOTE: all-reduce on hidden is conducted outside the kernel dist.all_reduce(d_tp_hidden, op=dist.ReduceOp.SUM, group=self.group) @@ -536,7 +536,7 @@ def check_kernel_storage(self): torch.cuda.reset_peak_memory_stats() (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none", self.group) torch.cuda.synchronize() - kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + kernel_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 g_entropy, g_logprobs = self.generate_backward_inputs() # NOTE: we need to manually synchronize g_entropy and g_logprobs among Process Group @@ -548,7 +548,7 @@ def check_kernel_storage(self): (g_entropy, g_logprobs), retain_graph=False) torch.cuda.synchronize() - kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 + kernel_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 # NOTE: all-reduce on hidden is conducted outside the kernel dist.all_reduce(d_kernel_hidden, op=dist.ReduceOp.SUM, group=self.group) From 5564356ac2ad02b1aea6e55a8740f91467b51d35 Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Wed, 26 Mar 2025 20:47:16 +0800 Subject: [PATCH 14/15] sign where should be fused with linear_cross_entropy --- .../llama/megatron/modeling_llama_megatron.py | 188 ++++++++++++++---- .../qwen2/megatron/modeling_qwen2_megatron.py | 187 ++++++++++++++--- verl/models/transformers/common.py | 11 + verl/models/transformers/llama.py | 94 ++++++++- verl/models/transformers/qwen2.py | 89 ++++++++- verl/models/transformers/qwen2_vl.py | 169 +++++++++++++++- verl/trainer/config/ppo_megatron_trainer.yaml | 2 - verl/trainer/config/ppo_trainer.yaml | 1 - verl/workers/actor/dp_actor.py | 40 ++-- verl/workers/actor/megatron_actor.py | 30 +-- 10 files changed, 690 insertions(+), 121 deletions(-) create mode 100644 verl/models/transformers/common.py diff --git a/verl/models/llama/megatron/modeling_llama_megatron.py b/verl/models/llama/megatron/modeling_llama_megatron.py index e9598202826..196900aae3d 100644 --- a/verl/models/llama/megatron/modeling_llama_megatron.py +++ b/verl/models/llama/megatron/modeling_llama_megatron.py @@ -35,6 +35,8 @@ from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config from .layers import ParallelLlamaDecoderLayer, ParallelLlamaRMSNorm, ParallelLlamaDecoderLayerRmPad +from verl.utils.kernel import linear_cross_entropy +from verl.models.transformers.common import FusedCausalLMOutputWithPast """ TODO: 1. Add weight initialization. Here we need to be careful on TP weight init. @@ -180,7 +182,10 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, + ) -> Union[Tuple, FusedCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -199,17 +204,45 @@ def forward( ) hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) @@ -315,7 +348,10 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, + ) -> Union[Tuple, FusedCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -346,25 +382,62 @@ def forward( max_seqlen_in_batch=max_seqlen_in_batch) hidden_states = outputs + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back, maybe done later + # move outside + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - return CausalLMOutputWithPast( + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) @@ -391,8 +464,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids, attention_mask, position_ids) + output = super().forward(input_ids, attention_mask, position_ids, labels, temperature, fuse_entropy_logprobs) output.logits = torch.squeeze(output.logits, dim=-1) return output @@ -578,6 +654,9 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -612,24 +691,62 @@ def forward( if self.post_process: hidden_states = outputs - # print(f'hidden_states.shape = {hidden_states.shape}') # torch.Size([4, 32, 4096]) - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back + # move outside + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) else: return outputs @@ -659,8 +776,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, temperature=temperature, fuse_entropy_logprobs=fuse_entropy_logprobs) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) return output diff --git a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py index 28af135ea5d..e64563cce1d 100644 --- a/verl/models/qwen2/megatron/modeling_qwen2_megatron.py +++ b/verl/models/qwen2/megatron/modeling_qwen2_megatron.py @@ -33,6 +33,8 @@ from verl.utils.megatron import sequence_parallel as sp_utils from verl.utils.megatron import tensor_parallel as tp_utils from verl.utils.megatron_utils import TransformerConfig, convert_config +from verl.utils.kernel import linear_cross_entropy +from verl.models.transformers.common import FusedCausalLMOutputWithPast from .layers import ParallelQwen2DecoderLayer, ParallelQwen2RMSNorm, ParallelQwen2DecoderLayerRmPad """ TODO: @@ -179,6 +181,9 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -198,20 +203,48 @@ def forward( ) hidden_states = outputs - logits = self.lm_head(hidden_states)[0] - - logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) - - logits = logits.float() - return CausalLMOutputWithPast( + + logits = None + log_probs = None + entropy = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + logits = self.lm_head(hidden_states)[0] + logits = tensor_parallel.gather_from_tensor_model_parallel_region(logits) + logits = logits.float() + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy ) + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -314,7 +347,10 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - ) -> Union[Tuple, CausalLMOutputWithPast]: + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, + ) -> Union[Tuple, FusedCausalLMOutputWithPast]: r""" Args: labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -345,25 +381,64 @@ def forward( max_seqlen_in_batch=max_seqlen_in_batch) hidden_states = outputs + + log_probs = None + entropy = None + logits = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back, move to later + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, + seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, + seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) - logits = self._forward_head(hidden_states) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension - # add removed padding back - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension + # add removed padding back + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - return CausalLMOutputWithPast( + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy, ) @@ -390,8 +465,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids, attention_mask, position_ids) + output = super().forward(input_ids, attention_mask, position_ids, labels, temperature, fuse_entropy_logprobs) output.logits = torch.squeeze(output.logits, dim=-1) return output @@ -626,6 +704,9 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -660,23 +741,62 @@ def forward( if self.post_process: hidden_states = outputs - logits = self._forward_head(hidden_states) - logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) - - # remove padding from sequence parallel - if self.megatron_config.sequence_parallel: - totol_nnz = cu_seqlens[-1] - logits = logits[:totol_nnz] # (total_nnz_padded) - # add removed padding back. If input is already rmpad, we let the caller pad_input - logits = pad_input(logits, indices, batch_size, - seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) - - return CausalLMOutputWithPast( + + log_probs = None + entropy = None + + if self.training and fuse_entropy_logprobs: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + responses = data['responses'] + response_length = responses.size(1) + + # labels is responses + labels = responses + label_length = labels.size(1) + label_mask = attention_mask[:, -label_length:] + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + # move to outside + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + logits = logits[:, -label_length - 1:-1].contiguous() + logits = logits.div_(temperature) + log_prob = vocab_parallel_log_probs_from_logits(logits, labels) + entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=label_mask) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + log_probs = pad_input(log_probs, indices, batch_size, + seqlen=sequence_length) + entropy = pad_input(entropy, indices, batch_size, + seqlen=sequence_length) + else: + logits = self._forward_head(hidden_states) + logits = torch.squeeze(logits, dim=1) # remove the artificial batch dimension # torch.Size([8, 32, 16]) + + # remove padding from sequence parallel + if self.megatron_config.sequence_parallel: + totol_nnz = cu_seqlens[-1] + logits = logits[:totol_nnz] # (total_nnz_padded) + # add removed padding back. If input is already rmpad, we let the caller pad_input + logits = pad_input(logits, indices, batch_size, + seqlen=sequence_length) # (batch_size, sequence_length, vocab_size) + + return FusedCausalLMOutputWithPast( loss=None, logits=logits, past_key_values=None, hidden_states=None, attentions=None, + log_probs=log_probs, + entropy=entropy, ) else: return outputs @@ -706,8 +826,11 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, + labels: Optional[torch.LongTensor] = None, + temperature: float = 1.0, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, CausalLMOutputWithPast]: - output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + output = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels, temperature=temperature, fuse_entropy_logprobs=fuse_entropy_logprobs) if self.post_process: output.logits = torch.squeeze(output.logits, dim=-1) return output diff --git a/verl/models/transformers/common.py b/verl/models/transformers/common.py new file mode 100644 index 00000000000..c5b9216d62e --- /dev/null +++ b/verl/models/transformers/common.py @@ -0,0 +1,11 @@ +import torch +from transformers.models.llama.modeling_llama import CausalLMOutputWithPast +from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast + +class FusedCausalLMOutputWithPast(CausalLMOutputWithPast): + log_probs: torch.Tensor + entropy: torch.Tensor + +class FusedQwen2VLCausalLMOutputWithPast(Qwen2VLCausalLMOutputWithPast): + log_probs: torch.Tensor + entropy: torch.Tensor \ No newline at end of file diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index 886ccb67d68..134e95bfaab 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from typing import Optional, Tuple, Callable +from typing import Optional, Tuple, Callable, Union, List import sys if sys.version_info >= (3, 11): from typing import Unpack @@ -26,6 +26,8 @@ from transformers.modeling_flash_attention_utils import _flash_attention_forward from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from verl.utils.kernel import linear_cross_entropy +from .common import FusedCausalLMOutputWithPast logger = logging.get_logger(__name__) @@ -224,3 +226,93 @@ def llama_attn_forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + +def llama_fused_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + temperature: Optional[float] = None, + **kwargs, +) -> Union[Tuple, FusedCausalLMOutputWithPast]: + """ + Codes patch to huggingface/transformers LlamaForCausalLM for Fused lmhead/Entropy/CrossEntropy. + """ + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + + # DO not support TP here + + logits = None + loss = None + log_probs = None + entropy = None + + if self.training: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + logits_rmpad = logits.squeeze(0) # (total_nnz, vocab_size) + + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + # Inferencce mode + logits = self.lm_head(hidden_states) + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return FusedCausalLMOutputWithPast( + loss=loss, + logits=logits, + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + \ No newline at end of file diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py index 63d9ae98b5e..c3ee1ede766 100644 --- a/verl/models/transformers/qwen2.py +++ b/verl/models/transformers/qwen2.py @@ -13,7 +13,7 @@ # limitations under the License. import torch -from typing import Optional, Tuple, Callable +from typing import Optional, Tuple, Callable, List, Union from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv from transformers.cache_utils import Cache @@ -22,6 +22,8 @@ from transformers.processing_utils import Unpack from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from verl.utils.kernel import linear_cross_entropy +from .common import FusedCausalLMOutputWithPast logger = logging.get_logger(__name__) @@ -224,3 +226,88 @@ def qwen2_attn_forward( attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) return attn_output, attn_weights + +def qwen2_fused_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs, +) -> Union[Tuple, FusedCausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + # DO not support TP here + + logits = None + loss = None + log_probs = None + entropy = None + + if self.training: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + logits_rmpad = logits.squeeze(0) # (total_nnz, vocab_size) + + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + # Inferencce mode + logits = self.lm_head(hidden_states) + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return FusedCausalLMOutputWithPast( + loss=loss, + logits=logits, + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + \ No newline at end of file diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 718b9ca6f5b..3b58fa10598 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Union import inspect import torch +from torch.nn import CrossEntropyLoss import os from transformers.utils import is_flash_attn_greater_or_equal from transformers.modeling_flash_attention_utils import _flash_attention_forward from verl.utils.ulysses import gather_heads_scatter_seq, gather_seq_scatter_heads, \ get_ulysses_sequence_parallel_world_size, validate_ulysses_config +from verl.utils.kernel import linear_cross_entropy +from .common import FusedCausalLMOutputWithPast try: from flash_attn import flash_attn_func, flash_attn_varlen_func @@ -288,3 +291,167 @@ def ulysses_flash_attn_forward( attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) return attn_output, None, None + +def qwen2_fused_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + pixel_values: Optional[torch.Tensor] = None, + pixel_values_videos: Optional[torch.FloatTensor] = None, + image_grid_thw: Optional[torch.LongTensor] = None, + video_grid_thw: Optional[torch.LongTensor] = None, + rope_deltas: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, FusedCausalLMOutputWithPast]: + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is None: + inputs_embeds = self.model.embed_tokens(input_ids) + if pixel_values is not None: + pixel_values = pixel_values.type(self.visual.get_dtype()) + image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + n_image_tokens = (input_ids == self.config.image_token_id).sum().item() + n_image_features = image_embeds.shape[0] + if n_image_tokens != n_image_features: + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + image_mask = ( + (input_ids == self.config.image_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.type(self.visual.get_dtype()) + video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) + n_video_tokens = (input_ids == self.config.video_token_id).sum().item() + n_video_features = video_embeds.shape[0] + if n_video_tokens != n_video_features: + raise ValueError( + f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" + ) + video_mask = ( + (input_ids == self.config.video_token_id) + .unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) + inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) + + if attention_mask is not None: + attention_mask = attention_mask.to(inputs_embeds.device) + + # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme + if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): + # calculate RoPE index once per generation in the pre-fill stage only + if ( + (cache_position is not None and cache_position[0] == 0) + or self.rope_deltas is None + or (past_key_values is None or past_key_values.get_seq_length() == 0) + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, image_grid_thw, video_grid_thw, attention_mask + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + batch_size, seq_length, _ = inputs_embeds.shape + delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + delta = delta.to(position_ids.device) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) + + hidden_states = outputs[0] + + # DO not support TP here + + logits = None + loss = None + log_probs = None + entropy = None + + if self.training: + # TOCHECK: whether labels is not None is needed + """ + To Squeeze: + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + logits_rmpad = logits.squeeze(0) # (total_nnz, vocab_size) + + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + """ + log_probs, entropy = linear_cross_entropy(hidden_states, self.lm_head.weights, labels, reduction="none") + else: + # Inferencce mode + logits = self.lm_head(hidden_states) + if labels is not None: + # Upcast to float if we need to compute the loss to avoid potential precision issues + logits = logits.float() + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return FusedCausalLMOutputWithPast( + loss=loss, + logits=logits, + log_probs=log_probs, + entropy=entropy, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + \ No newline at end of file diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 9513e391bba..afd697faaae 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -34,7 +34,6 @@ actor_rollout_ref: kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: True - use_fused_kernels: True optim: lr: 1e-6 clip_grad: 1.0 @@ -65,7 +64,6 @@ actor_rollout_ref: param_offload: False log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: null - use_fused_kernels: True rollout: name: vllm temperature: 1.0 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 7467cae167c..07eeea37c54 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -67,7 +67,6 @@ actor_rollout_ref: log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size - use_fused_kernels: True rollout: name: vllm temperature: 1.0 diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 2b74479857c..5a4c8666581 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -26,11 +26,10 @@ from verl.trainer.ppo import core_algos from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import logprobs_from_logits, masked_mean +from verl.utils.torch_functional import masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F -from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis @@ -53,12 +52,6 @@ def __init__( print(f'Actor use_remove_padding={self.use_remove_padding}') self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 - self.use_fused_kernels = config.use_fused_kernels - - self.compute_entropy_from_logits = ( - torch.compile(verl_F.entropy_from_logits, dynamic=True) - if self.config.get('use_torch_compile', True) # use torch compile by default - else verl_F.entropy_from_logits) def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -113,19 +106,14 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, attention_mask=None, position_ids=position_ids_rmpad, **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating - logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - - if not self.use_fused_kernels: - logits_rmpad.div_(temperature) - # compute entropy - entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) + use_cache=False, + labels=input_ids_rmpad_rolled, + temperature=temperature, + output_logits=False) # prevent model thinks we are generating + entropy_rmpad = output.entropy - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) - else: - weights = torch.eye(logits_rmpad.size(-1), device=logits_rmpad.device) / temperature - log_probs, entropy_rmpad = linear_cross_entropy(logits_rmpad, weights, input_ids_rmpad_rolled) + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = output.log_probs # gather log_prob if sp > 1 if self.use_ulysses_sp: @@ -154,12 +142,12 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, attention_mask=attention_mask, position_ids=position_ids, **multi_modal_inputs, - use_cache=False) # prevent model thinks we are generating - logits = output.logits - logits.div_(temperature) - logits = logits[:, -response_length - 1:-1, :] # (bsz, response_length, vocab_size) - log_probs = logprobs_from_logits(logits, micro_batch['responses']) - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + use_cache=False, + labels=micro_batch['responses'], + temperature=temperature, + output_logits=False) # prevent model thinks we are generating + entropy = output.entropy + log_probs = output.log_probs return entropy, log_probs diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index 2c52c3e5971..79cdda3f240 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -159,14 +159,6 @@ def compute_log_prob(self, data: DataProto) -> torch.Tensor: """ data.batch = data.batch.contiguous() - def compute_logprobs_fn(output, data): - response = data['responses'] - response_length = response.size(1) - logits = output['logits'] - logits = logits[:, -response_length - 1:-1].contiguous() - log_probs = vocab_parallel_log_probs_from_logits(logits, response) - return {'log_probs': log_probs} - # We make recompute_old_log_prob by default here. # TODO (zhangchi.usc1992): actually, this function should only return log_prob and this logic should be handled by user outside recompute_old_log_prob = self.config.get('recompute_old_log_prob', True) @@ -179,7 +171,7 @@ def compute_logprobs_fn(output, data): response = batch['responses'] response_length = response.size(1) with torch.no_grad(): - output = self.forward_backward_batch(data, forward_only=True, post_process_fn=compute_logprobs_fn) + output = self.forward_backward_batch(data, forward_only=True) if mpu.is_pipeline_last_stage(ignore_virtual=True): # only on last rank. It should be on every tp rank log_probs = torch.cat([o['log_probs'] for o in output], dim=0) # (bs, seq_size) @@ -230,7 +222,7 @@ def make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: epochs=self.config.ppo_epochs, dataloader_kwargs={'shuffle': self.config.shuffle}) - def forward_backward_batch(self, data: DataProto, forward_only=False, post_process_fn=None): + def forward_backward_batch(self, data: DataProto, forward_only=False): """ We assume: - The model takes input: (input_ids, attention_mask, position_ids). No rmpad for the input @@ -262,12 +254,6 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce forward_backward_func = get_forward_backward_func() def loss_func(output, data, meta_info): - if forward_only: - if post_process_fn is None: - return 1.0, {'logits': output.logits} - else: - return 1.0, post_process_fn(output, data) - responses = data['responses'] response_length = responses.size(1) attention_mask = data['attention_mask'] @@ -279,17 +265,13 @@ def loss_func(output, data, meta_info): entropy_coeff = meta_info['entropy_coeff'] # compute policy loss - logits = output.logits - logits = logits[:, -response_length - 1:-1].contiguous() - logits_back = logits.clone() - log_prob = vocab_parallel_log_probs_from_logits(logits, responses) - logits = logits_back + log_prob = output.log_probs + entropy_loss = output.entropy pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob, log_prob=log_prob, advantages=advantages, eos_mask=response_mask, cliprange=clip_ratio) - entropy_loss = vocab_parallel_compute_entropy_loss(logits, eos_mask=response_mask) policy_loss = pg_loss - entropy_loss * entropy_coeff metrics = {} @@ -320,7 +302,9 @@ def forward_step(batch_iter, model): input_ids = batch['input_ids'] attention_mask = batch['attention_mask'] position_ids = batch['position_ids'] - output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids) + responses = data['responses'] + temperature = data['temperature'] + output = model(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=responses, temperature=temperature, output_logits=False) if forward_only: meta_info = None else: From 89a69109110f4ff086f8459e12676d9c39f27d09 Mon Sep 17 00:00:00 2001 From: gaoziyuan Date: Thu, 27 Mar 2025 16:55:19 +0800 Subject: [PATCH 15/15] add fsdp monkey patch --- verl/models/transformers/llama.py | 8 +++--- verl/models/transformers/monkey_patch.py | 22 ++++++++++++--- verl/models/transformers/qwen2.py | 8 +++--- verl/models/transformers/qwen2_vl.py | 34 +++++++++++++----------- 4 files changed, 47 insertions(+), 25 deletions(-) diff --git a/verl/models/transformers/llama.py b/verl/models/transformers/llama.py index 134e95bfaab..532cea9d9bb 100644 --- a/verl/models/transformers/llama.py +++ b/verl/models/transformers/llama.py @@ -242,6 +242,7 @@ def llama_fused_forward( cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, temperature: Optional[float] = None, + fuse_entropy_logprobs: bool = False, **kwargs, ) -> Union[Tuple, FusedCausalLMOutputWithPast]: """ @@ -278,7 +279,7 @@ def llama_fused_forward( log_probs = None entropy = None - if self.training: + if self.training and fuse_entropy_logprobs: # TOCHECK: whether labels is not None is needed """ To Squeeze: @@ -299,8 +300,9 @@ def llama_fused_forward( else: # Inferencce mode logits = self.lm_head(hidden_states) - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + # loss is not needed + # if labels is not None: + # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 2ee0d3f1a94..60748bebadd 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -90,14 +90,30 @@ def apply_monkey_patch(model: PreTrainedModel): # TODO: VLM models only, unify monkey patch to LLM models. if model.config.model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope - from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward - from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 - from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2 + from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward, qwen2_vl_fused_forward + from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2, Qwen2ForCasalLM + from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2, Qwen2_5_VLForConditionalGeneration + + Qwen2_5_VLForConditionalGeneration.forward = qwen2_vl_fused_forward + Qwen2ForCasalLM.forward = qwen2_fused_forward Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward + print("Monkey patch FlashAttention2.forward in Qwen2VL") return + elif model.config.model_type in ("llama", "qwen2"): + from verl.models.transformers.qwen2 import qwen2_fused_forward + from verl.models.transformers.llama import llama_fused_forward + + from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + from transformers.models.llama.modeling_llama import LlamaFroCausalLM + + LlamaFroCausalLM.forward = llama_fused_forward + Qwen2ForCausalLM.forward = qwen2_fused_forward + + print("Monkey patch forward in Qwen2 and Llama") + # transformers<=4.47.1 if hasattr(module, "_flash_attention_forward"): diff --git a/verl/models/transformers/qwen2.py b/verl/models/transformers/qwen2.py index c3ee1ede766..295537a9c4a 100644 --- a/verl/models/transformers/qwen2.py +++ b/verl/models/transformers/qwen2.py @@ -241,6 +241,7 @@ def qwen2_fused_forward( return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, + fuse_entropy_logprobs: bool = False, **kwargs, ) -> Union[Tuple, FusedCausalLMOutputWithPast]: @@ -273,7 +274,7 @@ def qwen2_fused_forward( log_probs = None entropy = None - if self.training: + if self.training and fuse_entropy_logprobs: # TOCHECK: whether labels is not None is needed """ To Squeeze: @@ -294,8 +295,9 @@ def qwen2_fused_forward( else: # Inferencce mode logits = self.lm_head(hidden_states) - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) + # loss is not needed + # if labels is not None: + # loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) if not return_dict: output = (logits,) + outputs[1:] diff --git a/verl/models/transformers/qwen2_vl.py b/verl/models/transformers/qwen2_vl.py index 3b58fa10598..e3de116fba7 100644 --- a/verl/models/transformers/qwen2_vl.py +++ b/verl/models/transformers/qwen2_vl.py @@ -292,7 +292,7 @@ def ulysses_flash_attn_forward( attn_output = self.o_proj(attn_output) return attn_output, None, None -def qwen2_fused_forward( +def qwen2_vl_fused_forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, @@ -310,6 +310,7 @@ def qwen2_fused_forward( video_grid_thw: Optional[torch.LongTensor] = None, rope_deltas: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, + fuse_entropy_logprobs: bool = False, ) -> Union[Tuple, FusedCausalLMOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -406,13 +407,13 @@ def qwen2_fused_forward( log_probs = None entropy = None - if self.training: + if self.training and fuse_entropy_logprobs: # TOCHECK: whether labels is not None is needed """ To Squeeze: slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) + logits = self.lm_head(hidden_states) logits_rmpad = logits.squeeze(0) # (total_nnz, vocab_size) @@ -427,19 +428,20 @@ def qwen2_fused_forward( else: # Inferencce mode logits = self.lm_head(hidden_states) - if labels is not None: - # Upcast to float if we need to compute the loss to avoid potential precision issues - logits = logits.float() - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) + # loss may not needed + # if labels is not None: + # # Upcast to float if we need to compute the loss to avoid potential precision issues + # logits = logits.float() + # # Shift so that tokens < n predict n + # shift_logits = logits[..., :-1, :].contiguous() + # shift_labels = labels[..., 1:].contiguous() + # # Flatten the tokens + # loss_fct = CrossEntropyLoss() + # shift_logits = shift_logits.view(-1, self.config.vocab_size) + # shift_labels = shift_labels.view(-1) + # # Enable model parallelism + # shift_labels = shift_labels.to(shift_logits.device) + # loss = loss_fct(shift_logits, shift_labels) if not return_dict: output = (logits,) + outputs[1:]