Skip to content
3 changes: 3 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,9 @@ def _derive_model_shapes(self):
self.num_attention_layers = self.num_hidden_layers
if "LongcatFlashForCausalLM" in self.hf_config.architectures:
self.num_attention_layers = self.num_hidden_layers * 2
if "IQuestLoopCoderForCausalLM" in self.hf_config.architectures:
loop_num = getattr(self.hf_text_config, "loop_num", 1)
self.num_attention_layers = int(self.num_hidden_layers * int(loop_num))
self.num_nextn_predict_layers = getattr(
self.hf_text_config, "num_nextn_predict_layers", None
)
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/layers/attention/flashinfer_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,13 @@ def forward_extend(
v_scale=layer.v_scale_float,
)
else:
# If `k`/`v` are not explicitly provided, fall back to the KV cache stored in
# `forward_batch.token_to_kv_pool` for this layer. This enables attention over
# previously cached context without re-materializing KV tensors (e.g., the
# IQuestLoopCoder path uses token_to_kv_pool as the KV source).
if k is None and v is None:
k = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)[0]
v = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)[1]
causal = True
if (
layer.is_cross_attention
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,12 @@ def initialize(self, min_per_gpu_memory: float):
self.start_layer = getattr(self.model, "start_layer", 0)
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
self.num_effective_layers = self.end_layer - self.start_layer

# For LoopCoder models, each loop has its own layer_id, so we need to multiply by loop_num
loop_num = getattr(self.model_config.hf_config, "loop_num", 1)
if loop_num > 1:
self.num_effective_layers = self.num_effective_layers * loop_num

assert (
(not model_has_mtp_layers)
or (self.spec_algorithm.is_none())
Expand Down
Loading
Loading