2828
2929from .utils import CpuGuard
3030
31+ from paddleformers .utils .log import logger # 确保 logger 被导入
32+
33+ # 导入你的打印函数
34+ try :
35+ from fastdeploy .model_executor .models .minimax_m1 import print_tensor_stats
36+ except ImportError :
37+ # 如果无法导入,定义一个备用函数以避免程序崩溃
38+ import pprint
39+ def print_tensor_stats (tensor , name ):
40+ logger .info (f"--- [FD DEBUG] { name } --- (print_tensor_stats not found, simple log)" )
41+ if tensor is not None :
42+ logger .info (f"Shape: { tensor .shape } , DType: { tensor .dtype } " )
43+ else :
44+ logger .info ("Tensor is None" )
45+
3146
3247class ErnieRotaryEmbedding :
3348 def __init__ (self , rotary_dim , base , partial_rotary_factor ):
@@ -79,29 +94,82 @@ def __call__(self, position_ids):
7994 return rot_emb
8095
8196
97+ # class GlmRotaryEmbedding:
98+ # def __init__(self, rotary_dim, base, partial_rotary_factor):
99+ # """
100+ # Pre-calculate rotary position embedding for position_ids.
101+ # """
102+ # self.rotary_dim = rotary_dim
103+ # self.base = base
104+ # if partial_rotary_factor < 1.0:
105+ # self.rotary_dim = int(self.rotary_dim * partial_rotary_factor)
106+
107+ # def __call__(self, position_ids):
108+ # bsz, max_seq_len = position_ids.shape[:2]
109+ # inv_freq = self.base ** (-paddle.arange(0, self.rotary_dim, 2, dtype="float32") / self.rotary_dim)
110+ # freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq)
111+ # # shape: [B, S, D/2]
112+ # rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, self.rotary_dim // 2), dtype="float32")
113+ # emb = paddle.stack([freqs], axis=-1).reshape((bsz, max_seq_len, self.rotary_dim // 2))
114+ # # shape: [B, S, 1, D]
115+ # emb = paddle.unsqueeze(emb, 2)
116+ # rot_emb[0] = paddle.cos(emb)
117+ # rot_emb[1] = paddle.sin(emb)
118+ # return rot_emb
119+
120+
121+
82122class GlmRotaryEmbedding :
83123 def __init__ (self , rotary_dim , base , partial_rotary_factor ):
84124 """
85125 Pre-calculate rotary position embedding for position_ids.
86126 """
87- self .rotary_dim = rotary_dim
127+ # --- 详细日志 ---
128+ logger .info (">>>> [GlmRotaryEmbedding.__init__] <<<<" )
129+ logger .info (f" - Received rotary_dim (as head_dim): { rotary_dim } " )
130+ logger .info (f" - Received partial_rotary_factor: { partial_rotary_factor } " )
131+
88132 self .base = base
133+
134+ # 核心计算
89135 if partial_rotary_factor < 1.0 :
90- self .rotary_dim = int (self .rotary_dim * partial_rotary_factor )
136+ self .rotary_dim = int (rotary_dim * partial_rotary_factor )
137+ else :
138+ self .rotary_dim = rotary_dim
139+
140+ logger .info (f" - Calculated final self.rotary_dim: { self .rotary_dim } " )
141+ # --- 日志结束 ---
91142
92143 def __call__ (self , position_ids ):
144+ # --- 详细日志 ---
145+ logger .info (">>>> [GlmRotaryEmbedding.__call__] <<<<" )
146+ logger .info (f" - Using self.rotary_dim: { self .rotary_dim } " )
147+ logger .info (f" - Using self.base: { self .base } " )
148+
93149 bsz , max_seq_len = position_ids .shape [:2 ]
94- inv_freq = self .base ** (- paddle .arange (0 , self .rotary_dim , 2 , dtype = "float32" ) / self .rotary_dim )
150+
151+ # 检查 arange 的上界
152+ arange_upper_bound = self .rotary_dim
153+ logger .info (f" - paddle.arange upper bound is: { arange_upper_bound } " )
154+
155+ # 关键计算步骤
156+ inv_freq_dims = paddle .arange (0 , arange_upper_bound , 2 , dtype = "float32" )
157+ logger .info (f" - Shape of inv_freq_dims (from arange): { inv_freq_dims .shape } " ) # 这一行会告诉我们最终维度
158+
159+ inv_freq = self .base ** (- inv_freq_dims / self .rotary_dim )
95160 freqs = paddle .einsum ("ij,k->ijk" , position_ids .cast ("float32" ), inv_freq )
96- # shape: [B, S, D/2]
161+
97162 rot_emb = paddle .zeros ((2 , bsz , max_seq_len , 1 , self .rotary_dim // 2 ), dtype = "float32" )
98163 emb = paddle .stack ([freqs ], axis = - 1 ).reshape ((bsz , max_seq_len , self .rotary_dim // 2 ))
99- # shape: [B, S, 1, D]
164+
100165 emb = paddle .unsqueeze (emb , 2 )
101166 rot_emb [0 ] = paddle .cos (emb )
102167 rot_emb [1 ] = paddle .sin (emb )
103- return rot_emb
168+
169+ logger .info (f" - Final returned rot_emb shape: { rot_emb .shape } " )
170+ logger .info (">>>> [GlmRotaryEmbedding.__call__ END] <<<<" )
104171
172+ return rot_emb
105173
106174class QwenRotaryEmbedding :
107175 def __init__ (self , rotary_dim , base , partial_rotary_factor ):
@@ -131,7 +199,6 @@ def __call__(self, position_ids):
131199
132200 return rot_emb
133201
134-
135202def yarn_get_mscale (scale = 1 , mscale = 1 ):
136203 """ """
137204 if scale <= 1 :
@@ -332,11 +399,14 @@ def get_rope_impl(
332399 """
333400 The real implementation of get_rope
334401 """
402+ print_tensor_stats (position_ids [:, :16 ], "ROPE_IMPL_INPUT:position_ids[:, :16]" )
335403
336404 architecture = model_config .architectures [0 ]
405+ # if architecture.startswith("Qwen") or architecture.startswith("MiniMaxM1"):
337406 if architecture .startswith ("Qwen" ):
338407 rotary_emb_layer = QwenRotaryEmbedding (rotary_dim , base , partial_rotary_factor )
339408 rotary_emb = rotary_emb_layer (position_ids )
409+ # elif architecture.startswith("Glm"):
340410 elif architecture .startswith ("Glm" ) or architecture .startswith ("MiniMaxM1" ):
341411 rotary_emb_layer = GlmRotaryEmbedding (rotary_dim , base , partial_rotary_factor )
342412 rotary_emb = rotary_emb_layer (position_ids )
@@ -354,6 +424,15 @@ def get_rope_impl(
354424 else :
355425 rotary_emb_layer = ErnieRotaryEmbedding (rotary_dim , base , partial_rotary_factor )
356426 rotary_emb = rotary_emb_layer (position_ids )
427+
428+ # if rotary_emb.ndim == 5:
429+ # logger.info(f">>>> [ROPE RESHAPE] Squeezing rotary_emb from {rotary_emb.shape} <<<<")
430+ # rotary_emb = paddle.squeeze(rotary_emb, axis=[1, 3])
431+ # logger.info(f">>>> [ROPE RESHAPE] New shape is {rotary_emb.shape} <<<<")
432+
433+ # ... (之前的日志打印)
434+ print_tensor_stats (rotary_emb [0 , :16 ], "ROPE_IMPL_OUTPUT:cos_emb[:16]" )
435+ print_tensor_stats (rotary_emb [1 , :16 ], "ROPE_IMPL_OUTPUT:sin_emb[:16]" )
357436 return rotary_emb
358437
359438
0 commit comments