diff --git a/.github/workflows/kernels.yml b/.github/workflows/kernels.yml new file mode 100644 index 00000000000..987800221ba --- /dev/null +++ b/.github/workflows/kernels.yml @@ -0,0 +1,58 @@ +name: kernels +# latest version: Megatron-LM core_r0.11.0 https://github.com/NVIDIA/Megatron-LM/tree/core_r0.11.0 + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - .github/workflows/kernels.yml + pull_request: + branches: + - main + - v0.2.x + paths: + - "**/*.py" + - "verl/trainer/config/*.yaml" + - .github/workflows/kernels.yml + - "tests/e2e/*.sh" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +jobs: + e2e_gsm8k_megatron: + runs-on: [self-hosted, l20-0] + timeout-minutes: 40 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: whatcanyousee/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te2.0-megatron0.11.0-v0.0.6 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test] + - name: Testing LinearCrossEntropy Correction, Computation Time and Memory Consumption + run: | + python3 tests/kernel/test_linear_cross_entropy.py + - name: Testing VocabParallelEntropy + run: | + bash tests/kernel/run_vocab_parallel_entropy.sh \ No newline at end of file diff --git a/tests/kernel/run_vocab_parallel_entropy.sh b/tests/kernel/run_vocab_parallel_entropy.sh new file mode 100644 index 00000000000..ee7cd3a7568 --- /dev/null +++ b/tests/kernel/run_vocab_parallel_entropy.sh @@ -0,0 +1,18 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/usr/bin/env bash + +set -e -x +torchrun --nproc-per-node=8 --standalone tests/kernel/test_vocab_parallel_entropy.py \ No newline at end of file diff --git a/tests/kernel/test_linear_cross_entropy.py b/tests/kernel/test_linear_cross_entropy.py index 37069457fb9..cafe2b49c54 100644 --- a/tests/kernel/test_linear_cross_entropy.py +++ b/tests/kernel/test_linear_cross_entropy.py @@ -42,18 +42,38 @@ finally: from verl.utils.kernel import linear_cross_entropy, set_backward_method, BackwardEnum +import verl.utils.torch_functional as verl_F +from verl.utils.torch_functional import logprobs_from_logits -def run_torch_entropy(hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor) -> typing.List[torch.Tensor]: +compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True) + + +def run_torch_entropy(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction="none") -> typing.List[torch.Tensor]: logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] pd = torch.nn.functional.softmax(logits, dim=-1) # [num_tokens, vocab_size] entropy_a = torch.logsumexp(logits, dim=-1) # [num_tokens] entropy_b = torch.sum(pd * logits, dim=-1) # [num_tokens] entropy = entropy_a - entropy_b - logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction="none") # [num_tokens] + logprobs = torch.nn.functional.cross_entropy(logits, labels, reduction=reduction) # [num_tokens] logprobs = torch.neg(logprobs) return logprobs, entropy +def run_verl_actor_entropy(hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction="none") -> typing.List[torch.Tensor]: + logits = torch.matmul(hidden.to(torch.float32), weight.to(torch.float32)) # [num_tokens, vocab_size] + # compute entropy + entropy = compute_entropy_from_logits(logits) # ((total_nnz / sp) + pad) + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + logprobs = logprobs_from_logits(logits=logits, labels=labels) + return logprobs, entropy + + class TestLinearCrossEntropy: def cleanup(self): @@ -82,14 +102,14 @@ def generate_backward_inputs(self): g_logprobs = (torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)) return g_entropy, g_logprobs - def verify_correctness(self): + def verify_correctness(self, iterations=5): self.cleanup() self.generate_hyper() - iterations = 5 - torch_forward_latency = list() torch_backward_latency = list() + verl_forward_latency = list() + verl_backward_latency = list() kernel_forward_latency = list() kernel_backward_latency = list() @@ -97,7 +117,7 @@ def verify_correctness(self): end_event = torch.cuda.Event(enable_timing=True) for i in range(iterations): - print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r") + print(f"[INFO]: Iteration {i + 1} / {iterations}...", end='\r') hidden, weight, labels = self.generate_forward_inputs() start_event.record() @@ -106,14 +126,24 @@ def verify_correctness(self): torch.cuda.synchronize() torch_forward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (verl_logprobs, verl_entropy) = run_verl_actor_entropy(hidden, weight, labels) + end_event.record() + torch.cuda.synchronize() + verl_forward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none") end_event.record() torch.cuda.synchronize() kernel_forward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) torch.testing.assert_close(torch_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(verl_entropy, kernel_entropy, atol=1e-4, rtol=1e-4) # backward g_entropy, g_logprobs = self.generate_backward_inputs() @@ -126,6 +156,14 @@ def verify_correctness(self): torch.cuda.synchronize() torch_backward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() + (d_verl_hidden, d_verl_weight) = torch.autograd.grad((verl_entropy, verl_logprobs), (hidden, weight), + (g_entropy, g_logprobs), + retain_graph=False) + end_event.record() + torch.cuda.synchronize() + verl_backward_latency.append(start_event.elapsed_time(end_event)) + start_event.record() (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), @@ -134,12 +172,18 @@ def verify_correctness(self): torch.cuda.synchronize() kernel_backward_latency.append(start_event.elapsed_time(end_event)) + torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-4) torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=1e-2, rtol=1e-4) + torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=1e-2, rtol=1e-4) # remove first latency torch_forward_latency = torch_forward_latency[1:] torch_backward_latency = torch_backward_latency[1:] + verl_forward_latency = verl_forward_latency[1:] + verl_backward_latency = verl_backward_latency[1:] kernel_forward_latency = kernel_forward_latency[1:] kernel_backward_latency = kernel_backward_latency[1:] @@ -149,54 +193,41 @@ def verify_correctness(self): f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms") print(f"[INFO]: Backward pass: torch implementation average time: " f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms") + print(f"[INFO]: Forward pass: VeRL implementation average time: " + f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms") + print(f"[INFO]: Backward pass: VeRL implementation average time: " + f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms") print(f"[INFO]: Forward pass: Kernel implementation average time: " f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms") print(f"[INFO]: Backward pass: kernel implementation average time: " f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms") - def check_torch_storage(self): + def check_storage(self, method_name, run_forward, reduction="none"): self.cleanup() self.generate_hyper() hidden, weight, labels = self.generate_forward_inputs() torch.cuda.reset_peak_memory_stats() - (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels) + (logprobs, entropy) = run_forward(hidden, weight, labels, reduction) torch.cuda.synchronize() torch_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Torch Forward pass peak memory: {torch_max_memory:.2f} MB") + print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB") g_entropy, g_logprobs = self.generate_backward_inputs() torch.cuda.reset_peak_memory_stats() - (d_torch_hidden, d_torch_weight) = torch.autograd.grad((torch_entropy, torch_logprobs), (hidden, weight), + (d_torch_hidden, d_torch_weight) = torch.autograd.grad((entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False) torch.cuda.synchronize() torch_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Torch Backward pass peak memory: {torch_backward_max_memory:.2f} MB") - - def check_kernel_storage(self): - self.cleanup() - self.generate_hyper() + print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB") - hidden, weight, labels = self.generate_forward_inputs() - - torch.cuda.reset_peak_memory_stats() - (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, "none") - torch.cuda.synchronize() - kernel_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Kernel Forward pass peak memory: {kernel_max_memory:.2f} MB") - - g_entropy, g_logprobs = self.generate_backward_inputs() - - torch.cuda.reset_peak_memory_stats() - (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad((kernel_entropy, kernel_logprobs), (hidden, weight), - (g_entropy, g_logprobs), - retain_graph=False) - torch.cuda.synchronize() - kernel_backward_max_memory = torch.cuda.max_memory_reserved() / 1024 / 1024 - print(f"[INFO]: Kernel Backward pass peak memory: {kernel_backward_max_memory:.2f} MB") + def check_storage_all(self): + self.check_storage("Torch", run_torch_entropy) + self.check_storage("VeRL", run_verl_actor_entropy) + self.check_storage("Kernel", linear_cross_entropy) if __name__ == "__main__": @@ -204,6 +235,5 @@ def check_kernel_storage(self): test = TestLinearCrossEntropy() - test.verify_correctness() - test.check_torch_storage() - test.check_kernel_storage() + test.verify_correctness(100) + test.check_storage_all() diff --git a/tests/kernel/test_vocab_parallel_entropy.py b/tests/kernel/test_vocab_parallel_entropy.py new file mode 100644 index 00000000000..5fd899f2d62 --- /dev/null +++ b/tests/kernel/test_vocab_parallel_entropy.py @@ -0,0 +1,118 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +os.environ['NCCL_DEBUG'] = 'WARN' + +import torch +import torch.distributed + +from verl.utils.megatron.tensor_parallel import vocab_parallel_entropy +from verl.utils.torch_functional import logprobs_from_logits, entropy_from_logits + +from verl.utils.debug import log_gpu_memory_usage + +from megatron.core import mpu + + +class Utils: + world_size = torch.cuda.device_count() + rank = int(os.environ.get('LOCAL_RANK', '0')) + + @staticmethod + def initialize_distributed(): + print(f'Initializing torch.distributed with rank: {Utils.rank}, world_size: {Utils.world_size}') + torch.cuda.set_device(Utils.rank % torch.cuda.device_count()) + init_method = 'tcp://' + master_ip = os.getenv('MASTER_ADDR', 'localhost') + master_port = os.getenv('MASTER_PORT', '7000') + init_method += master_ip + ':' + master_port + torch.distributed.init_process_group(backend='nccl', + world_size=Utils.world_size, + rank=Utils.rank, + init_method=init_method) + print(f'successfully created process group') + + @staticmethod + def destroy_model_parallel(): + mpu.destroy_model_parallel() + torch.distributed.barrier() + torch.distributed.destroy_process_group() + + @staticmethod + def initialize_model_parallel(tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None): + mpu.destroy_model_parallel() + if not torch.distributed.is_initialized(): + Utils.initialize_distributed() + mpu.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size, pipeline_model_parallel_split_rank) + + +def test_vocab_parallel_entropy(): + # check vocab_parallel_entropy + Utils.world_size = 8 + Utils.initialize_model_parallel(8, 1) + + batch_size = 2 + seqlen = 128 + vocab_size = 155136 + + logits = torch.randn(batch_size * seqlen, vocab_size, device='cuda', requires_grad=True) + target = torch.randint(low=0, high=vocab_size, size=(batch_size * seqlen,), device='cuda', dtype=torch.int64) + + # broadcast across tp + torch.distributed.broadcast(logits, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + torch.distributed.broadcast(target, + mpu.get_tensor_model_parallel_src_rank(), + group=mpu.get_tensor_model_parallel_group()) + + tp_rank = mpu.get_tensor_model_parallel_rank() + vocab_size_per_tp = vocab_size // mpu.get_tensor_model_parallel_world_size() + + # get the local logits of each tp + vocab_parallel_logits = logits.clone().detach()[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * + vocab_size_per_tp].requires_grad_() + logits.grad = None + vocab_parallel_logits.grad = None + + log_gpu_memory_usage('begin') + output_entropy = vocab_parallel_entropy(vocab_parallel_logits) + log_gpu_memory_usage('after forward') + grad_output = torch.randn_like(output_entropy) + output_entropy.backward(grad_output) + log_gpu_memory_usage('after backward') + + target_entropy = entropy_from_logits(logits) + torch.testing.assert_close(output_entropy, target_entropy) + target_entropy.backward(grad_output) + torch.testing.assert_close(logits.grad[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits.grad) + # make sure logits is not altered + torch.testing.assert_close(logits[:, tp_rank * vocab_size_per_tp:(tp_rank + 1) * vocab_size_per_tp], + vocab_parallel_logits) + + if mpu.get_tensor_model_parallel_rank() == 0: + print('test_vocab_parallel_entropy passes') + + Utils.destroy_model_parallel() + + +if __name__ == '__main__': + test_vocab_parallel_entropy() diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index c903934b8bb..a9f9134b5ea 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -34,6 +34,7 @@ actor_rollout_ref: kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: True + use_fused_kernels: True optim: lr: 1e-6 clip_grad: 1.0 diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index eab987d3aad..09e5e702f60 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -39,6 +39,7 @@ actor_rollout_ref: ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size + use_fused_kernels: True optim: lr: 1e-6 lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio. diff --git a/verl/utils/megatron/tensor_parallel.py b/verl/utils/megatron/tensor_parallel.py index 032cabbedbe..bdbee9e04f5 100644 --- a/verl/utils/megatron/tensor_parallel.py +++ b/verl/utils/megatron/tensor_parallel.py @@ -101,19 +101,21 @@ class _VocabParallelEntropy(torch.autograd.Function): @staticmethod def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: + + @torch.compile(dynamic=True) + def mul_reduce(a, b): + return (a * b).sum(dim=-1, keepdim=True) + logits_max = vocab_parallel_logits.max(dim=-1, keepdim=True).values dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=mpu.get_tensor_model_parallel_group()) normalized_vocab_parallel_logits = vocab_parallel_logits - logits_max - normalized_exp_logits = normalized_vocab_parallel_logits.exp() + normalized_exp_logits = normalized_vocab_parallel_logits.exp_() normalized_sum_exp_logits = normalized_exp_logits.sum(dim=-1, keepdim=True) dist.all_reduce(normalized_sum_exp_logits, group=mpu.get_tensor_model_parallel_group()) - softmax_logits = normalized_exp_logits / normalized_sum_exp_logits + softmax_logits = normalized_exp_logits.div(normalized_sum_exp_logits) # This consume too much VRAM, causing OOM, try optimize # sum_softmax_times_logits = (softmax_logits * vocab_parallel_logits).sum(dim=-1, keepdim=True) - original_shape = softmax_logits.shape - sum_softmax_times_logits = torch.bmm(softmax_logits.view(-1, 1, original_shape[-1]), - vocab_parallel_logits.view(-1, original_shape[-1], - 1)).view(original_shape[:-1] + (1,)) + sum_softmax_times_logits = mul_reduce(softmax_logits, vocab_parallel_logits) dist.all_reduce(sum_softmax_times_logits, group=mpu.get_tensor_model_parallel_group()) entropy = logits_max + normalized_sum_exp_logits.log() - sum_softmax_times_logits ctx.save_for_backward(vocab_parallel_logits, softmax_logits, sum_softmax_times_logits) @@ -122,8 +124,14 @@ def forward(ctx, vocab_parallel_logits: torch.Tensor) -> torch.Tensor: @staticmethod def backward(ctx, grad_output: torch.Tensor) -> torch.Tensor: vocab_parallel_logits, softmax_logits, sum_softmax_times_logits = ctx.saved_tensors - grad_input = grad_output.unsqueeze(dim=-1) * softmax_logits * (sum_softmax_times_logits - vocab_parallel_logits) - return grad_input + # reuse softmax_logits as grad + vocab_parallel_logits.sub_(sum_softmax_times_logits) + softmax_logits.mul_(vocab_parallel_logits) + softmax_logits.mul_(grad_output.unsqueeze(dim=-1)) + # recover vocab_parallel_logits + vocab_parallel_logits.add_(sum_softmax_times_logits) + softmax_logits.mul_(-1) + return softmax_logits def vocab_parallel_entropy(vocab_parallel_logits: torch.Tensor) -> torch.Tensor: diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 32bfda2278f..4dc954c5c92 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -25,9 +25,9 @@ try: from flash_attn.ops.triton.cross_entropy import cross_entropy_loss - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True + FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = True except ImportError: - FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False + FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE = False def gather_from_labels(data, label): @@ -49,7 +49,7 @@ def logprobs_from_logits(logits, labels): """ See: https://github.com/pytorch/pytorch/issues/563#issuecomment-330103591 """ - if FLAH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: + if FLASH_ATTN_CROSS_ENTROPY_LOSS_AVAILABLE: batch_dim = logits.shape[:-1] last_dim = logits.shape[-1] logits = logits.reshape(-1, last_dim) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index eff5ffa26d9..2b74479857c 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -30,6 +30,7 @@ from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F +from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis @@ -52,6 +53,7 @@ def __init__( print(f'Actor use_remove_padding={self.use_remove_padding}') self.ulysses_sequence_parallel_size = self.config.ulysses_sequence_parallel_size self.use_ulysses_sp = self.ulysses_sequence_parallel_size > 1 + self.use_fused_kernels = config.use_fused_kernels self.compute_entropy_from_logits = ( torch.compile(verl_F.entropy_from_logits, dynamic=True) @@ -114,13 +116,16 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, use_cache=False) # prevent model thinks we are generating logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) - logits_rmpad.div_(temperature) + if not self.use_fused_kernels: + logits_rmpad.div_(temperature) + # compute entropy + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) - # compute entropy - entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) - - # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + else: + weights = torch.eye(logits_rmpad.size(-1), device=logits_rmpad.device) / temperature + log_probs, entropy_rmpad = linear_cross_entropy(logits_rmpad, weights, input_ids_rmpad_rolled) # gather log_prob if sp > 1 if self.use_ulysses_sp: