Skip to content

Commit cf41174

Browse files
committed
nits + rebase
1 parent 53c1f49 commit cf41174

File tree

2 files changed

+88
-95
lines changed

2 files changed

+88
-95
lines changed

mlx_lm/models/granitemoehybrid.py

Lines changed: 60 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ class ModelArgs(BaseModelArgs):
4444
mamba_n_groups: int
4545
mamba_conv_bias: bool
4646

47-
# Other parameters
4847
layer_types: List[str]
4948
rms_norm_eps: float
5049
rope_theta: float
51-
position_embedding_type: str = "rope" # Can be "rope", "nope", etc.
50+
position_embedding_type: str = "rope"
5251
tie_word_embeddings: bool = True
5352
time_step_limit: Tuple[float, float] = (0.001, 100.0)
5453

@@ -108,20 +107,18 @@ def __init__(self, args: ModelArgs):
108107
def _apply_conv(
109108
self, conv_input: mx.array, cache: Optional[MambaCache] = None
110109
) -> mx.array:
111-
if cache is not None:
112-
if cache[0] is None:
113-
conv_state = mx.zeros(
114-
(conv_input.shape[0], self.conv_kernel_size - 1, self.conv_dim),
115-
dtype=conv_input.dtype,
116-
)
117-
else:
118-
conv_state = cache[0]
119-
padded_input = mx.concatenate([conv_state, conv_input], axis=1)
120-
cache[0] = padded_input[:, -(self.conv_kernel_size - 1) :, :]
121-
else:
122-
padded_input = mx.pad(
123-
conv_input, [(0, 0), (self.conv_kernel_size - 1, 0), (0, 0)]
110+
if cache is None or cache[0] is None:
111+
conv_state = mx.zeros(
112+
(conv_input.shape[0], self.conv_kernel_size - 1, self.conv_dim),
113+
dtype=conv_input.dtype,
124114
)
115+
else:
116+
conv_state = cache[0]
117+
118+
padded_input = mx.concatenate([conv_state, conv_input], axis=1)
119+
120+
if cache is not None:
121+
cache[0] = padded_input[:, -(self.conv_kernel_size - 1) :]
125122

126123
conv_output = self.conv1d(padded_input)
127124
return nn.silu(conv_output)
@@ -224,7 +221,7 @@ def __init__(self, args: ModelArgs):
224221

225222
# Check if RoPE should be used based on position_embedding_type
226223
# If position_embedding_type is "nope", don't use RoPE
227-
use_rope = getattr(args, "position_embedding_type", "rope") != "nope"
224+
use_rope = args.position_embedding_type != "nope"
228225
if use_rope:
229226
self.rope = initialize_rope(
230227
self.head_dim,
@@ -283,7 +280,7 @@ def __call__(self, hidden_states: mx.array):
283280
..., -self.top_k :
284281
]
285282
top_k_logits = mx.take_along_axis(logits, top_k_idx, axis=-1)
286-
top_k_gates = mx.softmax(top_k_logits.astype(mx.float32), axis=-1)
283+
top_k_gates = mx.softmax(top_k_logits, precise=True, axis=-1)
287284
return top_k_idx, top_k_gates
288285

289286

@@ -305,14 +302,18 @@ def __init__(self, args: ModelArgs):
305302
def __call__(self, x: mx.array) -> mx.array:
306303
token_ids, gates = self.router(x)
307304
y = self.switch_mlp(x, token_ids)
308-
return (y * gates[..., None]).sum(axis=-2).astype(y.dtype)
305+
return (y * gates[..., None]).sum(axis=-2)
309306

310307

311308
class GraniteMoeHybridSharedMLP(nn.Module):
312309
def __init__(self, args: ModelArgs):
313310
super().__init__()
314-
self.input_linear = nn.Linear(args.hidden_size, args.shared_intermediate_size * 2, bias=False)
315-
self.output_linear = nn.Linear(args.shared_intermediate_size, args.hidden_size, bias=False)
311+
self.input_linear = nn.Linear(
312+
args.hidden_size, args.shared_intermediate_size * 2, bias=False
313+
)
314+
self.output_linear = nn.Linear(
315+
args.shared_intermediate_size, args.hidden_size, bias=False
316+
)
316317

317318
def __call__(self, x: mx.array) -> mx.array:
318319
gate, up = mx.split(self.input_linear(x), 2, axis=-1)
@@ -331,19 +332,14 @@ def __init__(self, args: ModelArgs, layer_type: str):
331332
self.mamba = GraniteMoeHybridMamba2Mixer(args)
332333
elif layer_type == "attention":
333334
self.self_attn = GraniteMoeHybridAttention(args)
334-
self.post_attention_layernorm = nn.RMSNorm(
335-
args.hidden_size, eps=args.rms_norm_eps
336-
)
337-
self.block_sparse_moe = GraniteMoeHybridMoE(args)
338335
else:
339336
raise ValueError(f"Unknown layer type: {layer_type}")
340337

341338
self.shared_mlp = GraniteMoeHybridSharedMLP(args)
342339
self.block_sparse_moe = GraniteMoeHybridMoE(args)
343-
if not hasattr(self, "post_attention_layernorm"):
344-
self.post_attention_layernorm = nn.RMSNorm(
345-
args.hidden_size, eps=args.rms_norm_eps
346-
)
340+
self.post_attention_layernorm = nn.RMSNorm(
341+
args.hidden_size, eps=args.rms_norm_eps
342+
)
347343

348344
def __call__(
349345
self,
@@ -362,11 +358,10 @@ def __call__(
362358

363359
hidden_states = residual + hidden_states * self.residual_multiplier
364360

365-
# Second block: MoE + shared_mlp (for ALL layers)
361+
# Second block: MoE + shared_mlp
366362
residual = hidden_states
367363
normed = self.post_attention_layernorm(hidden_states)
368364

369-
# Apply both sparse MoE and shared MLP, then sum them
370365
moe_out = self.block_sparse_moe(normed)
371366
shared_out = self.shared_mlp(normed)
372367
mlp_out = moe_out + shared_out
@@ -382,58 +377,29 @@ def __init__(self, args: ModelArgs):
382377
self.args = args
383378
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
384379
self.layers = [
385-
GraniteMoeHybridLayer(args, layer_type)
386-
for layer_type in args.layer_types
380+
GraniteMoeHybridLayer(args, layer_type) for layer_type in args.layer_types
387381
]
388382
self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps)
389383
self.embedding_multiplier = args.embedding_multiplier
390-
391-
# Find first attention layer index for mask creation
392-
self.fa_idx = 0
393-
for layer_type in args.layer_types:
394-
if layer_type == "attention":
395-
break
396-
elif layer_type == "mamba":
397-
self.fa_idx += 1
384+
self.fa_idx = args.layer_types.index("attention")
385+
self.layer_types = args.layer_types
398386

399387
def __call__(
400388
self,
401389
inputs: mx.array,
402-
mask: Optional[mx.array] = None,
403390
cache: Optional[Any] = None,
404391
) -> mx.array:
405392
hidden_states = self.embed_tokens(inputs) * self.embedding_multiplier
406393

407-
if mask is None:
408-
# Create mask using first attention layer cache
409-
attn_cache = None
410-
if cache is not None:
411-
cache_idx = 0
412-
for layer_type in self.args.layer_types:
413-
if layer_type == "attention":
414-
attn_cache = cache[cache_idx]
415-
break
416-
elif layer_type == "mamba":
417-
cache_idx += 1
418-
attn_mask = create_attention_mask(hidden_states, [attn_cache] if attn_cache else None)
419-
420394
if cache is None:
421395
cache = [None] * len(self.layers)
422396

423-
cache_counter = 0
424-
for layer in self.layers:
425-
if layer.layer_type in ["mamba", "attention"]:
426-
c = cache[cache_counter]
427-
cache_counter += 1
428-
else:
429-
c = None
430-
431-
if layer.layer_type == "attention":
432-
mask_to_use = attn_mask
433-
else:
434-
mask_to_use = None
397+
attn_mask = create_attention_mask(hidden_states, cache[self.fa_idx])
435398

436-
hidden_states = layer(hidden_states, mask=mask_to_use, cache=c)
399+
cache_counter = 0
400+
for layer, c, layer_type in zip(self.layers, cache, self.layer_types):
401+
mask = attn_mask if layer.layer_type == "attention" else None
402+
hidden_states = layer(hidden_states, mask=mask, cache=c)
437403

438404
return self.norm(hidden_states)
439405

@@ -442,6 +408,7 @@ class Model(nn.Module):
442408
def __init__(self, args: ModelArgs):
443409
super().__init__()
444410
self.args = args
411+
self.model_type = args.model_type
445412
self.model = GraniteMoeHybridModel(args)
446413
if not args.tie_word_embeddings:
447414
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
@@ -450,10 +417,9 @@ def __init__(self, args: ModelArgs):
450417
def __call__(
451418
self,
452419
inputs: mx.array,
453-
mask: Optional[mx.array] = None,
454420
cache: Optional[Any] = None,
455421
) -> mx.array:
456-
out = self.model(inputs, mask=mask, cache=cache)
422+
out = self.model(inputs, cache=cache)
457423

458424
if self.args.tie_word_embeddings:
459425
out = self.model.embed_tokens.as_linear(out)
@@ -476,37 +442,36 @@ def make_cache(self):
476442
return caches
477443

478444
def sanitize(self, weights):
479-
# Handle conv1d weights (similar to nemotron_h)
445+
# Handle conv1d weights
480446
for k, v in weights.items():
481447
if "conv1d.weight" in k and v.shape[-1] != 1:
482448
weights[k] = v.moveaxis(2, 1)
483449

484-
# Handle MoE weight transformation from 3D expert weights to SwitchGLU format
450+
# Handle MoE weight transformation to SwitchGLU format
485451
if "model.layers.0.block_sparse_moe.input_linear.weight" in weights:
486452
for l in range(self.args.num_hidden_layers):
487453
prefix = f"model.layers.{l}.block_sparse_moe"
488454

489-
# Transform input_linear: from (num_experts, expert_hidden, input) to SwitchGLU format
490-
input_key = f"{prefix}.input_linear.weight"
491-
if input_key in weights:
492-
# The weight is (num_experts, expert_hidden, input_size)
493-
# For (62, 1024, 1536): expert_hidden=1024, so gate/up should be 512 each
494-
input_weight = weights.pop(input_key)
495-
_, expert_hidden, _ = input_weight.shape
496-
497-
# Split into gate and up projections (each half of expert_hidden)
498-
gate_proj = input_weight[:, :expert_hidden//2, :] # (num_experts, 512, 1536)
499-
up_proj = input_weight[:, expert_hidden//2:, :] # (num_experts, 512, 1536)
500-
501-
weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_proj
502-
weights[f"{prefix}.switch_mlp.up_proj.weight"] = up_proj
503-
504-
# Transform output_linear: from (num_experts, input, expert_hidden/2) to down_proj
505-
output_key = f"{prefix}.output_linear.weight"
506-
if output_key in weights:
507-
output_weight = weights.pop(output_key)
508-
# Shape should be (num_experts, input_size, expert_hidden/2) = (62, 1536, 512)
509-
# This is already in the right format for down_proj
510-
weights[f"{prefix}.switch_mlp.down_proj.weight"] = output_weight
511-
512-
return weights
455+
input_weight = weights.pop(f"{prefix}.input_linear.weight")
456+
_, expert_hidden, _ = input_weight.shape
457+
458+
# Split into gate and up projections (each half of expert_hidden)
459+
gate_proj = input_weight[:, : expert_hidden // 2, :]
460+
up_proj = input_weight[:, expert_hidden // 2 :, :]
461+
weights[f"{prefix}.switch_mlp.gate_proj.weight"] = gate_proj
462+
weights[f"{prefix}.switch_mlp.up_proj.weight"] = up_proj
463+
464+
weights[f"{prefix}.switch_mlp.down_proj.weight"] = weights.pop(
465+
f"{prefix}.output_linear.weight"
466+
)
467+
468+
return weights
469+
470+
@property
471+
def quant_predicate(self):
472+
def predicate(path, _):
473+
if path.endswith("router.layer"):
474+
return {"group_size": 64, "bits": 8}
475+
return True
476+
477+
return predicate

tests/test_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1683,6 +1683,34 @@ def test_all_models(self):
16831683
"rope_theta": 1000,
16841684
"layer_norm_eps": 1e-5,
16851685
},
1686+
{
1687+
"model_type": "granitemoehybrid",
1688+
"vocab_size": 1000,
1689+
"hidden_size": 128,
1690+
"intermediate_size": 128,
1691+
"num_hidden_layers": 4,
1692+
"max_position_embeddings": 1000,
1693+
"num_attention_heads": 8,
1694+
"num_key_value_heads": 4,
1695+
"attention_bias": False,
1696+
"embedding_multiplier": 1.0,
1697+
"attention_multiplier": 1.0,
1698+
"logits_scaling": 1.0,
1699+
"residual_multiplier": 1.0,
1700+
"num_local_experts": 8,
1701+
"num_experts_per_tok": 2,
1702+
"shared_intermediate_size": 128,
1703+
"mamba_n_heads": 8,
1704+
"mamba_d_head": 16,
1705+
"mamba_proj_bias": False,
1706+
"mamba_d_state": 128,
1707+
"mamba_d_conv": 4,
1708+
"mamba_n_groups": 1,
1709+
"mamba_conv_bias": False,
1710+
"layer_types": ["mamba", "attention", "mamba", "attention"],
1711+
"rms_norm_eps": 1e-5,
1712+
"rope_theta": 1000.0,
1713+
},
16861714
]
16871715
for config in test_configs:
16881716
model_type = config["model_type"]

0 commit comments

Comments
 (0)