@@ -67,13 +67,15 @@ def __init__(self,
6767 activation = config .hidden_act ,
6868 quant_config = quant_config )
6969
70- self .block_sparse_moe = GraniteMoeMoE (
71- num_experts = config .num_local_experts ,
72- top_k = config .num_experts_per_tok ,
73- hidden_size = config .hidden_size ,
74- intermediate_size = config .intermediate_size ,
75- quant_config = quant_config ,
76- prefix = f"{ prefix } .block_sparse_moe" )
70+ self .block_sparse_moe = None
71+ if getattr (config , "num_local_experts" , 0 ) > 0 :
72+ self .block_sparse_moe = GraniteMoeMoE (
73+ num_experts = config .num_local_experts ,
74+ top_k = config .num_experts_per_tok ,
75+ hidden_size = config .hidden_size ,
76+ intermediate_size = config .intermediate_size ,
77+ quant_config = quant_config ,
78+ prefix = f"{ prefix } .block_sparse_moe" )
7779
7880 self .shared_mlp = None if \
7981 getattr (config , 'shared_intermediate_size' , 0 ) == 0 \
@@ -105,13 +107,19 @@ def forward(
105107 residual = hidden_states
106108 hidden_states = self .post_attention_layernorm (hidden_states )
107109 if self .shared_mlp is None :
108- hidden_states = self .block_sparse_moe (hidden_states )
110+ if self .block_sparse_moe is not None :
111+ hidden_states = self .block_sparse_moe (hidden_states )
112+ # else: skip
109113 else :
110114 # create a copy since block_sparse_moe modifies in-place
111- moe_hidden_states = hidden_states .clone ()
112- moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
113- hidden_states = moe_hidden_states + self .shared_mlp (hidden_states )
114- del moe_hidden_states
115+ if self .block_sparse_moe is not None :
116+ moe_hidden_states = hidden_states .clone ()
117+ moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
118+ hidden_states = moe_hidden_states + self .shared_mlp (
119+ hidden_states )
120+ del moe_hidden_states
121+ else :
122+ hidden_states = self .shared_mlp (hidden_states )
115123 hidden_states = residual + hidden_states * self .residual_multiplier
116124
117125 return hidden_states , residual
@@ -137,13 +145,15 @@ def __init__(
137145 quant_config = quant_config ,
138146 prefix = f"{ prefix } .self_attn" )
139147
140- self .block_sparse_moe = GraniteMoeMoE (
141- num_experts = config .num_local_experts ,
142- top_k = config .num_experts_per_tok ,
143- hidden_size = config .hidden_size ,
144- intermediate_size = config .intermediate_size ,
145- quant_config = quant_config ,
146- prefix = f"{ prefix } .block_sparse_moe" )
148+ self .block_sparse_moe = None
149+ if getattr (config , "num_local_experts" , 0 ) > 0 :
150+ self .block_sparse_moe = GraniteMoeMoE (
151+ num_experts = config .num_local_experts ,
152+ top_k = config .num_experts_per_tok ,
153+ hidden_size = config .hidden_size ,
154+ intermediate_size = config .intermediate_size ,
155+ quant_config = quant_config ,
156+ prefix = f"{ prefix } .block_sparse_moe" )
147157
148158 self .shared_mlp = None if \
149159 getattr (config , 'shared_intermediate_size' , 0 ) == 0 \
@@ -178,13 +188,19 @@ def forward(
178188 residual = hidden_states
179189 hidden_states = self .post_attention_layernorm (hidden_states )
180190 if self .shared_mlp is None :
181- hidden_states = self .block_sparse_moe (hidden_states )
191+ if self .block_sparse_moe is not None :
192+ hidden_states = self .block_sparse_moe (hidden_states )
193+ # else: skip
182194 else :
183195 # create a copy since block_sparse_moe modifies in-place
184- moe_hidden_states = hidden_states .clone ()
185- moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
186- hidden_states = moe_hidden_states + self .shared_mlp (hidden_states )
187- del moe_hidden_states
196+ if self .block_sparse_moe is not None :
197+ moe_hidden_states = hidden_states .clone ()
198+ moe_hidden_states = self .block_sparse_moe (moe_hidden_states )
199+ hidden_states = moe_hidden_states + self .shared_mlp (
200+ hidden_states )
201+ del moe_hidden_states
202+ else :
203+ hidden_states = self .shared_mlp (hidden_states )
188204 hidden_states = residual + hidden_states * self .residual_multiplier
189205
190206 return hidden_states , residual
0 commit comments