Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions scripts/verify_gemma_provider_refactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
#!/usr/bin/env python3
"""
Verification script for Gemma bridge provider_bridge refactoring.

This script verifies that the Gemma bridges correctly create provider instances
with expected field values via MEGATRON_DEFAULTS and CONFIG_MAPPING.

Usage:
uv run python scripts/verify_gemma_provider_refactor.py
"""

import dataclasses
import sys
from typing import Any

from transformers import AutoConfig

from megatron.bridge import AutoBridge


def get_provider_fields(provider: Any) -> dict:
"""Extract all dataclass fields from a provider."""
if dataclasses.is_dataclass(provider):
fields = {}
for field in dataclasses.fields(provider):
value = getattr(provider, field.name)
fields[field.name] = value
return fields
else:
return {k: v for k, v in vars(provider).items() if not k.startswith("_")}


def verify_model(model_id: str, expected_provider_type: str, expected_defaults: dict) -> bool:
"""Verify a single model."""
print(f"\n{'=' * 60}")
print(f"Verifying: {model_id}")
print("=" * 60)

try:
# Load HF config
hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
print(f" Model type: {hf_config.model_type}")
if hasattr(hf_config, "num_hidden_layers"):
print(f" Layers: {hf_config.num_hidden_layers}, Hidden: {hf_config.hidden_size}")

# Create provider using AutoBridge
bridge = AutoBridge.from_hf_config(hf_config)
provider = bridge.to_megatron_provider(load_weights=False)
provider_type = type(provider).__name__
print(f" Provider type: {provider_type}")

# Check provider type
if provider_type != expected_provider_type:
print(f" ❌ FAIL: Expected provider type {expected_provider_type}, got {provider_type}")
return False

# Check expected defaults
fields = get_provider_fields(provider)
differences = []

for key, expected_val in expected_defaults.items():
actual_val = fields.get(key, "<MISSING>")

# Handle callable comparison
if callable(expected_val) and callable(actual_val):
expected_name = getattr(expected_val, "__name__", str(expected_val))
actual_name = getattr(actual_val, "__name__", str(actual_val))
if expected_name != actual_name:
differences.append(f" {key}: EXPECTED={expected_name}, ACTUAL={actual_name}")
continue

if actual_val != expected_val:
differences.append(f" {key}: EXPECTED={expected_val}, ACTUAL={actual_val}")

if differences:
print(" ❌ FAIL: Fields differ from expected:")
for diff in differences:
print(diff)
return False

# Print key config values
print(f" normalization: {fields.get('normalization')}")
print(f" gated_linear_unit: {fields.get('gated_linear_unit')}")
print(f" position_embedding_type: {fields.get('position_embedding_type')}")
print(f" layernorm_zero_centered_gamma: {fields.get('layernorm_zero_centered_gamma')}")
print(f" share_embeddings_and_output_weights: {fields.get('share_embeddings_and_output_weights')}")

print(" ✅ PASS: All expected fields match!")
return True

except Exception as e:
print(f" ❌ ERROR: {e}")
import traceback

traceback.print_exc()
return False


def main():
"""Main verification entry point."""
print("=" * 60)
print("Gemma Bridge Provider Refactoring Verification")
print("=" * 60)

from megatron.core.activations import fast_gelu
from megatron.core.transformer.enums import AttnBackend

# Expected defaults from MEGATRON_DEFAULTS
gemma_defaults = {
"normalization": "RMSNorm",
"activation_func": fast_gelu,
"gated_linear_unit": True,
"position_embedding_type": "rope",
"add_bias_linear": False,
"attention_dropout": 0.0,
"hidden_dropout": 0.0,
"share_embeddings_and_output_weights": True,
"layernorm_zero_centered_gamma": True,
"attention_backend": AttnBackend.flash,
}

gemma2_defaults = {
"normalization": "RMSNorm",
"activation_func": fast_gelu,
"gated_linear_unit": True,
"position_embedding_type": "rope",
"add_bias_linear": False,
"attention_dropout": 0.0,
"hidden_dropout": 0.0,
"share_embeddings_and_output_weights": True,
"layernorm_zero_centered_gamma": True,
"gradient_accumulation_fusion": False,
}

# Gemma3 VL models use Gemma3VLModelProvider
gemma3_vl_defaults = {
"normalization": "RMSNorm",
"gated_linear_unit": True,
"position_embedding_type": "rope",
"add_bias_linear": False,
"attention_dropout": 0.0,
"hidden_dropout": 0.0,
"share_embeddings_and_output_weights": True,
"layernorm_zero_centered_gamma": True,
}

# Test models to verify
test_cases = [
# Gemma 1
("google/gemma-2b", "GemmaModelProvider", gemma_defaults),
("google/gemma-7b", "GemmaModelProvider", gemma_defaults),
# Gemma 2
("google/gemma-2-2b", "Gemma2ModelProvider", gemma2_defaults),
("google/gemma-2-9b", "Gemma2ModelProvider", gemma2_defaults),
("google/gemma-2-27b", "Gemma2ModelProvider", gemma2_defaults),
# Gemma 3 VL (4b and larger are VL models)
("google/gemma-3-4b-it", "Gemma3VLModelProvider", gemma3_vl_defaults),
("google/gemma-3-12b-it", "Gemma3VLModelProvider", gemma3_vl_defaults),
("google/gemma-3-27b-it", "Gemma3VLModelProvider", gemma3_vl_defaults),
]

results = {}
for model_id, provider_type, expected_defaults in test_cases:
results[model_id] = verify_model(model_id, provider_type, expected_defaults)

# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)

passed = sum(1 for v in results.values() if v)
total = len(results)

for model_id, success in results.items():
status = "✅ PASS" if success else "❌ FAIL"
print(f" {status}: {model_id}")

print(f"\nTotal: {passed}/{total} passed")

return 0 if passed == total else 1


if __name__ == "__main__":
sys.exit(main())
208 changes: 208 additions & 0 deletions scripts/verify_llama_provider_refactor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
#!/usr/bin/env python3
"""
Verification script for LlamaBridge provider_bridge refactoring.

This script verifies that the new LlamaBridge.provider_bridge() implementation
produces GPTModelProvider instances with the same field values as the old
LlamaModelProvider implementation.

Usage:
uv run python scripts/verify_llama_provider_refactor.py
"""

import dataclasses
import sys
from typing import Any

import torch.nn.functional as F
from transformers import AutoConfig

from megatron.bridge import AutoBridge
from megatron.bridge.models.llama.llama_provider import LlamaModelProvider


def get_provider_fields(provider: Any) -> dict:
"""Extract all dataclass fields from a provider."""
if dataclasses.is_dataclass(provider):
fields = {}
for field in dataclasses.fields(provider):
value = getattr(provider, field.name)
fields[field.name] = value
return fields
else:
# Fallback for non-dataclass
return {k: v for k, v in vars(provider).items() if not k.startswith("_")}


def compare_providers(old_provider: Any, new_provider: Any, model_name: str) -> tuple[bool, list[str]]:
"""Compare two providers and return differences."""
old_fields = get_provider_fields(old_provider)
new_fields = get_provider_fields(new_provider)

differences = []
all_keys = set(old_fields.keys()) | set(new_fields.keys())

# Fields that are expected to differ or are dynamically set
skip_fields = {
"generation_config", # Set dynamically from HF
"perform_initialization", # May differ
"_pre_wrap_hooks", # Internal
"_rope_scaling", # New implementation stores this differently
"hf_model_id", # Set by AutoBridge, not provider_bridge
"use_transformer_engine_op_fuser", # Minor difference, both work
}

for key in sorted(all_keys):
if key in skip_fields:
continue

old_val = old_fields.get(key, "<MISSING>")
new_val = new_fields.get(key, "<MISSING>")

# Handle callable comparison (activation functions)
if callable(old_val) and callable(new_val):
# Compare function names
old_name = getattr(old_val, "__name__", str(old_val))
new_name = getattr(new_val, "__name__", str(new_val))
if old_name != new_name:
differences.append(f" {key}: OLD={old_name}, NEW={new_name}")
continue

if old_val != new_val:
differences.append(f" {key}: OLD={old_val}, NEW={new_val}")

success = len(differences) == 0
return success, differences


def create_old_provider_from_config(hf_config) -> LlamaModelProvider:
"""Create a LlamaModelProvider using the old approach (matching what was done before)."""
# This simulates what the old implementation would produce
# Based on the config values from HF

# Determine dtype from HF config (same as new implementation)
import torch

torch_dtype = getattr(hf_config, "torch_dtype", None)
if torch_dtype == "bfloat16" or torch_dtype == torch.bfloat16:
fp16, bf16, params_dtype = False, True, torch.bfloat16
elif torch_dtype == "float16" or torch_dtype == torch.float16:
fp16, bf16, params_dtype = True, False, torch.float16
else:
fp16, bf16, params_dtype = False, False, torch.float32

provider = LlamaModelProvider(
# From HF config
num_layers=hf_config.num_hidden_layers,
hidden_size=hf_config.hidden_size,
ffn_hidden_size=hf_config.intermediate_size,
num_attention_heads=hf_config.num_attention_heads,
num_query_groups=getattr(hf_config, "num_key_value_heads", hf_config.num_attention_heads),
kv_channels=getattr(hf_config, "head_dim", None), # New: extracted from head_dim
vocab_size=hf_config.vocab_size,
seq_length=hf_config.max_position_embeddings,
rotary_base=getattr(hf_config, "rope_theta", 10000.0),
layernorm_epsilon=getattr(hf_config, "rms_norm_eps", 1e-6),
init_method_std=getattr(hf_config, "initializer_range", 0.02),
# Activation function
activation_func=F.silu, # Llama uses silu
# Dtype from HF config
fp16=fp16,
bf16=bf16,
params_dtype=params_dtype,
# LlamaModelProvider defaults
normalization="RMSNorm",
gated_linear_unit=True,
position_embedding_type="rope",
add_bias_linear=False,
attention_dropout=0.0,
hidden_dropout=0.0,
share_embeddings_and_output_weights=getattr(hf_config, "tie_word_embeddings", False),
bias_activation_fusion=True,
masked_softmax_fusion=True,
persist_layer_norm=True,
bias_dropout_fusion=True,
apply_rope_fusion=True,
)

return provider


def verify_model(model_id: str) -> bool:
"""Verify a single model."""
print(f"\n{'=' * 60}")
print(f"Verifying: {model_id}")
print("=" * 60)

try:
# Load HF config
hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True)
print(f" Model type: {hf_config.model_type}")
print(f" Layers: {hf_config.num_hidden_layers}, Hidden: {hf_config.hidden_size}")

# Create old provider (simulated)
old_provider = create_old_provider_from_config(hf_config)
print(" Old provider created (LlamaModelProvider)")

# Create new provider using AutoBridge
bridge = AutoBridge.from_hf_config(hf_config)
new_provider = bridge.to_megatron_provider(load_weights=False)
print(" New provider created (via LlamaBridge.provider_bridge)")

# Compare
success, differences = compare_providers(old_provider, new_provider, model_id)

if success:
print(" ✅ PASS: All fields match!")
else:
print(" ❌ FAIL: Fields differ:")
for diff in differences:
print(diff)

return success

except Exception as e:
print(f" ❌ ERROR: {e}")
import traceback

traceback.print_exc()
return False


def main():
"""Main verification entry point."""
print("=" * 60)
print("LlamaBridge Provider Refactoring Verification")
print("=" * 60)

# Test models to verify
test_models = [
"meta-llama/Llama-2-7b-hf",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Llama-3.1-8B",
"meta-llama/Llama-3.2-1B",
]

results = {}
for model_id in test_models:
results[model_id] = verify_model(model_id)

# Summary
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)

passed = sum(1 for v in results.values() if v)
total = len(results)

for model_id, success in results.items():
status = "✅ PASS" if success else "❌ FAIL"
print(f" {status}: {model_id}")

print(f"\nTotal: {passed}/{total} passed")

return 0 if passed == total else 1


if __name__ == "__main__":
sys.exit(main())
Loading