2121
2222import torch
2323from torch import nn
24- from torch .nn import LayerNorm
2524from transformers import Glm4Config
2625
2726from sglang .srt .distributed import get_tensor_model_parallel_world_size
4039from sglang .srt .models .llama import LlamaMLP as Glm4MLP
4140from sglang .srt .utils import add_prefix , make_layers
4241
43- LoraConfig = None
44-
4542
4643class Glm4Attention (nn .Module ):
4744 def __init__ (
@@ -220,7 +217,6 @@ def __init__(
220217 )
221218
222219 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
223- self .layers_to_capture = []
224220
225221 @torch .no_grad ()
226222 def forward (
@@ -234,12 +230,8 @@ def forward(
234230 hidden_states = self .embed_tokens (input_ids )
235231 else :
236232 hidden_states = input_embeds
237- aux_hidden_states = []
238233 residual = None
239- for i in range (len (self .layers )):
240- if i in self .layers_to_capture :
241- aux_hidden_states .append (hidden_states )
242- layer = self .layers [i ]
234+ for layer in self .layers :
243235 hidden_states , residual = layer (
244236 positions ,
245237 hidden_states ,
@@ -248,34 +240,7 @@ def forward(
248240 )
249241 hidden_states , _ = self .norm (hidden_states , residual )
250242
251- if len (aux_hidden_states ) == 0 :
252- return hidden_states
253-
254- return hidden_states , aux_hidden_states
255-
256- # If this function is called, it should always initialize KV cache scale
257- # factors (or else raise an exception). Thus, handled exceptions should
258- # make sure to leave KV cache scale factors in a known good (dummy) state
259- def load_kv_cache_scales (self , quantization_param_path : str ) -> None :
260- tp_size = get_tensor_model_parallel_world_size ()
261- tp_rank = get_tensor_model_parallel_rank ()
262- for layer_idx , scaling_factor in kv_cache_scales_loader (
263- quantization_param_path ,
264- tp_rank ,
265- tp_size ,
266- self .config .num_hidden_layers ,
267- self .config .__class__ .model_type ,
268- ):
269- if not isinstance (self .layers [layer_idx ], nn .Identity ):
270- layer_self_attn = self .layers [layer_idx ].self_attn
271-
272- if hasattr (layer_self_attn .attn , "k_scale" ):
273- layer_self_attn .attn .k_scale = scaling_factor
274- layer_self_attn .attn .v_scale = scaling_factor
275- else :
276- raise RuntimeError (
277- "Self attention has no KV cache scaling " "factor attribute!"
278- )
243+ return hidden_states
279244
280245
281246class Glm4ForCausalLM (nn .Module ):
@@ -325,14 +290,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
325290 for name , loaded_weight in weights :
326291 if self .config .tie_word_embeddings and "lm_head.weight" in name :
327292 continue
328- # Handle FP8 kv-scale remapping
329- if "scale" in name :
330- name = maybe_remap_kv_scale_name (name , params_dict )
331- if name is None :
332- continue
333- # Skip loading kv_scale from ckpts towards new design.
334- if name .endswith (".kv_scale" ) and name not in params_dict :
335- continue
336293 for param_name , weight_name , shard_id in stacked_params_mapping :
337294 if weight_name not in name :
338295 continue
@@ -342,12 +299,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
342299 weight_loader (param , loaded_weight , shard_id )
343300 break
344301 else :
345- # Skip loading extra bias for GPTQ models.
346- if name .endswith (".bias" ) and name not in params_dict :
347- continue
348- # Skip loading kv_scale from ckpts towards new design.
349- if name .endswith (".kv_scale" ) and name not in params_dict :
350- continue
351302 if name in params_dict .keys ():
352303 param = params_dict [name ]
353304 weight_loader = getattr (
0 commit comments