@@ -44,11 +44,10 @@ class ModelArgs(BaseModelArgs):
4444 mamba_n_groups : int
4545 mamba_conv_bias : bool
4646
47- # Other parameters
4847 layer_types : List [str ]
4948 rms_norm_eps : float
5049 rope_theta : float
51- position_embedding_type : str = "rope" # Can be "rope", "nope", etc.
50+ position_embedding_type : str = "rope"
5251 tie_word_embeddings : bool = True
5352 time_step_limit : Tuple [float , float ] = (0.001 , 100.0 )
5453
@@ -108,20 +107,18 @@ def __init__(self, args: ModelArgs):
108107 def _apply_conv (
109108 self , conv_input : mx .array , cache : Optional [MambaCache ] = None
110109 ) -> mx .array :
111- if cache is not None :
112- if cache [0 ] is None :
113- conv_state = mx .zeros (
114- (conv_input .shape [0 ], self .conv_kernel_size - 1 , self .conv_dim ),
115- dtype = conv_input .dtype ,
116- )
117- else :
118- conv_state = cache [0 ]
119- padded_input = mx .concatenate ([conv_state , conv_input ], axis = 1 )
120- cache [0 ] = padded_input [:, - (self .conv_kernel_size - 1 ) :, :]
121- else :
122- padded_input = mx .pad (
123- conv_input , [(0 , 0 ), (self .conv_kernel_size - 1 , 0 ), (0 , 0 )]
110+ if cache is None or cache [0 ] is None :
111+ conv_state = mx .zeros (
112+ (conv_input .shape [0 ], self .conv_kernel_size - 1 , self .conv_dim ),
113+ dtype = conv_input .dtype ,
124114 )
115+ else :
116+ conv_state = cache [0 ]
117+
118+ padded_input = mx .concatenate ([conv_state , conv_input ], axis = 1 )
119+
120+ if cache is not None :
121+ cache [0 ] = padded_input [:, - (self .conv_kernel_size - 1 ) :]
125122
126123 conv_output = self .conv1d (padded_input )
127124 return nn .silu (conv_output )
@@ -224,7 +221,7 @@ def __init__(self, args: ModelArgs):
224221
225222 # Check if RoPE should be used based on position_embedding_type
226223 # If position_embedding_type is "nope", don't use RoPE
227- use_rope = getattr ( args , " position_embedding_type" , "rope" ) != "nope"
224+ use_rope = args . position_embedding_type != "nope"
228225 if use_rope :
229226 self .rope = initialize_rope (
230227 self .head_dim ,
@@ -283,7 +280,7 @@ def __call__(self, hidden_states: mx.array):
283280 ..., - self .top_k :
284281 ]
285282 top_k_logits = mx .take_along_axis (logits , top_k_idx , axis = - 1 )
286- top_k_gates = mx .softmax (top_k_logits . astype ( mx . float32 ) , axis = - 1 )
283+ top_k_gates = mx .softmax (top_k_logits , precise = True , axis = - 1 )
287284 return top_k_idx , top_k_gates
288285
289286
@@ -305,14 +302,18 @@ def __init__(self, args: ModelArgs):
305302 def __call__ (self , x : mx .array ) -> mx .array :
306303 token_ids , gates = self .router (x )
307304 y = self .switch_mlp (x , token_ids )
308- return (y * gates [..., None ]).sum (axis = - 2 ). astype ( y . dtype )
305+ return (y * gates [..., None ]).sum (axis = - 2 )
309306
310307
311308class GraniteMoeHybridSharedMLP (nn .Module ):
312309 def __init__ (self , args : ModelArgs ):
313310 super ().__init__ ()
314- self .input_linear = nn .Linear (args .hidden_size , args .shared_intermediate_size * 2 , bias = False )
315- self .output_linear = nn .Linear (args .shared_intermediate_size , args .hidden_size , bias = False )
311+ self .input_linear = nn .Linear (
312+ args .hidden_size , args .shared_intermediate_size * 2 , bias = False
313+ )
314+ self .output_linear = nn .Linear (
315+ args .shared_intermediate_size , args .hidden_size , bias = False
316+ )
316317
317318 def __call__ (self , x : mx .array ) -> mx .array :
318319 gate , up = mx .split (self .input_linear (x ), 2 , axis = - 1 )
@@ -331,19 +332,14 @@ def __init__(self, args: ModelArgs, layer_type: str):
331332 self .mamba = GraniteMoeHybridMamba2Mixer (args )
332333 elif layer_type == "attention" :
333334 self .self_attn = GraniteMoeHybridAttention (args )
334- self .post_attention_layernorm = nn .RMSNorm (
335- args .hidden_size , eps = args .rms_norm_eps
336- )
337- self .block_sparse_moe = GraniteMoeHybridMoE (args )
338335 else :
339336 raise ValueError (f"Unknown layer type: { layer_type } " )
340337
341338 self .shared_mlp = GraniteMoeHybridSharedMLP (args )
342339 self .block_sparse_moe = GraniteMoeHybridMoE (args )
343- if not hasattr (self , "post_attention_layernorm" ):
344- self .post_attention_layernorm = nn .RMSNorm (
345- args .hidden_size , eps = args .rms_norm_eps
346- )
340+ self .post_attention_layernorm = nn .RMSNorm (
341+ args .hidden_size , eps = args .rms_norm_eps
342+ )
347343
348344 def __call__ (
349345 self ,
@@ -362,11 +358,10 @@ def __call__(
362358
363359 hidden_states = residual + hidden_states * self .residual_multiplier
364360
365- # Second block: MoE + shared_mlp (for ALL layers)
361+ # Second block: MoE + shared_mlp
366362 residual = hidden_states
367363 normed = self .post_attention_layernorm (hidden_states )
368364
369- # Apply both sparse MoE and shared MLP, then sum them
370365 moe_out = self .block_sparse_moe (normed )
371366 shared_out = self .shared_mlp (normed )
372367 mlp_out = moe_out + shared_out
@@ -382,58 +377,29 @@ def __init__(self, args: ModelArgs):
382377 self .args = args
383378 self .embed_tokens = nn .Embedding (args .vocab_size , args .hidden_size )
384379 self .layers = [
385- GraniteMoeHybridLayer (args , layer_type )
386- for layer_type in args .layer_types
380+ GraniteMoeHybridLayer (args , layer_type ) for layer_type in args .layer_types
387381 ]
388382 self .norm = nn .RMSNorm (args .hidden_size , eps = args .rms_norm_eps )
389383 self .embedding_multiplier = args .embedding_multiplier
390-
391- # Find first attention layer index for mask creation
392- self .fa_idx = 0
393- for layer_type in args .layer_types :
394- if layer_type == "attention" :
395- break
396- elif layer_type == "mamba" :
397- self .fa_idx += 1
384+ self .fa_idx = args .layer_types .index ("attention" )
385+ self .layer_types = args .layer_types
398386
399387 def __call__ (
400388 self ,
401389 inputs : mx .array ,
402- mask : Optional [mx .array ] = None ,
403390 cache : Optional [Any ] = None ,
404391 ) -> mx .array :
405392 hidden_states = self .embed_tokens (inputs ) * self .embedding_multiplier
406393
407- if mask is None :
408- # Create mask using first attention layer cache
409- attn_cache = None
410- if cache is not None :
411- cache_idx = 0
412- for layer_type in self .args .layer_types :
413- if layer_type == "attention" :
414- attn_cache = cache [cache_idx ]
415- break
416- elif layer_type == "mamba" :
417- cache_idx += 1
418- attn_mask = create_attention_mask (hidden_states , [attn_cache ] if attn_cache else None )
419-
420394 if cache is None :
421395 cache = [None ] * len (self .layers )
422396
423- cache_counter = 0
424- for layer in self .layers :
425- if layer .layer_type in ["mamba" , "attention" ]:
426- c = cache [cache_counter ]
427- cache_counter += 1
428- else :
429- c = None
430-
431- if layer .layer_type == "attention" :
432- mask_to_use = attn_mask
433- else :
434- mask_to_use = None
397+ attn_mask = create_attention_mask (hidden_states , cache [self .fa_idx ])
435398
436- hidden_states = layer (hidden_states , mask = mask_to_use , cache = c )
399+ cache_counter = 0
400+ for layer , c , layer_type in zip (self .layers , cache , self .layer_types ):
401+ mask = attn_mask if layer .layer_type == "attention" else None
402+ hidden_states = layer (hidden_states , mask = mask , cache = c )
437403
438404 return self .norm (hidden_states )
439405
@@ -442,6 +408,7 @@ class Model(nn.Module):
442408 def __init__ (self , args : ModelArgs ):
443409 super ().__init__ ()
444410 self .args = args
411+ self .model_type = args .model_type
445412 self .model = GraniteMoeHybridModel (args )
446413 if not args .tie_word_embeddings :
447414 self .lm_head = nn .Linear (args .hidden_size , args .vocab_size , bias = False )
@@ -450,10 +417,9 @@ def __init__(self, args: ModelArgs):
450417 def __call__ (
451418 self ,
452419 inputs : mx .array ,
453- mask : Optional [mx .array ] = None ,
454420 cache : Optional [Any ] = None ,
455421 ) -> mx .array :
456- out = self .model (inputs , mask = mask , cache = cache )
422+ out = self .model (inputs , cache = cache )
457423
458424 if self .args .tie_word_embeddings :
459425 out = self .model .embed_tokens .as_linear (out )
@@ -476,37 +442,36 @@ def make_cache(self):
476442 return caches
477443
478444 def sanitize (self , weights ):
479- # Handle conv1d weights (similar to nemotron_h)
445+ # Handle conv1d weights
480446 for k , v in weights .items ():
481447 if "conv1d.weight" in k and v .shape [- 1 ] != 1 :
482448 weights [k ] = v .moveaxis (2 , 1 )
483449
484- # Handle MoE weight transformation from 3D expert weights to SwitchGLU format
450+ # Handle MoE weight transformation to SwitchGLU format
485451 if "model.layers.0.block_sparse_moe.input_linear.weight" in weights :
486452 for l in range (self .args .num_hidden_layers ):
487453 prefix = f"model.layers.{ l } .block_sparse_moe"
488454
489- # Transform input_linear: from (num_experts, expert_hidden, input) to SwitchGLU format
490- input_key = f"{ prefix } .input_linear.weight"
491- if input_key in weights :
492- # The weight is (num_experts, expert_hidden, input_size)
493- # For (62, 1024, 1536): expert_hidden=1024, so gate/up should be 512 each
494- input_weight = weights .pop (input_key )
495- _ , expert_hidden , _ = input_weight .shape
496-
497- # Split into gate and up projections (each half of expert_hidden)
498- gate_proj = input_weight [:, :expert_hidden // 2 , :] # (num_experts, 512, 1536)
499- up_proj = input_weight [:, expert_hidden // 2 :, :] # (num_experts, 512, 1536)
500-
501- weights [f"{ prefix } .switch_mlp.gate_proj.weight" ] = gate_proj
502- weights [f"{ prefix } .switch_mlp.up_proj.weight" ] = up_proj
503-
504- # Transform output_linear: from (num_experts, input, expert_hidden/2) to down_proj
505- output_key = f"{ prefix } .output_linear.weight"
506- if output_key in weights :
507- output_weight = weights .pop (output_key )
508- # Shape should be (num_experts, input_size, expert_hidden/2) = (62, 1536, 512)
509- # This is already in the right format for down_proj
510- weights [f"{ prefix } .switch_mlp.down_proj.weight" ] = output_weight
511-
512- return weights
455+ input_weight = weights .pop (f"{ prefix } .input_linear.weight" )
456+ _ , expert_hidden , _ = input_weight .shape
457+
458+ # Split into gate and up projections (each half of expert_hidden)
459+ gate_proj = input_weight [:, : expert_hidden // 2 , :]
460+ up_proj = input_weight [:, expert_hidden // 2 :, :]
461+ weights [f"{ prefix } .switch_mlp.gate_proj.weight" ] = gate_proj
462+ weights [f"{ prefix } .switch_mlp.up_proj.weight" ] = up_proj
463+
464+ weights [f"{ prefix } .switch_mlp.down_proj.weight" ] = weights .pop (
465+ f"{ prefix } .output_linear.weight"
466+ )
467+
468+ return weights
469+
470+ @property
471+ def quant_predicate (self ):
472+ def predicate (path , _ ):
473+ if path .endswith ("router.layer" ):
474+ return {"group_size" : 64 , "bits" : 8 }
475+ return True
476+
477+ return predicate
0 commit comments