Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions python/sglang/srt/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Modify details for the adaptation of Qwen2 model.
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Dict, Iterable, Optional, Tuple, Union, List

import torch
from torch import nn
Expand Down Expand Up @@ -425,7 +425,6 @@ def __init__(
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
)

else:
# ranks other than the last rank will have a placeholder layer
self.lm_head = PPMissingLayer()
Expand Down Expand Up @@ -470,11 +469,14 @@ def forward(
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states

if self.pp_group.is_last_rank:
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return self.pooler(hidden_states, forward_batch)
Expand Down Expand Up @@ -612,6 +614,21 @@ def set_embed_and_head(self, embed, head):

def load_kv_cache_scales(self, quantization_param_path: str) -> None:
self.model.load_kv_cache_scales(quantization_param_path)

def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return

self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]


EntryClass = Qwen2ForCausalLM
21 changes: 19 additions & 2 deletions python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
from dataclasses import dataclass
from enum import Enum, auto
from typing import Any, Dict, Iterable, Optional, Tuple, Union
from typing import Any, Dict, Iterable, Optional, Tuple, Union, List

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -548,9 +548,12 @@ def forward(
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
)
else:
return hidden_states
Expand Down Expand Up @@ -702,5 +705,19 @@ def get_model_config_for_expert_location(cls, config):
num_groups=None,
)

def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return

self.capture_aux_hidden_states = True
if layer_ids is None:
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [
2,
num_layers // 2,
num_layers - 3,
] # Specific layers for EAGLE3 support
else:
self.model.layers_to_capture = [val + 1 for val in layer_ids]

EntryClass = Qwen2MoeForCausalLM