From bea86d4b17ef98d1a94cdeb3677480447a986b9f Mon Sep 17 00:00:00 2001 From: realAsma Date: Mon, 18 May 2026 00:40:44 +0000 Subject: [PATCH] Support auto_quantize for Megatron expert parallelism Signed-off-by: realAsma --- modelopt/torch/quantization/algorithms.py | 43 +++++++++++++------ .../torch/quantization/plugins/megatron.py | 31 +++++++++++++ .../torch/quantization/quantize_common.py | 24 ++++++++--- .../quantization/plugins/test_megatron.py | 40 +++++++++++++++++ .../unit/torch/quantization/test_autoquant.py | 35 +++++++++++++++ 5 files changed, 154 insertions(+), 19 deletions(-) diff --git a/modelopt/torch/quantization/algorithms.py b/modelopt/torch/quantization/algorithms.py index 28086a1e96e..99cee86dcb2 100644 --- a/modelopt/torch/quantization/algorithms.py +++ b/modelopt/torch/quantization/algorithms.py @@ -291,13 +291,14 @@ def get_score(self, recipe: QuantRecipe) -> float: total_score += importance.cpu().item() continue - if parallel_state.expert_model_parallel_group.is_initialized(): - # TODO: Support expert model parallelism for score estimation - warnings.warn("AutoQuantize does not support expert model parallelism yet.") importance = importance.cpu() importance = DistributedProcessGroup.get_dist_syncd_obj( importance, - [parallel_state.tensor_parallel_group, parallel_state.data_parallel_group], + [ + parallel_state.tensor_parallel_group, + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + ], sum, ) total_score += importance.item() @@ -318,13 +319,12 @@ def get_cost(self, recipe: QuantRecipe) -> float: cost += weight_size * recipe.compression continue - if parallel_state.expert_model_parallel_group.is_initialized(): - # TODO: Support expert model parallelism - warnings.warn("AutoQuantize does not support expert model parallelism yet.") - weight_size = DistributedProcessGroup.get_dist_syncd_obj( weight_size, - [parallel_state.tensor_parallel_group], + [ + parallel_state.tensor_parallel_group, + parallel_state.expert_model_parallel_group, + ], sum, ) @@ -722,6 +722,15 @@ def _get_total_weight_size(modules): for module in modules ) + @staticmethod + def _get_total_weight_size_from_candidate_stats(candidate_stats): + no_quant_recipe = QuantRecipe(quant_cfg=None) + total_weight_size = 0 + for candidate_stat in candidate_stats.values(): + no_quant_idx = candidate_stat["formats"].index(no_quant_recipe) + total_weight_size += candidate_stat["costs"][no_quant_idx] + return total_weight_size + def _get_constraints_for_search(self, max_weight_size, lower_bound=None): constraints = { "weight_size_after_compression": ( @@ -744,7 +753,7 @@ def run_search(self): ) compression = self._get_formatted_weight_compression_constraint() - total_weight_size = self._get_total_weight_size(self.model.modules()) + total_weight_size = self._get_total_weight_size_from_candidate_stats(self.candidate_stats) max_weight_size = total_weight_size * compression # Run the search with stats to get the best recipe and whether the constraints are satisfied @@ -754,12 +763,16 @@ def run_search(self): best_recipe = {} best_constraints, best_scores = 0, 0 for name, best_hparam_recipe_info in best_recipe_info.items(): - # Solvers could give different solutions for the same layer across DP/TP groups even though - # the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP + # Solvers could give different solutions for the same layer across DP/TP/EP groups even though + # the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP/EP _ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state best_format = DistributedProcessGroup.get_dist_syncd_obj( best_hparam_recipe_info["format"], - [_ps.data_parallel_group, _ps.tensor_parallel_group], + [ + _ps.data_parallel_group, + _ps.tensor_parallel_group, + _ps.expert_model_parallel_group, + ], lambda a: a[0], ) @@ -1379,7 +1392,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False): effective_bits = constraints["effective_bits"] compression = effective_bits / 16.0 candidate_stats = search_state["candidate_stats"] - total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values()) + total_weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size_from_candidate_stats( + candidate_stats + ) max_weight_size = total_weight_size * compression method = search_state["method"] diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 33f11a4491b..f90d2862aef 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -18,6 +18,7 @@ import logging import types import warnings +from contextlib import contextmanager from typing import Any import megatron.core.parallel_state as mcore_parallel @@ -775,3 +776,33 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # Affine KVCache Quant bias vector. state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets) + + +def _is_supported_megatron_model(model: torch.nn.Module) -> bool: + return isinstance(model, MegatronModule) + + +@contextmanager +def _megatron_grad_ckpt_context(model: torch.nn.Module): + # Megatron configures activation recompute at model build time via TransformerConfig, + # so there is no runtime flag to flip here. + yield + + +def _is_param_grad_enabled_for_megatron(pname: str, model: torch.nn.Module) -> bool: + return "embed" in pname + + +def _register_auto_quantize_support() -> None: + # Local import breaks the circular path where algorithms imports model_calib, + # which imports _check_static_block_tp_supported from this plugin. + from ..algorithms import AutoQuantizeGradientSearcher + + AutoQuantizeGradientSearcher.register_custom_support( + _is_supported_megatron_model, + _megatron_grad_ckpt_context, + _is_param_grad_enabled_for_megatron, + ) + + +_register_auto_quantize_support() diff --git a/tests/_test_utils/torch/quantization/quantize_common.py b/tests/_test_utils/torch/quantization/quantize_common.py index 7af33fa599f..46259203b24 100644 --- a/tests/_test_utils/torch/quantization/quantize_common.py +++ b/tests/_test_utils/torch/quantization/quantize_common.py @@ -249,14 +249,28 @@ def forward_loop(model): ) -def auto_quantize_helper(model): +def auto_quantize_helper( + model, + data_loader=None, + forward_step=None, + forward_backward_step=None, + quantization_formats=None, +): + if data_loader is None: + data_loader = [model.get_dummy_input().cuda() for _ in range(2)] + if forward_step is None: + forward_step = lambda model, batch: model(batch) # noqa: E731 + if forward_backward_step is None: + forward_backward_step = lambda model, batch: model(batch).sum().backward() # noqa: E731 + if quantization_formats is None: + quantization_formats = [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG] model, search_state = mtq.auto_quantize( model, constraints={"effective_bits": 8.0}, - quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG], - data_loader=[model.get_dummy_input().cuda() for _ in range(2)], - forward_step=lambda model, batch: model(batch), - forward_backward_step=lambda model, batch: model(batch).sum().backward(), + quantization_formats=quantization_formats, + data_loader=data_loader, + forward_step=forward_step, + forward_backward_step=forward_backward_step, num_calib_steps=2, num_score_steps=2, verbose=True, diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index c34fb2df376..3b7985c817b 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -26,6 +26,7 @@ from _test_utils.torch.megatron.utils import ( compare_amax_sync_across_expert_parallel, copy_weights_from_grouped_to_non_grouped, + get_batch, get_forward, initialize_for_megatron, run_mcore_inference, @@ -694,6 +695,45 @@ def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg): ) +def _test_auto_quantize_moe_ep_helper(rank, size): + initialize_for_megatron( + tensor_model_parallel_size=1, + expert_model_parallel_size=size, + seed=SEED, + ) + model = _gpt_model_provider( + tp_size=1, + ep_size=size, + hidden_size=32, + num_moe_experts=4, + moe_grouped_gemm=False, + transformer_impl="modelopt", + ) + + def forward_step(model, batch): + input_ids, labels, position_ids, attention_mask, loss_mask = batch + return model.forward( + input_ids=input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + labels=labels, + loss_mask=loss_mask, + ) + + auto_quantize_helper( + model, + data_loader=[get_batch(model, batch_size=2) for _ in range(2)], + forward_step=forward_step, + forward_backward_step=lambda m, b: forward_step(m, b).mean().backward(), + quantization_formats=[mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG], + ) + + +def test_auto_quantize_moe_ep(dist_workers_size_2): + """auto_quantize must sum score/cost across EP ranks and pick a consistent recipe.""" + dist_workers_size_2.run(_test_auto_quantize_moe_ep_helper) + + @pytest.mark.parametrize("ep_size", [1, 2]) @pytest.mark.parametrize("sync_weight_amax", [True, False]) def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax): diff --git a/tests/unit/torch/quantization/test_autoquant.py b/tests/unit/torch/quantization/test_autoquant.py index 87ec73291e7..f7b4965cf8f 100644 --- a/tests/unit/torch/quantization/test_autoquant.py +++ b/tests/unit/torch/quantization/test_autoquant.py @@ -24,8 +24,10 @@ import modelopt.torch.opt as mto import modelopt.torch.quantization as mtq from modelopt.torch.quantization.algorithms import ( + AutoQuantizeGradientSearcher, QuantRecipe, QuantRecipeHparam, + _AutoQuantizeBaseSearcher, estimate_quant_compression, ) from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg @@ -305,6 +307,39 @@ def test_data_parallel_auto_quantize(skip_on_windows): spawn_multiprocess_job(4, _test_data_parallel_auto_quantize, backend="gloo") +def test_auto_quantize_budget_uses_no_quant_candidate_cost(monkeypatch): + class _BudgetCaptureSearcher(AutoQuantizeGradientSearcher): + def run_search_with_stats(self, max_weight_size, verbose=False): + self.max_weight_size = max_weight_size + return {}, True + + def _raise_local_total_weight_size(modules): + pytest.fail("run_search should derive total weight size from candidate costs") + + monkeypatch.setattr( + _AutoQuantizeBaseSearcher, + "_get_total_weight_size", + staticmethod(_raise_local_total_weight_size), + ) + + searcher = _BudgetCaptureSearcher() + searcher.reset_search() + searcher.model = torch.nn.Module() + searcher.config = {"verbose": False} + searcher.constraints = {"effective_bits": 8.0} + searcher.candidate_stats = { + "local_expert.quant_recipe": { + "formats": [QuantRecipe(mtq.NVFP4_DEFAULT_CFG), QuantRecipe(None)], + "scores": [1.0, 0.0], + "costs": [25.0, 100.0], + } + } + + searcher.run_search() + + assert searcher.max_weight_size == 50.0 + + def test_estimate_quant_compression(): nvfp4_affine_kv_cfg = mtq.config.QuantizeConfig(**mtq.NVFP4_AFFINE_KV_CFG) assert estimate_quant_compression(nvfp4_affine_kv_cfg) == 0.25