[Quantization] - Added uses_meta_device_weights to quant config#34645
[Quantization] - Added uses_meta_device_weights to quant config#34645vllm-bot merged 7 commits intovllm-project:mainfrom
Conversation
Signed-off-by: Josephasafg <ajgard7@gmail.com>
Signed-off-by: Josephasafg <ajgard7@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request introduces a uses_meta_device_weights method to the quantization configuration, providing a more robust way to handle online quantization. The change refactors a hardcoded check for fp8 quantization to a more generic mechanism. The implementation is sound, but there's a critical issue where a None value for model_config.quantization could cause a crash. I've provided a suggestion to handle this case gracefully.
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
@vkuzo Thanks for the review! Who should trigger the CI? |
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
@Josephasafg @vkuzo I would prefer to keep the information on the linear method itself rather than the top-level quant config. What do you think about this proposal
Something like this: def initialize_dummy_weights(model, model_config, ...):
meta_device_params: set[int] = set()
for module in model.modules():
qm = getattr(module, "quant_method", None)
if qm is not None and getattr(qm, "uses_meta_device", False):
for param in module.parameters(recurse=False):
meta_device_params.add(id(param))
for param in model.state_dict().values():
if id(param) in meta_device_params \
and param.device == torch.device("meta"):
continue
initialize_single_dummy_weight(param, low, high, seed) |
Signed-off-by: Josephasafg <ajgard7@gmail.com>
|
@mgoin @vkuzo I made the change but made it a little simpler. How does this look? def initialize_dummy_weights(
model: torch.nn.Module,
model_config: ModelConfig,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
def uses_meta_device(module: torch.nn.Module) -> bool:
quant_method = getattr(module, "quant_method", None)
return getattr(quant_method, "uses_meta_device", False)
has_online_quant = any(uses_meta_device(m) for m in model.modules())
for param in model.state_dict().values():
if has_online_quant and param.device == torch.device("meta"):
# For online quantization, weights are created on meta device and
# dummy weight init will happen in `process_weights_after_loading`.
continue
initialize_single_dummy_weight(param, low, high, seed) |
mgoin
left a comment
There was a problem hiding this comment.
Nice work, I'm quite happy with this!
…-project#34645) Signed-off-by: Josephasafg <ajgard7@gmail.com> Signed-off-by: Jason Ozuzu <jasonozuzu@cohere.com>
…-project#34645) Signed-off-by: Josephasafg <ajgard7@gmail.com> Signed-off-by: zjy0516 <riverclouds.zhu@qq.com>
…-project#34645) Signed-off-by: Josephasafg <ajgard7@gmail.com>
…-project#34645) Signed-off-by: Josephasafg <ajgard7@gmail.com>
…-project#34645) Signed-off-by: Josephasafg <ajgard7@gmail.com> Signed-off-by: Andrii Skliar <askliar@nvidia.com>
…-project#34645) Signed-off-by: Josephasafg <ajgard7@gmail.com>
Purpose
As more quant methods are starting to support online quantization we need a more robust way to check that they are loading dummy weights in the same way using
process_weights_after_loadingTest Plan
Test Result
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.