Skip to content

Conversation

@gabe-l-hart
Copy link
Contributor

Description

Addresses #256

This PR adds support for the GraniteMoeHybrid model architecture. It was heavily written using Claude Code with my input and guidance. I have deep knowledge of the model architecture having implemented it for llama.cpp along with similar architectures like nemotron_h, so I gave Claude Code very specific guidance (though after-the fact I found several gaps in my guidance that Claude Code was able to work around).

Claude Code Artifacts

I used this as an opportunity to test Claude Code's capabilities, so I'm recording as much of the session here as possible for posterity.

Input Prompt

I've got an in-depth feature request for you to add. I need you to add support for the GraniteMoeHybrid architecture to the mlx-lm project. The task is to extend the existing set of model architecture implementations in mlx_lm/models by adding a new module named granitemoehybrid.py. Here are a few key pointers on this model architecture:

The goal of this project is to create a fully working local implementation of the model in mlx_lm. You can find a local model to test with at /Users/ghart/models/granite-4.0-tiny-preview/. You can find a version of the nemotron_h model to test with at /Users/ghart/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2/. To accomplish this project, you'll need to take the following steps:

  1. Get a development environment working (you can use uv to manage your virtual env) and install the necessary dependencies
  2. Run a sample inference with a model that is already known to work (eg /Users/ghart/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2/)
  3. Create the new module at mlx_lm/models/granitemoehybrid.py
  4. Implement the model architecture, test, and iterate until you've got things working locally

Once you've got it working, let me know and I'll review and commit

tmp_mlx.py
#!/usr/bin/env python3

import mlx.core as mx
from mlx_lm import load
import json

def debug_mlx():
    model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
    prompt = 'The'
    
    print("=== MLX DEBUG ===")
    
    # Load model and tokenizer
    print("Loading model...")
    model, tokenizer = load(model_path)
    
    print(f"Model type: {model.args.model_type}")
    
    # Display config
    print("\nKey config values:")
    config_attrs = [
        'embedding_multiplier', 'attention_multiplier', 'logits_scaling',
        'residual_multiplier', 'position_embedding_type', 'layer_types'
    ]
    for attr in config_attrs:
        if hasattr(model.args, attr):
            value = getattr(model.args, attr)
            if attr == 'layer_types':
                print(f"  {attr}: {value[:10]}...")  # Show first 10
            else:
                print(f"  {attr}: {value}")
    
    # Tokenize
    tokens = tokenizer.encode(prompt)
    print(f"\nPrompt: '{prompt}'")
    print(f"Tokens: {tokens}")
    
    # Forward pass
    inputs = mx.array([tokens])
    print(f"Input shape: {inputs.shape}")
    
    # Get embeddings
    embeddings = model.model.embed_tokens(inputs)
    print(f"\nEmbedding analysis:")
    print(f"  Raw embeddings shape: {embeddings.shape}")
    print(f"  Raw embeddings mean: {mx.mean(embeddings).item():.6f}")
    print(f"  Raw embeddings std: {mx.std(embeddings).item():.6f}")
    
    # Apply embedding multiplier
    scaled_embeddings = embeddings * model.model.embedding_multiplier
    print(f"  After embedding_multiplier ({model.model.embedding_multiplier}):")
    print(f"    mean: {mx.mean(scaled_embeddings).item():.6f}")
    print(f"    std: {mx.std(scaled_embeddings).item():.6f}")
    
    # Full forward pass
    logits = model(inputs)
    
    print(f"\nLogits analysis:")
    print(f"  Logits shape: {logits.shape}")
    print(f"  Logits mean: {mx.mean(logits).item():.6f}")
    print(f"  Logits std: {mx.std(logits).item():.6f}")
    
    # Check logits before scaling
    pre_scaled_logits = logits * model.logits_scaling
    print(f"  Logits before /logits_scaling ({model.logits_scaling}):")
    print(f"    mean: {mx.mean(pre_scaled_logits).item():.6f}")
    print(f"    std: {mx.std(pre_scaled_logits).item():.6f}")
    
    # Top predictions
    last_logits = logits[0, -1, :]
    probs = mx.softmax(last_logits)
    top_indices = mx.argpartition(-last_logits, 10)[:10]
    
    print(f"\nTop 10 predictions:")
    for i, idx in enumerate(top_indices):
        idx_val = idx.item()
        prob = probs[idx].item()
        token_text = repr(tokenizer.decode([idx_val]))
        print(f"  {i+1:2d}. Token {idx_val:5d}: {token_text:20s} (prob: {prob:.6f})")
    
    # Layer analysis
    print(f"\nLayer analysis:")
    print(f"Number of layers: {len(model.layers)}")
    print(f"First 10 layer types: {[layer.layer_type for layer in model.layers[:10]]}")
    
    # Debug specific layers
    print(f"\nLayer 0 analysis (should be mamba):")
    layer0 = model.layers[0]
    print(f"  Layer type: {layer0.layer_type}")
    if layer0.layer_type == 'mamba':
        print(f"  Mamba config:")
        mamba = layer0.mamba
        print(f"    num_heads: {mamba.num_heads}")
        print(f"    hidden_size: {mamba.hidden_size}")
        print(f"    intermediate_size: {mamba.intermediate_size}")
        print(f"    conv_kernel_size: {mamba.conv_kernel_size}")
    
    # Test first layer forward pass
    print(f"\nFirst layer forward pass test:")
    x_after_embed = scaled_embeddings
    print(f"  Input to layer 0: mean={mx.mean(x_after_embed).item():.6f}, std={mx.std(x_after_embed).item():.6f}")
    
    # Apply first layer norm
    normed_input = layer0.input_layernorm(x_after_embed)
    print(f"  After input_layernorm: mean={mx.mean(normed_input).item():.6f}, std={mx.std(normed_input).item():.6f}")
    
    return {
        'tokens': tokens,
        'logits': logits,
        'top_indices': [idx.item() for idx in top_indices],
        'top_probs': [probs[idx].item() for idx in top_indices],
        'embeddings_stats': {
            'mean': mx.mean(embeddings).item(),
            'std': mx.std(embeddings).item()
        },
        'scaled_embeddings_stats': {
            'mean': mx.mean(scaled_embeddings).item(),
            'std': mx.std(scaled_embeddings).item()
        },
        'logits_stats': {
            'mean': mx.mean(logits).item(),
            'std': mx.std(logits).item()
        },
        'pre_scaled_logits_stats': {
            'mean': mx.mean(pre_scaled_logits).item(),
            'std': mx.std(pre_scaled_logits).item()
        }
    }

if __name__ == "__main__":
    results = debug_mlx()
    print(f"\n=== RESULTS SUMMARY ===")
    print(f"Top token: {results['top_indices'][0]} (prob: {results['top_probs'][0]:.6f})")
    print(f"Logits mean: {results['logits_stats']['mean']:.6f}")
    print(f"Logits std: {results['logits_stats']['std']:.6f}")
    print(f"Pre-scaled logits mean: {results['pre_scaled_logits_stats']['mean']:.6f}")
    print(f"Pre-scaled logits std: {results['pre_scaled_logits_stats']['std']:.6f}")
tmp_transformers.py
#!/usr/bin/env python3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import json

def debug_transformers():
    model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
    prompt = 'The'
    
    print("=== TRANSFORMERS DEBUG ===")
    
    # Load model and tokenizer
    print("Loading model...")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    print(f"Model type: {model.config.model_type}")
    print(f"Architecture: {model.config.architectures}")
    
    # Load and display config
    with open(f'{model_path}/config.json', 'r') as f:
        config = json.load(f)
    
    print("\nKey config values:")
    key_params = [
        'embedding_multiplier', 'attention_multiplier', 'logits_scaling', 
        'residual_multiplier', 'position_embedding_type', 'layer_types'
    ]
    for param in key_params:
        if param in config:
            if param == 'layer_types':
                print(f"  {param}: {config[param][:10]}...")  # Show first 10
            else:
                print(f"  {param}: {config[param]}")
    
    # Tokenize
    tokens = tokenizer.encode(prompt)
    print(f"\nPrompt: '{prompt}'")
    print(f"Tokens: {tokens}")
    
    # Forward pass
    inputs = torch.tensor([tokens])
    print(f"Input shape: {inputs.shape}")
    
    with torch.no_grad():
        # Get embeddings
        embeddings = model.model.embed_tokens(inputs)
        print(f"\nEmbedding analysis:")
        print(f"  Raw embeddings shape: {embeddings.shape}")
        print(f"  Raw embeddings mean: {embeddings.mean().item():.6f}")
        print(f"  Raw embeddings std: {embeddings.std().item():.6f}")
        
        # Apply embedding multiplier (if it exists in the model)
        if hasattr(model.config, 'embedding_multiplier'):
            scaled_embeddings = embeddings * model.config.embedding_multiplier
            print(f"  After embedding_multiplier ({model.config.embedding_multiplier}):")
            print(f"    mean: {scaled_embeddings.mean().item():.6f}")
            print(f"    std: {scaled_embeddings.std().item():.6f}")
        
        # Full forward pass
        outputs = model(inputs)
        logits = outputs.logits
        
        print(f"\nLogits analysis:")
        print(f"  Logits shape: {logits.shape}")
        print(f"  Logits mean: {logits.mean().item():.6f}")
        print(f"  Logits std: {logits.std().item():.6f}")
        
        # Top predictions
        last_logits = logits[0, -1, :]
        probs = torch.softmax(last_logits, dim=-1)
        top_probs, top_indices = torch.topk(probs, 10)
        
        print(f"\nTop 10 predictions:")
        for i, (idx, prob) in enumerate(zip(top_indices, top_probs)):
            token_text = repr(tokenizer.decode([idx.item()]))
            print(f"  {i+1:2d}. Token {idx.item():5d}: {token_text:20s} (prob: {prob.item():.6f})")
        
        # Layer-by-layer analysis for first few layers
        print(f"\nLayer-by-layer analysis (first 3 layers):")
        print(f"Expected layer types: {config['layer_types'][:3]}")
        
        # Note: This would require hooking into the model internals
        # For now, just show the layer structure
        print(f"Number of layers: {len(model.model.layers)}")
        
        return {
            'tokens': tokens,
            'logits': logits,
            'top_indices': top_indices.tolist(),
            'top_probs': top_probs.tolist(),
            'embeddings_stats': {
                'mean': embeddings.mean().item(),
                'std': embeddings.std().item()
            },
            'logits_stats': {
                'mean': logits.mean().item(), 
                'std': logits.std().item()
            }
        }

if __name__ == "__main__":
    results = debug_transformers()
    print(f"\n=== RESULTS SUMMARY ===")
    print(f"Top token: {results['top_indices'][0]} (prob: {results['top_probs'][0]:.6f})")
    print(f"Logits mean: {results['logits_stats']['mean']:.6f}")
    print(f"Logits std: {results['logits_stats']['std']:.6f}")
debug_layer_call.py
#!/usr/bin/env python3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import mlx.core as mx
from mlx_lm import load

def debug_layer_calling():
    model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
    prompt = 'The'
    
    print("=== LAYER CALLING DEBUG ===")
    
    # Load both models
    hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
    hf_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
    if hf_tokenizer.pad_token is None:
        hf_tokenizer.pad_token = hf_tokenizer.eos_token
    
    mlx_model, mlx_tokenizer = load(model_path)
    
    tokens = hf_tokenizer.encode(prompt)
    hf_inputs = torch.tensor([tokens])
    mlx_inputs = mx.array([tokens])
    
    with torch.no_grad():
        # Get embeddings
        hf_embeddings = hf_model.model.embed_tokens(hf_inputs) * hf_model.config.embedding_multiplier
        mlx_embeddings = mlx_model.model.embed_tokens(mlx_inputs) * mlx_model.model.embedding_multiplier
        
        print(f"Embeddings:")
        print(f"  HF:  mean={hf_embeddings.mean().item():.6f}, std={hf_embeddings.std().item():.6f}")
        print(f"  MLX: mean={mx.mean(mlx_embeddings).item():.6f}, std={mx.std(mlx_embeddings).item():.6f}")
        
        print(f"\\n=== LAYER 0 DETAILED ANALYSIS ===")
        
        # HF layer 0 - step by step
        hf_layer0 = hf_model.model.layers[0]
        print(f"HF Layer 0 attributes: {[attr for attr in dir(hf_layer0) if not attr.startswith('_')]}")
        
        # Input layernorm
        hf_normed = hf_layer0.input_layernorm(hf_embeddings)
        mlx_normed = mlx_model.layers[0].input_layernorm(mlx_embeddings)
        
        print(f"\\nAfter input_layernorm:")
        print(f"  HF:  mean={hf_normed.mean().item():.6f}, std={hf_normed.std().item():.6f}")
        print(f"  MLX: mean={mx.mean(mlx_normed).item():.6f}, std={mx.std(mlx_normed).item():.6f}")
        
        # Mamba block
        if hasattr(hf_layer0, 'mamba'):
            print(f"\\nHF has 'mamba' attribute")
            hf_mamba_out = hf_layer0.mamba(hf_normed)
            print(f"HF mamba output: mean={hf_mamba_out.mean().item():.6f}, std={hf_mamba_out.std().item():.6f}")
        else:
            print(f"\\nHF layer attributes: {dir(hf_layer0)}")
            
        mlx_mamba_out = mlx_model.layers[0].mamba(mlx_normed, cache=None)
        print(f"MLX mamba output: mean={mx.mean(mlx_mamba_out).item():.6f}, std={mx.std(mlx_mamba_out).item():.6f}")
        
        # Residual connection
        if hasattr(hf_layer0, 'mamba'):
            hf_residual = hf_embeddings + hf_mamba_out * hf_model.config.residual_multiplier
            print(f"\\nHF residual: mean={hf_residual.mean().item():.6f}, std={hf_residual.std().item():.6f}")
        
        mlx_residual = mlx_embeddings + mlx_mamba_out * mlx_model.args.residual_multiplier
        print(f"MLX residual: mean={mx.mean(mlx_residual).item():.6f}, std={mx.std(mlx_residual).item():.6f}")
        
        # Compare with calling the full layer
        print(f"\\n=== FULL LAYER CALL COMPARISON ===")
        
        # HF full layer call
        hf_full_out = hf_layer0(hf_embeddings)
        if isinstance(hf_full_out, tuple):
            hf_full_out = hf_full_out[0]  # Get just the hidden states
        print(f"HF full layer: mean={hf_full_out.mean().item():.6f}, std={hf_full_out.std().item():.6f}")
        
        # MLX full layer call
        mlx_full_out = mlx_model.layers[0](mlx_embeddings, mask=None, cache=None)
        print(f"MLX full layer: mean={mx.mean(mlx_full_out).item():.6f}, std={mx.std(mlx_full_out).item():.6f}")
        
        # Check residual multiplier values
        print(f"\\n=== CONFIG CHECK ===")
        print(f"HF residual_multiplier: {hf_model.config.residual_multiplier}")
        print(f"MLX residual_multiplier: {mlx_model.args.residual_multiplier}")

if __name__ == "__main__":
    debug_layer_calling()
debug_layers.py
#!/usr/bin/env python3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import mlx.core as mx
from mlx_lm import load

def debug_multiple_layers():
    model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
    prompt = 'The'
    
    print("=== MULTI-LAYER DEBUGGING ===")
    
    # Load both models
    print("Loading models...")
    hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
    hf_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
    if hf_tokenizer.pad_token is None:
        hf_tokenizer.pad_token = hf_tokenizer.eos_token
    
    mlx_model, mlx_tokenizer = load(model_path)
    
    # Get layer types
    layer_types = mlx_model.args.layer_types[:10]  # First 10 layers
    print(f"Layer types: {layer_types}")
    
    # Tokenize
    tokens = hf_tokenizer.encode(prompt)
    hf_inputs = torch.tensor([tokens])
    mlx_inputs = mx.array([tokens])
    
    with torch.no_grad():
        # Start with embeddings
        hf_embeddings = hf_model.model.embed_tokens(hf_inputs) * hf_model.config.embedding_multiplier
        mlx_embeddings = mlx_model.model.embed_tokens(mlx_inputs) * mlx_model.model.embedding_multiplier
        
        print(f"\\nInitial embeddings:")
        print(f"  HF:  mean={hf_embeddings.mean().item():.6f}, std={hf_embeddings.std().item():.6f}")
        print(f"  MLX: mean={mx.mean(mlx_embeddings).item():.6f}, std={mx.std(mlx_embeddings).item():.6f}")
        
        # Process through layers one by one
        hf_hidden = hf_embeddings
        mlx_hidden = mlx_embeddings
        
        for i in range(min(6, len(layer_types))):  # Check first 6 layers (including attention at layer 5)
            layer_type = layer_types[i]
            print(f"\\n--- Layer {i} ({layer_type}) ---")
            
            # HF layer
            hf_layer = hf_model.model.layers[i]
            hf_hidden = hf_layer(hf_hidden)[0]  # HF returns tuple
            
            # MLX layer  
            mlx_layer = mlx_model.layers[i]
            mlx_hidden = mlx_layer(mlx_hidden, mask=None, cache=None)
            
            print(f"  HF output:  mean={hf_hidden.mean().item():.6f}, std={hf_hidden.std().item():.6f}")
            print(f"  MLX output: mean={mx.mean(mlx_hidden).item():.6f}, std={mx.std(mlx_hidden).item():.6f}")
            
            # Check if they're still close
            diff_mean = abs(hf_hidden.mean().item() - mx.mean(mlx_hidden).item())
            diff_std = abs(hf_hidden.std().item() - mx.std(mlx_hidden).item())
            
            if diff_mean > 0.01 or diff_std > 0.01:
                print(f"  🚨 DIVERGENCE DETECTED!")
                print(f"     Mean diff: {diff_mean:.6f}, Std diff: {diff_std:.6f}")
                break
            else:
                print(f"  ✅ Still matching (mean diff: {diff_mean:.6f}, std diff: {diff_std:.6f})")
        
        print(f"\\n=== FINAL COMPARISON ===")
        
        # Run full forward pass for comparison
        hf_full_output = hf_model.model(hf_inputs)[0]  # Get hidden states
        
        # MLX forward through model.model (not the full model with lm_head)
        mlx_full_output = mlx_model.model(mlx_inputs)
        
        print(f"Full model hidden states:")
        print(f"  HF:  mean={hf_full_output.mean().item():.6f}, std={hf_full_output.std().item():.6f}")
        print(f"  MLX: mean={mx.mean(mlx_full_output).item():.6f}, std={mx.std(mlx_full_output).item():.6f}")
        
        # Apply final norm
        hf_normed = hf_model.model.norm(hf_full_output)
        mlx_normed = mlx_model.model.norm(mlx_full_output)
        
        print(f"After final norm:")
        print(f"  HF:  mean={hf_normed.mean().item():.6f}, std={hf_normed.std().item():.6f}")
        print(f"  MLX: mean={mx.mean(mlx_normed).item():.6f}, std={mx.std(mlx_normed).item():.6f}")
        
        # Apply lm_head
        if hasattr(hf_model, 'lm_head'):
            hf_logits = hf_model.lm_head(hf_normed)
        else:
            hf_logits = hf_model.model.embed_tokens.weight @ hf_normed.transpose(-2, -1)
            hf_logits = hf_logits.transpose(-2, -1)
            
        if hasattr(mlx_model, 'lm_head'):
            mlx_logits = mlx_model.lm_head(mlx_normed)
        else:
            # Use tied embeddings
            mlx_logits = mlx_model.model.embed_tokens.as_linear(mlx_normed)
            
        # Apply logits scaling
        hf_final = hf_logits / hf_model.config.logits_scaling
        mlx_final = mlx_logits / mlx_model.logits_scaling
        
        print(f"Final logits:")
        print(f"  HF:  mean={hf_final.mean().item():.6f}, std={hf_final.std().item():.6f}")
        print(f"  MLX: mean={mx.mean(mlx_final).item():.6f}, std={mx.std(mlx_final).item():.6f}")

if __name__ == "__main__":
    debug_multiple_layers()
debug_mamba.py
#!/usr/bin/env python3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import mlx.core as mx
from mlx_lm import load

def compare_mamba_layer():
    model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
    prompt = 'The'
    
    print("=== MAMBA LAYER COMPARISON ===")
    
    # Load both models
    print("Loading models...")
    hf_tokenizer = AutoTokenizer.from_pretrained(model_path)
    hf_model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
    if hf_tokenizer.pad_token is None:
        hf_tokenizer.pad_token = hf_tokenizer.eos_token
    
    mlx_model, mlx_tokenizer = load(model_path)
    
    # Tokenize
    tokens = hf_tokenizer.encode(prompt)
    print(f"Prompt: '{prompt}' -> Tokens: {tokens}")
    
    # Get initial embeddings (both should be identical)
    hf_inputs = torch.tensor([tokens])
    mlx_inputs = mx.array([tokens])
    
    with torch.no_grad():
        # HF embeddings
        hf_embeddings = hf_model.model.embed_tokens(hf_inputs)
        hf_scaled = hf_embeddings * hf_model.config.embedding_multiplier
        
        # MLX embeddings  
        mlx_embeddings = mlx_model.model.embed_tokens(mlx_inputs)
        mlx_scaled = mlx_embeddings * mlx_model.model.embedding_multiplier
        
        print(f"\nEmbedding verification:")
        print(f"  HF scaled: mean={hf_scaled.mean().item():.6f}, std={hf_scaled.std().item():.6f}")
        print(f"  MLX scaled: mean={mx.mean(mlx_scaled).item():.6f}, std={mx.std(mlx_scaled).item():.6f}")
        
        # Now test the first Mamba layer specifically
        print(f"\n=== LAYER 0 ANALYSIS (Mamba) ===")
        
        # HF first layer
        hf_layer0 = hf_model.model.layers[0]
        print(f"HF Layer 0 type: {type(hf_layer0).__name__}")
        
        # Apply input layernorm  
        hf_normed = hf_layer0.input_layernorm(hf_scaled)
        print(f"HF after input_layernorm: mean={hf_normed.mean().item():.6f}, std={hf_normed.std().item():.6f}")
        
        # Apply the Mamba block
        if hasattr(hf_layer0, 'mamba'):
            hf_mamba_out = hf_layer0.mamba(hf_normed)
            print(f"HF after mamba: mean={hf_mamba_out.mean().item():.6f}, std={hf_mamba_out.std().item():.6f}")
            
            # Add residual
            hf_residual = hf_scaled + hf_mamba_out * hf_model.config.residual_multiplier
            print(f"HF after residual: mean={hf_residual.mean().item():.6f}, std={hf_residual.std().item():.6f}")
        
        # MLX first layer
        mlx_layer0 = mlx_model.layers[0]
        print(f"\nMLX Layer 0 type: {mlx_layer0.layer_type}")
        
        # Apply input layernorm
        mlx_normed = mlx_layer0.input_layernorm(mlx_scaled)
        print(f"MLX after input_layernorm: mean={mx.mean(mlx_normed).item():.6f}, std={mx.std(mlx_normed).item():.6f}")
        
        # Apply the Mamba block
        mlx_mamba_out = mlx_layer0.mamba(mlx_normed, cache=None)
        print(f"MLX after mamba: mean={mx.mean(mlx_mamba_out).item():.6f}, std={mx.std(mlx_mamba_out).item():.6f}")
        
        # Add residual
        mlx_residual = mlx_scaled + mlx_mamba_out * mlx_model.args.residual_multiplier
        print(f"MLX after residual: mean={mx.mean(mlx_residual).item():.6f}, std={mx.std(mlx_residual).item():.6f}")
        
        # Compare the differences
        print(f"\n=== COMPARISON ===")
        print("Layer norm comparison:")
        print(f"  Difference in normalized outputs: This tells us if the issue is in normalization")
        
        print("Mamba output comparison:")
        print(f"  This will show if the Mamba implementation itself is wrong")
        
        print("Residual comparison:")
        print(f"  This shows the cumulative effect after one layer")

if __name__ == "__main__":
    compare_mamba_layer()
test_cache_mlx.py
#!/usr/bin/env python3

import mlx.core as mx
from mlx_lm import load

def test_mlx_cache():
    model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
    
    print("=== MLX CACHE TEST ===")
    
    # Load model
    model, tokenizer = load(model_path)
    
    # Test with common prefix
    prefix = "The weather is very"
    continuation1 = " hot today"
    continuation2 = " cold this winter"
    
    print(f"Prefix: '{prefix}'")
    print(f"Continuation 1: '{continuation1}'")
    print(f"Continuation 2: '{continuation2}'")
    
    # First inference - no cache
    full_text1 = prefix + continuation1
    tokens1 = tokenizer.encode(full_text1)
    inputs1 = mx.array([tokens1])
    
    print(f"\nFirst inference (no cache):")
    print(f"Full text: '{full_text1}'")
    print(f"Tokens: {tokens1}")
    
    outputs1 = model(inputs1)
    logits1 = outputs1
    
    print(f"Output logits shape: {logits1.shape}")
    
    # Second inference - with cache from prefix
    prefix_tokens = tokenizer.encode(prefix)
    cont2_tokens = tokenizer.encode(continuation2)
    # Remove BOS token from continuation if it exists
    if len(cont2_tokens) > 0 and cont2_tokens[0] == tokenizer.bos_token_id:
        cont2_tokens = cont2_tokens[1:]
    
    print(f"\nSecond inference (with cache):")
    print(f"Prefix tokens: {prefix_tokens}")
    print(f"Continuation tokens: {cont2_tokens}")
    
    # Create cache by running prefix
    cache = model.make_cache()
    prefix_inputs = mx.array([prefix_tokens])
    
    # Run prefix to populate cache
    prefix_outputs = model(prefix_inputs, cache=cache)
    
    print(f"Prefix cache created")
    print(f"Cache type: {type(cache)}")
    if hasattr(cache, 'kv_heads'):
        print(f"KV cache heads: {len(cache.kv_heads) if cache.kv_heads else 'None'}")
    if hasattr(cache, 'mamba_states'):
        print(f"Mamba cache states: {len(cache.mamba_states) if cache.mamba_states else 'None'}")
    
    # Run continuation with cache
    cont2_inputs = mx.array([cont2_tokens])
    cached_outputs = model(cont2_inputs, cache=cache)
    
    print(f"Cached inference logits shape: {cached_outputs.shape}")
    
    # Compare with full inference for the same text
    full_text2 = prefix + continuation2
    full_tokens2 = tokenizer.encode(full_text2)
    full_inputs2 = mx.array([full_tokens2])
    
    full_outputs2 = model(full_inputs2)
    
    print(f"\nComparison:")
    print(f"Full inference logits (last token): {full_outputs2[0, -1, :5].tolist()}")
    print(f"Cached inference logits (last token): {cached_outputs[0, -1, :5].tolist()}")
    
    # Check if they match
    diff = mx.abs(full_outputs2[0, -1, :] - cached_outputs[0, -1, :]).max()
    print(f"Max difference: {diff.item():.8f}")
    print(f"Cache working correctly: {diff.item() < 1e-4}")
    
    return {
        'cache_type': type(cache).__name__,
        'max_diff': diff.item(),
        'cache_working': diff.item() < 1e-4
    }

if __name__ == "__main__":
    results = test_mlx_cache()
    print(f"\n=== SUMMARY ===")
    print(f"Cache type: {results['cache_type']}")
    print(f"Max difference: {results['max_diff']:.8f}")
    print(f"Cache working: {results['cache_working']}")
test_cache_transformers.py
#!/usr/bin/env python3

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def test_transformers_cache():
    model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
    
    print("=== TRANSFORMERS CACHE TEST ===")
    
    # Load model
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float32)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Test with common prefix
    prefix = "The weather is very"
    continuation1 = " hot today"
    continuation2 = " cold this winter"
    
    print(f"Prefix: '{prefix}'")
    print(f"Continuation 1: '{continuation1}'")
    print(f"Continuation 2: '{continuation2}'")
    
    # First inference - no cache
    full_text1 = prefix + continuation1
    tokens1 = tokenizer.encode(full_text1)
    inputs1 = torch.tensor([tokens1])
    
    print(f"\nFirst inference (no cache):")
    print(f"Full text: '{full_text1}'")
    print(f"Tokens: {tokens1}")
    
    with torch.no_grad():
        outputs1 = model(inputs1, use_cache=True)
        logits1 = outputs1.logits
        past_key_values1 = outputs1.past_key_values
        
    print(f"Output logits shape: {logits1.shape}")
    print(f"Cache type: {type(past_key_values1)}")
    if past_key_values1 is not None:
        print(f"Cache length: {len(past_key_values1) if hasattr(past_key_values1, '__len__') else 'N/A'}")
        if hasattr(past_key_values1, 'get_seq_length'):
            print(f"Cache seq length: {past_key_values1.get_seq_length()}")
    
    # Second inference - with cache from prefix
    prefix_tokens = tokenizer.encode(prefix)
    cont2_tokens = tokenizer.encode(continuation2)[1:]  # Remove BOS if present
    
    print(f"\nSecond inference (with cache):")
    print(f"Prefix tokens: {prefix_tokens}")
    print(f"Continuation tokens: {cont2_tokens}")
    
    # First, run the prefix to get the cache
    prefix_inputs = torch.tensor([prefix_tokens])
    with torch.no_grad():
        prefix_outputs = model(prefix_inputs, use_cache=True)
        prefix_cache = prefix_outputs.past_key_values
        
    print(f"Prefix cache created, seq length: {prefix_cache.get_seq_length() if hasattr(prefix_cache, 'get_seq_length') else 'N/A'}")
    
    # Then run the continuation with the cache
    cont2_inputs = torch.tensor([cont2_tokens])
    with torch.no_grad():
        cached_outputs = model(cont2_inputs, past_key_values=prefix_cache, use_cache=True)
        cached_logits = cached_outputs.logits
        
    print(f"Cached inference logits shape: {cached_logits.shape}")
    
    # Compare with full inference for the same text
    full_text2 = prefix + continuation2
    full_tokens2 = tokenizer.encode(full_text2)
    full_inputs2 = torch.tensor([full_tokens2])
    
    with torch.no_grad():
        full_outputs2 = model(full_inputs2)
        full_logits2 = full_outputs2.logits
        
    print(f"\nComparison:")
    print(f"Full inference logits (last token): {full_logits2[0, -1, :5].tolist()}")
    print(f"Cached inference logits (last token): {cached_logits[0, -1, :5].tolist()}")
    
    # Check if they match
    diff = torch.abs(full_logits2[0, -1, :] - cached_logits[0, -1, :]).max()
    print(f"Max difference: {diff.item():.8f}")
    print(f"Cache working correctly: {diff.item() < 1e-4}")
    
    return {
        'cache_type': type(prefix_cache).__name__,
        'cache_seq_length': prefix_cache.get_seq_length() if hasattr(prefix_cache, 'get_seq_length') else None,
        'max_diff': diff.item(),
        'cache_working': diff.item() < 1e-4
    }

if __name__ == "__main__":
    results = test_transformers_cache()
    print(f"\n=== SUMMARY ===")
    print(f"Cache type: {results['cache_type']}")
    print(f"Cache seq length: {results['cache_seq_length']}")
    print(f"Max difference: {results['max_diff']:.8f}")
    print(f"Cache working: {results['cache_working']}")

claude-trace.jsonl.txt

@awni
Copy link
Member

awni commented Sep 10, 2025

Looks pretty good! Have you tested it? Is it functional?

@gabe-l-hart
Copy link
Contributor Author

🤦 I got so caught up dumping the Claude Code artifacts I forgot to add my testing results! Yes, this does work. Here's the simple example I use everywhere:

from mlx_lm import generate, load

model_path = '/Users/ghart/models/granite-4.0-tiny-preview/'
prompt = 'Tell me a story about a developer and their dog'
model, tokenizer = load(model_path)
result = generate(model, tokenizer, prompt=prompt, verbose=True)
==========
.

Once upon a time, there was a dedicated developer named Alex who spent countless hours crafting the perfect app. One day, Alex's loyal companion, a spirited dog named Max, decided to join them on their coding journey.

Max, a curious and energetic pup, would often find himself curled up next to Alex's laptop, occasionally nudging the keyboard with his nose. Alex, always appreciative of Max's unwavering support, would occasionally pause to pet their furry friend, who would then wag his tail in delight.

One fateful day, as Alex was working late on a crucial project, Max's curiosity got the better of him. He nudged the laptop, causing the screen to flicker and the project to crash. Panicked, Alex tried to restart the computer, but Max, in a moment of canine intuition, nudged the power button, causing it to turn off.

Realizing the importance of a well-deserved break, Alex decided to take a step back and spend quality time with Max. They went for a long walk, played fetch, and even tried out
==========
Prompt: 10 tokens, 104.643 tokens-per-sec
Generation: 256 tokens, 68.363 tokens-per-sec
Peak memory: 13.429 GB

For comparison, here's roughly the same thing with llama-cli:

./bin/llama-cli -m ~/models/granite-4.0-tiny-preview/Granite-4.0-Tiny-Preview-62x915M-F16.gguf -no-cnv -p "Tell me a story about a developer and their dog" --temp 0
Tell me a story about a developer and their dog.

Once upon a time, there was a dedicated developer named Alex who spent countless hours coding. One day, Alex's dog, Max, decided to join them at the desk. Max, a curious and energetic pup, would often find himself in the way, causing Alex to pause their work to gently move the dog aside.

One particular evening, Alex was stuck on a challenging coding problem. Frustrated, they decided to take a break and let Max curl up next to them. As Alex absentmindedly scratched Max behind the ears, a sudden realization struck them. The solution to their coding problem was not in the lines of code, but in the warmth and companionship of their furry friend.

From that day on, Alex made sure to include Max in their coding sessions, finding that the bond between them fueled their creativity and problem-solving abilities. Max, in turn, became an invaluable source of inspiration and motivation for Alex, proving that sometimes, the best ideas come from the most unexpected places.

And so, the developer and their dog continued to create beautiful code together, their bond growing stronger with each line of code they wrote. [end of text]


llama_perf_sampler_print:    sampling time =       7.33 ms /   276 runs   (    0.03 ms per token, 37643.21 tokens per second)
llama_perf_context_print:        load time =    5911.59 ms
llama_perf_context_print: prompt eval time =      65.06 ms /    10 tokens (    6.51 ms per token,   153.71 tokens per second)
llama_perf_context_print:        eval time =    6269.55 ms /   265 runs   (   23.66 ms per token,    42.27 tokens per second)

The results are not identical, so there are clearly some precision differences somewhere in the calculations, but the story of Alex and Max is consistent and coherent.

@gabe-l-hart
Copy link
Contributor Author

The specific model I'm testing with is: https://huggingface.co/ibm-granite/granite-4.0-tiny-preview

…d by Claude Code

This commit was entirely generated using Claude Code and the following
prompt:

---
I've got an in-depth feature request for you to add. I need you to add support for the GraniteMoeHybrid architecture to the `mlx-lm` project. The task is to extend the existing set of model architecture implementations in `mlx_lm/models` by adding a new module named `granitemoehybrid.py`. Here are a few key pointers on this model architecture:

* It is a hybrid-recurrent model that uses `mamba2` for some layers (recurrent) and `granitemoe` for some layers (attention)
* It is very similar to the `nemotron_h` architecture implemented in `mlx_lm/models/nemotron_h.py`, but with a few key differences
    * In `GraniteMoeHybrid`, each layer has either a `mamba2` block or a `granitemoe` attention block AND a MoE block, whereas in `nemotron_h`, each "layer" is a single block that is either `mamba2`, `attention` (llama), or `ffn` (not MoE).
    * The config for `GraniteMoeHybrid` uses the `layer_types` field to determine whether to use `mamba2` or `granitemoe` attention for each layer
* The `transformers` implementation can be found at https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py
    * The config can be found at https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/configuration_granitemoehybrid.py
* The PR adding support in `llama.cpp` is: ggml-org/llama.cpp#13550
    * NOTE: In `llama.cpp`, I made the architecture slightly more flexible such that each layer could use either a MoE block OR a fully-connected FFN block after the recurrent/attention block
* For the `granitemoe` attention, the architecture is very similar to standard `llama` attention, but it includes 4 additional scalar multipliers that are pulled from config:
    * `embedding_multiplier`:
        * Multiply the input embeddings by this scalar before the first layer
        * Used here in `transformers` https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L1347
    * `attention_multiplier`:
        * Used as the scaling factor in standard attention in place of the default 1/sqrt(n_embed_head)
        * Used here in `transformers`: https://github.com/huggingface/transformers/blob/main/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py#L217

The goal of this project is to create a fully working local implementation of the model in `mlx_lm`. You can find a local model to test with at /Users/ghart/models/granite-4.0-tiny-preview/. You can find a version of the `nemotron_h` model to test with at /Users/ghart/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2/. To accomplish this project, you'll need to take the following steps:

1. Get a development environment working (you can use `uv` to manage your virtual env) and install the necessary dependencies
2. Run a sample inference with a model that is already known to work (eg `/Users/ghart/models/nvidia/NVIDIA-Nemotron-Nano-9B-v2/`)
3. Create the new module at `mlx_lm/models/granitemoehybrid.py`
4. Implement the model architecture, test, and iterate until you've got things working locally

Once you've got it working, let me know and I'll review and commit
---

Branch: GraniteHybrid

Signed-off-by: Gabe Goodhart <[email protected]>
Inference now matches transormers. Further refinement by me comming next.

Branch: GraniteHybrid

Signed-off-by: Gabe Goodhart <[email protected]>
…odels

This keeps the implementation of the attention block closer to GraniteMoe
for an easier diff view in the future. The functionality is identical.

Branch: GraniteHybrid

Signed-off-by: Gabe Goodhart <[email protected]>
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@gabe-l-hart
Copy link
Contributor Author

Thanks for the cleanup fixes!

@awni awni merged commit 1537efd into ml-explore:main Sep 10, 2025
4 checks passed
@gabe-l-hart gabe-l-hart deleted the GraniteHybrid branch September 10, 2025 15:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants