Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/advanced_features/server_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
| `--enable-return-hidden-states` | Enable returning hidden states with responses. | `False` | bool flag (set to enable) |
| `--scheduler-recv-interval` | The interval to poll requests in scheduler. Can be set to >1 to reduce the overhead of this. | `1` | Type: int |
| `--numa-node` | Sets the numa node for the subprocesses. i-th element corresponds to i-th subprocess. | `None` | List[int] |
| `--enable-attn-tp-input-scattered` | Allow input of attention to be scattered when only using tensor parallelism, to reduce the computational load of operations such as qkv latent. | `False` | bool flag (set to enable) |
## Debug tensor dumps
| Argument | Description | Defaults | Options |
Expand Down
199 changes: 190 additions & 9 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import logging
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import Dict, List, Optional
from typing import Callable, Dict, List, Optional, Tuple

import torch

from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
Expand Down Expand Up @@ -59,9 +61,10 @@
prepare_weight_cache,
)

_is_cuda = is_cuda()
_is_flashinfer_available = is_flashinfer_available()
_is_sm90_supported = is_cuda() and is_sm90_supported()
_is_sm100_supported = is_cuda() and is_sm100_supported()
_is_sm90_supported = _is_cuda and is_sm90_supported()
_is_sm100_supported = _is_cuda and is_sm100_supported()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()

Expand Down Expand Up @@ -92,6 +95,119 @@ def model_input_output():
return ScatterMode.TP_ATTN_FULL


class AttentionInputs:

def __init__(
self,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
qkv_latent_func: Callable,
):
self.hidden_states_local = hidden_states
self.forward_batch = forward_batch
self.qkv_latent_func = qkv_latent_func
self.hidden_states_ = None
self.qkv_latent_ = None

def tp_all_gather_hidden_states(self, hidden_states, forward_batch):
total_tokens = forward_batch.input_ids.shape[0]
output = hidden_states.new_empty((total_tokens, hidden_states.shape[-1]))
get_tp_group().all_gather_into_tensor(output, hidden_states)
return output

def fetch_qkv_latent(self):
if self.qkv_latent_ is not None:
return self.qkv_latent_
assert self.qkv_latent_func is not None
self.qkv_latent_ = self.qkv_latent_func(
self.hidden_states_local, self.forward_batch
)
if get_attn_tp_context().input_scattered:
self.qkv_latent_ = self.tp_all_gather_hidden_states(
self.qkv_latent_, self.forward_batch
)
return self.qkv_latent_

def fetch_hidden_states(self):
if self.hidden_states_ is not None:
return self.hidden_states_
self.hidden_states_ = self.hidden_states_local
if get_attn_tp_context().input_scattered:
self.hidden_states_ = self.tp_all_gather_hidden_states(
self.hidden_states_, self.forward_batch
)
return self.hidden_states_


class AttnTpContext:
def __init__(self):
self.allow_input_scattered = False
self.input_scattered_ = False
self.attn_inputs_: Optional[AttentionInputs] = None

def init_context(self, q_lora_rank, is_nsa):
self.allow_input_scattered = (
get_global_server_args().enable_attn_tp_input_scattered
and _is_cuda
and q_lora_rank is not None
and not is_nsa
and get_tensor_model_parallel_world_size() > 1
and not is_dp_attention_enabled()
and get_moe_a2a_backend().is_none()
and not enable_moe_dense_fully_dp()
and not get_global_server_args().enable_piecewise_cuda_graph
and get_global_server_args().speculative_algorithm != "EAGLE3"
)
if get_global_server_args().enable_attn_tp_input_scattered:
if not self.allow_input_scattered:
logging.info(
"attn_tp_input_scattered is not enabled while other conditions are not met"
)
else:
logging.info("attn_tp_input_scattered is enabled")

def use_input_scattered(self, forward_batch: ForwardBatch):
return (
self.allow_input_scattered
and forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
and forward_batch.input_ids is not None
and not forward_batch.can_run_tbo
)

@property
def input_scattered(self):
return self.input_scattered_

def set_attn_inputs(self, attn_inputs: AttentionInputs):
self.attn_inputs_ = attn_inputs

def fetch_qkv_latent(self):
assert self.attn_inputs_ is not None
return self.attn_inputs_.fetch_qkv_latent()

def fetch_hidden_states(self):
assert self.attn_inputs_ is not None
return self.attn_inputs_.fetch_hidden_states()

@contextmanager
def maybe_input_scattered(self, forward_batch: ForwardBatch):
flag = self.use_input_scattered(forward_batch)
old_flag = self.input_scattered
self.input_scattered_ = flag
yield
self.input_scattered_ = old_flag
self.attn_inputs_ = None


ATTN_TP_CONTEXT = AttnTpContext()


def get_attn_tp_context():
return ATTN_TP_CONTEXT


@dataclass
class _LayerModeComputationContext:
num_layers: int
Expand Down Expand Up @@ -188,12 +304,14 @@ def __init__(
# Reduce scatter requires skipping all-reduce in model code after MoE/MLP, so only enable for models which have that implemented. Remove flag once done for all models that use LayerCommunicator.
allow_reduce_scatter: bool = False,
is_last_layer: bool = False,
qkv_latent_func: Optional[Callable] = None,
):
self.layer_scatter_modes = layer_scatter_modes
self.input_layernorm = input_layernorm
self.post_attention_layernorm = post_attention_layernorm
self.allow_reduce_scatter = allow_reduce_scatter
self.is_last_layer = is_last_layer
self.qkv_latent_func = qkv_latent_func

self._context = CommunicateContext.init_new()
self._communicate_simple_fn = CommunicateSimpleFn.get_fn(
Expand Down Expand Up @@ -252,6 +370,11 @@ def prepare_attn(
forward_batch: ForwardBatch,
quant_format: str = "",
):
if get_attn_tp_context().input_scattered:
hidden_states, residual = self._tp_reduce_scatter(
hidden_states,
residual,
)
if hidden_states.shape[0] == 0:
residual = hidden_states
else:
Expand Down Expand Up @@ -335,9 +458,32 @@ def prepare_attn(
forward_batch=forward_batch,
context=self._context,
)

if self.qkv_latent_func is not None:
attn_inputs = AttentionInputs(
hidden_states, forward_batch, self.qkv_latent_func
)
get_attn_tp_context().set_attn_inputs(attn_inputs)
return hidden_states, residual

def _tp_reduce_scatter(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
return hidden_states, hidden_states
assert (
hidden_states.shape[0] % self._context.tp_size == 0
), f"Expected total tokens {hidden_states.shape[0]} % tp_size {self._context.tp_size} to be 0"
local_tokens = hidden_states.shape[0] // self._context.tp_size
output = hidden_states.new_empty(local_tokens, *hidden_states.shape[1:])
get_tp_group().reduce_scatter_tensor(output, hidden_states)
if residual is not None:
residual = residual.tensor_split(self._context.tp_size)[
self._context.tp_rank
]
return output, residual

Comment on lines +468 to +486
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot Oct 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: reduce_scatter with tensor_split breaks for L % tp_size != 0

NCCL reduce_scatter requires equal chunk sizes; tensor_split yields uneven chunks when total_tokens is not divisible by tp_size, causing runtime errors.

Apply a pad-to-equal-chunk fallback (or use reduce_scatterv if available):

 def _tp_reduce_scatter(
     self,
     hidden_states: torch.Tensor,
     residual: torch.Tensor,
 ) -> Tuple[torch.Tensor, torch.Tensor]:
-    if hidden_states.shape[0] == 0:
-        return hidden_states, hidden_states
-
-    inputs = list(hidden_states.tensor_split(self._context.tp_size))
-    scattered_local_tokens = inputs[self._context.tp_rank]
-    hidden_states = get_tp_group().reduce_scatter(scattered_local_tokens, inputs)
-
-    if residual is not None:
-        residual = residual.tensor_split(self._context.tp_size)[
-            self._context.tp_rank
-        ]
-    return hidden_states, residual
+    total = hidden_states.shape[0]
+    if total == 0:
+        return hidden_states, hidden_states
+    tp_size = self._context.tp_size
+    rank = self._context.tp_rank
+    # Equal-size path
+    if total % tp_size == 0:
+        chunk = total // tp_size
+        inputs = list(hidden_states.split(chunk, dim=0))
+        out = torch.empty_like(inputs[rank])
+        get_tp_group().reduce_scatter(out, inputs)
+        hidden_states = out
+        if residual is not None:
+            residual = residual.split(chunk, dim=0)[rank]
+        return hidden_states, residual
+    # Fallback: pad to equal chunks, then slice local
+    max_chunk = (total + tp_size - 1) // tp_size
+    pad = max_chunk * tp_size - total
+    if pad:
+        pad_shape = (pad,) + hidden_states.shape[1:]
+        hidden_states_padded = torch.cat(
+            [hidden_states, hidden_states.new_zeros(pad_shape)], dim=0
+        )
+    else:
+        hidden_states_padded = hidden_states
+    inputs = list(hidden_states_padded.split(max_chunk, dim=0))
+    out = torch.empty_like(inputs[rank])
+    get_tp_group().reduce_scatter(out, inputs)
+    local_len = total // tp_size + (1 if rank < (total % tp_size) else 0)
+    hidden_states = out[:local_len]
+    if residual is not None:
+        residual = residual.tensor_split(tp_size)[rank]
+    return hidden_states, residual

If reduce_scatterv is available in GroupCoordinator, prefer it; otherwise keep this padding path.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _tp_reduce_scatter(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
if hidden_states.shape[0] == 0:
return hidden_states, hidden_states
inputs = list(hidden_states.tensor_split(self._context.tp_size))
scattered_local_tokens = inputs[self._context.tp_rank]
hidden_states = get_tp_group().reduce_scatter(scattered_local_tokens, inputs)
if residual is not None:
residual = residual.tensor_split(self._context.tp_size)[
self._context.tp_rank
]
return hidden_states, residual
def _tp_reduce_scatter(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
total = hidden_states.shape[0]
if total == 0:
return hidden_states, hidden_states
tp_size = self._context.tp_size
rank = self._context.tp_rank
# Equal-size path
if total % tp_size == 0:
chunk = total // tp_size
inputs = list(hidden_states.split(chunk, dim=0))
out = torch.empty_like(inputs[rank])
get_tp_group().reduce_scatter(out, inputs)
hidden_states = out
if residual is not None:
residual = residual.split(chunk, dim=0)[rank]
return hidden_states, residual
# Fallback: pad to equal chunks, then slice local
max_chunk = (total + tp_size - 1) // tp_size
pad = max_chunk * tp_size - total
if pad:
pad_shape = (pad,) + hidden_states.shape[1:]
hidden_states_padded = torch.cat(
[hidden_states, hidden_states.new_zeros(pad_shape)], dim=0
)
else:
hidden_states_padded = hidden_states
inputs = list(hidden_states_padded.split(max_chunk, dim=0))
out = torch.empty_like(inputs[rank])
get_tp_group().reduce_scatter(out, inputs)
local_len = total // tp_size + (1 if rank < (total % tp_size) else 0)
hidden_states = out[:local_len]
if residual is not None:
residual = residual.tensor_split(tp_size)[rank]
return hidden_states, residual
🤖 Prompt for AI Agents
In python/sglang/srt/layers/communicator.py around lines 314 to 331, the current
use of tensor_split and reduce_scatter fails when hidden_states.shape[0] %
tp_size != 0 because NCCL reduce_scatter requires equal chunk sizes; update the
method to first branch: if the GroupCoordinator exposes a reduce_scatterv (or
similar variable-length API) call that with the original uneven splits;
otherwise compute a uniform chunk_size = ceil(L / tp_size), pad hidden_states
and residual on the sequence dimension to chunk_size * tp_size (with a safe pad
value) then reshape/split into equal chunks and call reduce_scatter; after the
operation trim the padded tokens from the returned hidden_states and residual
for the local rank so downstream shapes match the original L; ensure padding is
applied consistently on both tensors and removed before returning.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.distributed.reduce_scatter(output, input_list, op=<RedOpType.SUM: 0>, group=None, async_op=False)
input_list ([list] List of tensors to reduce and scatter. Uneven sized tensors are supported. Self-verification on local datasets such as GSM8K and MMLU shows normal results.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🌐 Web query:

Does PyTorch torch.distributed.reduce_scatter support uneven sized tensors in input_list?

💡 Result:

Short answer: No — reduce_scatter (and reduce_scatter_multigpu / reduce_scatter_tensor) expects the input pieces to be sized consistently for the scatter and the output tensor size must be the same across ranks; uneven-sized tensors in the input_list are not supported. See the PyTorch docs stating the input must be the concatenation/stack of equal-sized per-rank outputs and that list lengths/sizes must match across processes. [1][2]

Sources:

  • torch.distributed.reduce_scatter / reduce_scatter_tensor docs (requires output size same across ranks and input sized = world_size * output_size). [1]
  • torch.distributed.reduce_scatter_multigpu / reduce_scatter (input_tensor_lists / input_list sizes must match across processes). [2]

🌐 Web query:

PyTorch tensor_split with reduce_scatter uneven chunks behavior

💡 Result:

Short answer: reduce_scatter requires each rank's output (the scattered piece) to have the same shape/size across ranks, so you cannot directly use torch.tensor_split to create uneven chunks and feed them into torch.distributed.reduce_scatter (or reduce_scatter_tensor) — the input must be a concatenation/stack of equal-sized output pieces. For uneven-shard use-cases you must either (a) pad chunks to equal size before reduce_scatter and then trim after, or (b) use higher-level DTensor/Shard abstractions (uneven sharding is experimental) or a custom collective that supports per-rank sizes. [1][2]

References

  • reduce_scatter / reduce_scatter_tensor docs: output must have same size across ranks; input is concatenation/stack of those outputs. [1]
  • DTensor docs: uneven sharding behavior is experimental (use caution / different semantics). [2]

Sources
[1] PyTorch distributed reduce_scatter / reduce_scatter_tensor documentation. (pytorch.cadn.net.cn)
[2] PyTorch DTensor (redistribute / sharding notes about uneven sharding being experimental). (docs.pytorch.org)

Citations:


@xu-yfei, I appreciate you sharing the documentation reference. However, according to the official PyTorch documentation [1], reduce_scatter requires that the output tensor size must be the same across all ranks, meaning the input chunks must be equal-sized. The documentation states: "output must have the same size across ranks" and "input is the concatenation/stack of those outputs".

Your tests on GSM8K and MMLU may be passing because:

  1. The sequence lengths in your test batches happen to divide evenly by tp_size=8 (e.g., 1024, 2048, 4096)
  2. Accuracy metrics might not catch subtle numerical errors from misaligned tensor operations

To verify this, could you please test the specific failing case mentioned by @miter6:

  • Sequence length: 1023 tokens with tp_size=8

With tensor_split, this creates uneven chunks: [128, 128, 128, 128, 128, 128, 128, 127], which violates reduce_scatter's requirement. You can verify this with:

import torch
hidden_states = torch.randn(1023, 7168)
chunks = list(hidden_states.tensor_split(8))
print([c.shape[0] for c in chunks])  # Will show [128, 128, 128, 128, 128, 128, 128, 127]

Recommendation: Implement the padding approach suggested in the original review comment to handle cases where total_tokens % tp_size != 0, or verify that your get_tp_group().reduce_scatter() wrapper provides additional handling for uneven chunks (which the standard PyTorch API does not).

Could you confirm by running a test with exactly 1023 tokens on TP8?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The bench_serving verification is normal. One of the logs shows an input length of 1021:
Prefill batch. #new-seq: 1, #new-token: 1021, #cached-token: 3, token usage: 0.00, #running-req: 0, #queue-req: 0

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

def prepare_mlp(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -371,12 +517,17 @@ def postprocess_layer(
)

def should_use_reduce_scatter(self, forward_batch: ForwardBatch):
return (
self.allow_reduce_scatter
and self._communicate_summable_tensor_pair_fn
if not self.allow_reduce_scatter:
return False
if (
self._communicate_summable_tensor_pair_fn
is CommunicateSummableTensorPairFn._scatter_hidden_states
and forward_batch.dp_padding_mode.is_max_len()
)
):
return True
if get_attn_tp_context().input_scattered and not self.is_last_layer:
return True
return False

def should_fuse_mlp_allreduce_with_next_layer(
self, forward_batch: ForwardBatch
Expand All @@ -388,6 +539,9 @@ def should_fuse_mlp_allreduce_with_next_layer(
):
return False

if get_attn_tp_context().input_scattered:
return False

batch_size = (
forward_batch.input_ids.shape[0]
if hasattr(forward_batch, "input_ids")
Expand Down Expand Up @@ -422,6 +576,7 @@ class CommunicateContext:
attn_dp_size: int
tp_size: int
cache = None
tp_rank: int

def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b]
Expand All @@ -432,6 +587,7 @@ def init_new(cls):
attn_tp_size = get_attention_tp_size()
attn_dp_size = get_attention_dp_size()
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
process_group_sizes = {
ScatterMode.SCATTERED: 1,
ScatterMode.TP_ATTN_FULL: attn_tp_size,
Expand All @@ -444,6 +600,7 @@ def init_new(cls):
attn_tp_size=attn_tp_size,
attn_dp_size=attn_dp_size,
tp_size=tp_size,
tp_rank=tp_rank,
)


Expand Down Expand Up @@ -566,6 +723,14 @@ def _gather_hidden_states_and_residual(
*,
residual_input_mode,
):
if get_attn_tp_context().input_scattered:
return CommunicateWithAllReduceAndLayerNormFn._tp_all_reduce_with_scattered_residual(
hidden_states,
residual,
layernorm,
context,
)

if residual_input_mode == ScatterMode.SCATTERED and context.attn_tp_size > 1:
residual, local_residual = (
get_local_dp_buffer(),
Expand Down Expand Up @@ -637,6 +802,22 @@ def _scatter_hidden_states_and_residual(
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual

@staticmethod
def _tp_all_reduce_with_scattered_residual(
hidden_states: torch.Tensor,
residual: torch.Tensor,
layernorm: torch.nn.Module,
context: CommunicateContext,
):
if hidden_states.shape[0] == 0:
return hidden_states, hidden_states

scattered_states = hidden_states.tensor_split(context.tp_size)[context.tp_rank]
scattered_states += residual
residual = tensor_model_parallel_all_reduce(hidden_states)
hidden_states = layernorm(residual)
return hidden_states, residual


class CommunicateSummableTensorPairFn:
"""It is allowed to make (hidden_states, residual) := (hidden_states + residual, None) if needed."""
Expand Down
10 changes: 5 additions & 5 deletions python/sglang/srt/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
use_symmetric_memory,
)
from sglang.srt.layers.amx_utils import PackWeightMethod
from sglang.srt.layers.communicator import get_attn_tp_context
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
from sglang.srt.layers.parameter import BasevLLMParameter
from sglang.srt.layers.quantization.base_config import (
Expand Down Expand Up @@ -478,11 +479,10 @@ def forward(self, input_):
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
# Reduce across all the model parallel GPUs.
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
if not get_attn_tp_context().input_scattered:
# Reduce across all the model parallel GPUs.
output_parallel = tensor_model_parallel_all_reduce(output_parallel)
return output_parallel

def extra_repr(self) -> str:
s = f"num_embeddings={self.num_embeddings_per_partition}"
Expand Down
28 changes: 24 additions & 4 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@
import triton
import triton.language as tl

from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
from sglang.srt.distributed.parallel_state import (
get_moe_expert_parallel_world_size,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import (
DpPaddingMode,
Expand Down Expand Up @@ -766,6 +769,13 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
else:
bs = self.batch_size = num_tokens

# padding
self._pad_inputs_to_size(model_runner, num_tokens, bs)
self.global_num_tokens_cpu = global_num_tokens
global_num_tokens_pinned = torch.tensor(global_num_tokens, pin_memory=True)
self.global_num_tokens_gpu.copy_(global_num_tokens_pinned, non_blocking=True)

def _pad_inputs_to_size(self, model_runner: ModelRunner, num_tokens, bs):
# padding
self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens)
self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs)
Expand All @@ -788,9 +798,6 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
if self.encoder_lens is not None:
self.encoder_lens = self._pad_tensor_to_size(self.encoder_lens, bs)
self.positions = self._pad_tensor_to_size(self.positions, num_tokens)
self.global_num_tokens_cpu = global_num_tokens
global_num_tokens_pinned = torch.tensor(global_num_tokens, pin_memory=True)
self.global_num_tokens_gpu.copy_(global_num_tokens_pinned, non_blocking=True)

if self.mrope_positions is not None:
self.mrope_positions = self._pad_tensor_to_size(self.mrope_positions, bs)
Expand Down Expand Up @@ -818,6 +825,19 @@ def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
spec_info.hidden_states, num_tokens
)

def prepare_attn_tp_scatter_input(self, model_runner: ModelRunner):
from sglang.srt.layers.communicator import get_attn_tp_context

attn_tp_context = get_attn_tp_context()
input_scattered = attn_tp_context.use_input_scattered(self)
if not input_scattered:
return
assert self.forward_mode.is_extend()
tokens = self.input_ids.shape[0]
rank_size = get_tensor_model_parallel_world_size()
tokens_padded = (tokens + rank_size - 1) // rank_size * rank_size
self._pad_inputs_to_size(model_runner, tokens_padded, self.batch_size)

def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput):

self.forward_mode = getattr(self, "_original_forward_mode", self.forward_mode)
Expand Down
Loading
Loading