Skip to content

Commit ea55a97

Browse files
committed
Remove unused code.
1 parent 3e9cf8e commit ea55a97

File tree

2 files changed

+3
-52
lines changed

2 files changed

+3
-52
lines changed

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ runtime_common = [
3737
"python-multipart",
3838
"pyzmq>=25.1.2",
3939
"soundfile==0.13.1",
40-
"torchao>=0.9.0",
40+
"torchao==0.9.0",
4141
"transformers==4.51.3",
4242
"uvicorn",
4343
"uvloop",

python/sglang/srt/models/glm4.py

Lines changed: 2 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import torch
2323
from torch import nn
24-
from torch.nn import LayerNorm
2524
from transformers import Glm4Config
2625

2726
from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -40,8 +39,6 @@
4039
from sglang.srt.models.llama import LlamaMLP as Glm4MLP
4140
from sglang.srt.utils import add_prefix, make_layers
4241

43-
LoraConfig = None
44-
4542

4643
class Glm4Attention(nn.Module):
4744
def __init__(
@@ -220,7 +217,6 @@ def __init__(
220217
)
221218

222219
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
223-
self.layers_to_capture = []
224220

225221
@torch.no_grad()
226222
def forward(
@@ -234,12 +230,8 @@ def forward(
234230
hidden_states = self.embed_tokens(input_ids)
235231
else:
236232
hidden_states = input_embeds
237-
aux_hidden_states = []
238233
residual = None
239-
for i in range(len(self.layers)):
240-
if i in self.layers_to_capture:
241-
aux_hidden_states.append(hidden_states)
242-
layer = self.layers[i]
234+
for layer in self.layers:
243235
hidden_states, residual = layer(
244236
positions,
245237
hidden_states,
@@ -248,34 +240,7 @@ def forward(
248240
)
249241
hidden_states, _ = self.norm(hidden_states, residual)
250242

251-
if len(aux_hidden_states) == 0:
252-
return hidden_states
253-
254-
return hidden_states, aux_hidden_states
255-
256-
# If this function is called, it should always initialize KV cache scale
257-
# factors (or else raise an exception). Thus, handled exceptions should
258-
# make sure to leave KV cache scale factors in a known good (dummy) state
259-
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
260-
tp_size = get_tensor_model_parallel_world_size()
261-
tp_rank = get_tensor_model_parallel_rank()
262-
for layer_idx, scaling_factor in kv_cache_scales_loader(
263-
quantization_param_path,
264-
tp_rank,
265-
tp_size,
266-
self.config.num_hidden_layers,
267-
self.config.__class__.model_type,
268-
):
269-
if not isinstance(self.layers[layer_idx], nn.Identity):
270-
layer_self_attn = self.layers[layer_idx].self_attn
271-
272-
if hasattr(layer_self_attn.attn, "k_scale"):
273-
layer_self_attn.attn.k_scale = scaling_factor
274-
layer_self_attn.attn.v_scale = scaling_factor
275-
else:
276-
raise RuntimeError(
277-
"Self attention has no KV cache scaling " "factor attribute!"
278-
)
243+
return hidden_states
279244

280245

281246
class Glm4ForCausalLM(nn.Module):
@@ -325,14 +290,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
325290
for name, loaded_weight in weights:
326291
if self.config.tie_word_embeddings and "lm_head.weight" in name:
327292
continue
328-
# Handle FP8 kv-scale remapping
329-
if "scale" in name:
330-
name = maybe_remap_kv_scale_name(name, params_dict)
331-
if name is None:
332-
continue
333-
# Skip loading kv_scale from ckpts towards new design.
334-
if name.endswith(".kv_scale") and name not in params_dict:
335-
continue
336293
for param_name, weight_name, shard_id in stacked_params_mapping:
337294
if weight_name not in name:
338295
continue
@@ -342,12 +299,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
342299
weight_loader(param, loaded_weight, shard_id)
343300
break
344301
else:
345-
# Skip loading extra bias for GPTQ models.
346-
if name.endswith(".bias") and name not in params_dict:
347-
continue
348-
# Skip loading kv_scale from ckpts towards new design.
349-
if name.endswith(".kv_scale") and name not in params_dict:
350-
continue
351302
if name in params_dict.keys():
352303
param = params_dict[name]
353304
weight_loader = getattr(

0 commit comments

Comments
 (0)