diff --git a/mlx_lm/models/deepseek_v32.py b/mlx_lm/models/deepseek_v32.py index 7014dc261..935497fc7 100644 --- a/mlx_lm/models/deepseek_v32.py +++ b/mlx_lm/models/deepseek_v32.py @@ -210,7 +210,7 @@ def __call__( topk_indices = self.indexer(x, qr, mask, cache=cache[1]) if topk_indices is not None: shape = list(topk_indices.shape) - shape[-1] = keys.shape[2] + shape[-1] = kv_latent.shape[2] sparse_mask = mx.zeros(shape, dtype=mx.bool_) sparse_mask = mx.put_along_axis( sparse_mask, topk_indices, mx.array(True), axis=-1 @@ -521,6 +521,7 @@ def dequant(weight, scale_inv): for e in range(self.args.n_routed_experts) ] weights[f"{prefix}.mlp.switch_mlp.{m}.{k}"] = mx.stack(to_join) + prefix = f"model.layers.{l}.self_attn" if f"{prefix}.kv_b_proj.weight" in weights: layer = self.model.layers[l].self_attn.embed_q quantized = f"{prefix}.kv_b_proj.scales" in weights