Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions modelopt/torch/quantization/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def get_score(self, recipe: QuantRecipe) -> float:
importance,
[
parallel_state.tensor_parallel_group,
parallel_state.expert_model_parallel_group,
parallel_state.data_parallel_group,
parallel_state.expert_model_parallel_group,
],
sum,
)
Expand Down Expand Up @@ -722,6 +722,15 @@ def _get_total_weight_size(modules):
for module in modules
)

@staticmethod
def _get_total_weight_size_from_candidate_stats(candidate_stats):
no_quant_recipe = QuantRecipe(quant_cfg=None)
total_weight_size = 0
for candidate_stat in candidate_stats.values():
no_quant_idx = candidate_stat["formats"].index(no_quant_recipe)
total_weight_size += candidate_stat["costs"][no_quant_idx]
return total_weight_size

def _get_constraints_for_search(self, max_weight_size, lower_bound=None):
constraints = {
"weight_size_after_compression": (
Expand All @@ -744,7 +753,7 @@ def run_search(self):
)

compression = self._get_formatted_weight_compression_constraint()
total_weight_size = self._get_total_weight_size(self.model.modules())
total_weight_size = self._get_total_weight_size_from_candidate_stats(self.candidate_stats)
max_weight_size = total_weight_size * compression

# Run the search with stats to get the best recipe and whether the constraints are satisfied
Expand All @@ -754,12 +763,16 @@ def run_search(self):
best_recipe = {}
best_constraints, best_scores = 0, 0
for name, best_hparam_recipe_info in best_recipe_info.items():
# Solvers could give different solutions for the same layer across DP/TP groups even though
# the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP
# Solvers could give different solutions for the same layer across DP/TP/EP groups even though
# the scores and costs are the same. Lets make sure the same recipe is selected across DP/TP/EP
_ps = self.model.get_submodule(name.split(".quant_recipe")[0]).parallel_state
best_format = DistributedProcessGroup.get_dist_syncd_obj(
best_hparam_recipe_info["format"],
[_ps.data_parallel_group, _ps.tensor_parallel_group],
[
_ps.data_parallel_group,
_ps.tensor_parallel_group,
_ps.expert_model_parallel_group,
],
lambda a: a[0],
)

Expand Down Expand Up @@ -1379,7 +1392,9 @@ def _resolve_best_recipe(search_state, constraints, verbose=False):
effective_bits = constraints["effective_bits"]
compression = effective_bits / 16.0
candidate_stats = search_state["candidate_stats"]
total_weight_size = sum(s["costs"][-1] for s in candidate_stats.values())
total_weight_size = _AutoQuantizeBaseSearcher._get_total_weight_size_from_candidate_stats(
candidate_stats
)
max_weight_size = total_weight_size * compression
method = search_state["method"]

Expand Down
31 changes: 31 additions & 0 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import logging
import types
import warnings
from contextlib import contextmanager
from typing import Any

import megatron.core.parallel_state as mcore_parallel
Expand Down Expand Up @@ -775,3 +776,33 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
# Affine KVCache Quant bias vector.
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets)


def _is_supported_megatron_model(model: torch.nn.Module) -> bool:
return isinstance(model, MegatronModule)


@contextmanager
def _megatron_grad_ckpt_context(model: torch.nn.Module):
# Megatron configures activation recompute at model build time via TransformerConfig,
# so there is no runtime flag to flip here.
yield


def _is_param_grad_enabled_for_megatron(pname: str, model: torch.nn.Module) -> bool:
return "embed" in pname


def _register_auto_quantize_support() -> None:
# Local import breaks the circular path where algorithms imports model_calib,
# which imports _check_static_block_tp_supported from this plugin.
from ..algorithms import AutoQuantizeGradientSearcher

AutoQuantizeGradientSearcher.register_custom_support(
_is_supported_megatron_model,
_megatron_grad_ckpt_context,
_is_param_grad_enabled_for_megatron,
)


_register_auto_quantize_support()
24 changes: 19 additions & 5 deletions tests/_test_utils/torch/quantization/quantize_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,14 +249,28 @@ def forward_loop(model):
)


def auto_quantize_helper(model):
def auto_quantize_helper(
model,
data_loader=None,
forward_step=None,
forward_backward_step=None,
quantization_formats=None,
):
if data_loader is None:
data_loader = [model.get_dummy_input().cuda() for _ in range(2)]
if forward_step is None:
forward_step = lambda model, batch: model(batch) # noqa: E731
if forward_backward_step is None:
forward_backward_step = lambda model, batch: model(batch).sum().backward() # noqa: E731
if quantization_formats is None:
quantization_formats = [mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG]
model, search_state = mtq.auto_quantize(
model,
constraints={"effective_bits": 8.0},
quantization_formats=[mtq.INT4_BLOCKWISE_WEIGHT_ONLY_CFG, mtq.INT8_DEFAULT_CFG],
data_loader=[model.get_dummy_input().cuda() for _ in range(2)],
forward_step=lambda model, batch: model(batch),
forward_backward_step=lambda model, batch: model(batch).sum().backward(),
quantization_formats=quantization_formats,
data_loader=data_loader,
forward_step=forward_step,
forward_backward_step=forward_backward_step,
num_calib_steps=2,
num_score_steps=2,
verbose=True,
Expand Down
40 changes: 40 additions & 0 deletions tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from _test_utils.torch.megatron.utils import (
compare_amax_sync_across_expert_parallel,
copy_weights_from_grouped_to_non_grouped,
get_batch,
get_forward,
initialize_for_megatron,
run_mcore_inference,
Expand Down Expand Up @@ -694,6 +695,45 @@ def test_te_grouped_vs_sequential_quantize(dist_workers_size_4, quant_cfg):
)


def _test_auto_quantize_moe_ep_helper(rank, size):
initialize_for_megatron(
tensor_model_parallel_size=1,
expert_model_parallel_size=size,
seed=SEED,
)
model = _gpt_model_provider(
tp_size=1,
ep_size=size,
hidden_size=32,
num_moe_experts=4,
moe_grouped_gemm=False,
transformer_impl="modelopt",
)

def forward_step(model, batch):
input_ids, labels, position_ids, attention_mask, loss_mask = batch
return model.forward(
input_ids=input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
labels=labels,
loss_mask=loss_mask,
)

auto_quantize_helper(
model,
data_loader=[get_batch(model, batch_size=2) for _ in range(2)],
forward_step=forward_step,
forward_backward_step=lambda m, b: forward_step(m, b).mean().backward(),
quantization_formats=[mtq.NVFP4_DEFAULT_CFG, mtq.FP8_DEFAULT_CFG],
)


def test_auto_quantize_moe_ep(dist_workers_size_2):
"""auto_quantize must sum score/cost across EP ranks and pick a consistent recipe."""
dist_workers_size_2.run(_test_auto_quantize_moe_ep_helper)


@pytest.mark.parametrize("ep_size", [1, 2])
@pytest.mark.parametrize("sync_weight_amax", [True, False])
def test_layer_sync_moe_local_experts_amax(dist_workers, ep_size, sync_weight_amax):
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/torch/quantization/test_autoquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.algorithms import (
AutoQuantizeGradientSearcher,
QuantRecipe,
QuantRecipeHparam,
_AutoQuantizeBaseSearcher,
estimate_quant_compression,
)
from modelopt.torch.quantization.config import _base_disable_all, _default_disabled_quantizer_cfg
Expand Down Expand Up @@ -305,6 +307,39 @@ def test_data_parallel_auto_quantize(skip_on_windows):
spawn_multiprocess_job(4, _test_data_parallel_auto_quantize, backend="gloo")


def test_auto_quantize_budget_uses_no_quant_candidate_cost(monkeypatch):
class _BudgetCaptureSearcher(AutoQuantizeGradientSearcher):
def run_search_with_stats(self, max_weight_size, verbose=False):
self.max_weight_size = max_weight_size
return {}, True

def _raise_local_total_weight_size(modules):
pytest.fail("run_search should derive total weight size from candidate costs")

monkeypatch.setattr(
_AutoQuantizeBaseSearcher,
"_get_total_weight_size",
staticmethod(_raise_local_total_weight_size),
)

searcher = _BudgetCaptureSearcher()
searcher.reset_search()
searcher.model = torch.nn.Module()
searcher.config = {"verbose": False}
searcher.constraints = {"effective_bits": 8.0}
searcher.candidate_stats = {
"local_expert.quant_recipe": {
"formats": [QuantRecipe(mtq.NVFP4_DEFAULT_CFG), QuantRecipe(None)],
"scores": [1.0, 0.0],
"costs": [25.0, 100.0],
}
}

searcher.run_search()

assert searcher.max_weight_size == 50.0


def test_estimate_quant_compression():
nvfp4_affine_kv_cfg = mtq.config.QuantizeConfig(**mtq.NVFP4_AFFINE_KV_CFG)
assert estimate_quant_compression(nvfp4_affine_kv_cfg) == 0.25
Expand Down
Loading