Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -361,10 +361,11 @@ jobs:
- script: L2_Launch_models_nemotron_vl
- script: L2_Launch_models_olmoe
- script: L2_Launch_models_qwen
# - script: L2_Launch_models_qwen_quantization
- script: L2_Launch_models_qwen_quantization
- script: L2_Launch_models_qwen_vl
- script: L2_Launch_recipes_gemma_vl
- script: L2_Launch_recipes_gpt_oss
- script: L2_Launch_models_qwen_vl_quantization
- script: L2_Launch_recipes_llama_1b
- script: L2_Launch_recipes_llama_3b
- script: L2_Launch_recipes_llama_distill
Expand Down
17 changes: 16 additions & 1 deletion examples/quantization/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

from megatron.bridge import AutoBridge
from megatron.bridge.models.decorators import torchrun_main
from megatron.bridge.models.hf_pretrained.utils import is_safe_repo


warnings.filterwarnings("ignore")
Expand All @@ -61,6 +62,7 @@ def main(
export_dir: str = "./hf_export",
export_extra_modules: bool = False,
dtype: str = "bfloat16",
trust_remote_code: bool | None = None,
) -> None:
"""Export a quantized Megatron-LM checkpoint to HuggingFace format on multiple GPUs."""
if os.environ.get("WORLD_SIZE") is None:
Expand All @@ -78,7 +80,13 @@ def main(
sys.exit(1)

# Initialize bridge from HF model to get tokenizer and model structure
bridge = AutoBridge.from_hf_pretrained(hf_model_id)
bridge = AutoBridge.from_hf_pretrained(
hf_model_id,
trust_remote_code=is_safe_repo(
trust_remote_code=trust_remote_code,
hf_path=hf_model_id,
),
)

# Get model provider and configure for multi-GPU execution
model_provider = bridge.to_megatron_provider(load_weights=False)
Expand Down Expand Up @@ -152,6 +160,7 @@ def main(
export_extra_modules=export_extra_modules_flag,
dtype=torch_dtype,
export_dir=export_dir,
trust_remote_code=is_safe_repo(trust_remote_code=trust_remote_code, hf_path=hf_model_id),
)

if is_rank_0:
Expand Down Expand Up @@ -195,6 +204,11 @@ def main(
choices=["bfloat16", "float16", "float32"],
help="Data type for export",
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="if trust_remote_code",
)

args = parser.parse_args()
main(
Expand All @@ -207,6 +221,7 @@ def main(
args.export_dir,
args.export_extra_modules,
args.dtype,
args.trust_remote_code,
)

if torch.distributed.is_initialized():
Expand Down
41 changes: 30 additions & 11 deletions examples/quantization/ptq_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
If someone accidentally breaks the quantization loading logic (e.g., in
has_modelopt_state or build_and_load_model), this check will catch it.

We check for QuantRowParallelLinear and QuantColumnParallelLinear as these
are present in all quantized model architectures (GPT, Llama, Qwen, Nemotron-H, etc).
We check for quantized layer types that indicate successful quantization:
- Local spec: QuantRowParallelLinear, QuantColumnParallelLinear
- TE spec: QuantTERowParallelLinear, QuantTELayerNormColumnParallelLinear

Args:
model: The unwrapped model to validate
Expand All @@ -68,25 +69,36 @@ def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
Raises:
RuntimeError: If the model doesn't contain expected quantized layers
"""
# Check for quantized layer types that are universal across all architectures
model_str = str(model)

required_quant_layers = [
# Local spec quantized layers
local_spec_layers = [
"QuantRowParallelLinear",
"QuantColumnParallelLinear",
]

missing_layers = [layer for layer in required_quant_layers if layer not in model_str]
# TE spec quantized layers
te_spec_layers = [
"QuantTERowParallelLinear",
"QuantTELayerNormColumnParallelLinear",
]

# Check if model has local spec quantized layers
has_local_spec = all(layer in model_str for layer in local_spec_layers)

# Check if model has TE spec quantized layers
has_te_spec = all(layer in model_str for layer in te_spec_layers)

if missing_layers:
if not has_local_spec and not has_te_spec:
error_msg = (
f"\n{'=' * 80}\n"
f"QUANTIZATION VALIDATION FAILED!\n"
f"{'=' * 80}\n"
f"Expected quantized layers not found in the loaded model.\n"
f"This indicates the quantized checkpoint was not loaded correctly.\n\n"
f"Missing: {missing_layers}\n"
f"Expected: {required_quant_layers}\n\n"
f"Expected one of:\n"
f" - Local spec: {local_spec_layers}\n"
f" - TE spec: {te_spec_layers}\n\n"
f"This is likely due to a bug in the checkpoint loading logic.\n"
f"{'=' * 80}\n"
)
Expand All @@ -95,9 +107,16 @@ def _validate_quantized_model(model: torch.nn.Module, is_rank_0: bool) -> None:
raise RuntimeError(error_msg)

if is_rank_0:
console.print(
"[green]✓ Quantization validation passed: Found QuantRowParallelLinear and QuantColumnParallelLinear[/green]"
)
if has_te_spec:
console.print(
"[green]✓ Quantization validation passed: Found TE spec quantized layers "
"(QuantTERowParallelLinear, QuantTELayerNormColumnParallelLinear)[/green]"
)
else:
console.print(
"[green]✓ Quantization validation passed: Found local spec quantized layers "
"(QuantRowParallelLinear, QuantColumnParallelLinear)[/green]"
)


@torchrun_main
Expand Down
Loading
Loading