Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ def _derive_model_shapes(self):
self.num_attention_layers = self.num_hidden_layers
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
self.num_attention_layers = self.num_hidden_layers * 2
if "IQuestLoopCoderForCausalLM" in self.hf_config.architectures:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if condition should likely be an elif. If hf_config.architectures could contain both LongcatFlashForCausalLM and IQuestLoopCoderForCausalLM, the current logic would have the calculation for IQuestLoopCoderForCausalLM overwrite the one for LongcatFlashForCausalLM. Using elif would make the choices mutually exclusive, which seems more correct for different model architectures.

loop_num = getattr(self.hf_text_config, "loop_num", 1)
self.num_attention_layers = int(self.num_hidden_layers * int(loop_num))
self.num_nextn_predict_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", None
)
Expand Down
33 changes: 26 additions & 7 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,14 +760,33 @@ def forward_extend(
logits_soft_cap = layer.logit_cap

q = q.contiguous()
if not self.forward_metadata.use_ragged:
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
# For cross-layer KV cache sharing (e.g., LoopCoder), k and v can be None
# When k is None, we read from KV cache instead of computing attention
if k is None:
# Read from KV cache (similar to decode mode)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
causal=not layer.is_cross_attention,
sm_scale=layer.scaling,
window_left=(
layer.sliding_window_size
if not (
self.forward_metadata.multi_item_params
and self.forward_metadata.multi_item_params.is_enabled()
)

else -1
),
logits_soft_cap=logits_soft_cap,
k_scale=layer.k_scale_float,
v_scale=layer.v_scale_float,
)
elif not self.forward_metadata.use_ragged:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
o = prefill_wrapper_paged.forward(
q.view(-1, layer.tp_q_head_num, layer.head_dim),
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
Expand Down
Loading