-
Notifications
You must be signed in to change notification settings - Fork 1k
[Quantization] Enable FP8 online quantization for Z-image text encoder #1338
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e790e33
0fe72c7
51070f4
4bb6989
f6ebce1
8a6d59f
59eb19d
c0b8d16
d37a560
e0810eb
fd75eae
d35ba01
e81520f
53c28f5
b644f56
75fd6f7
805507a
bfd7cd2
44b6e06
865218b
9cb7a5f
5f12f2a
65e8f4d
2248cce
0b6fe43
570bbd3
97f2466
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,132 @@ | |
|
|
||
| import json | ||
| import os | ||
| from typing import TYPE_CHECKING, Literal | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| ReplicatedLinear, | ||
| RowParallelLinear, | ||
| ) | ||
| from vllm.model_executor.models.transformers.utils import init_on_device_without_buffers | ||
| from vllm.model_executor.models.utils import maybe_prefix | ||
|
|
||
| from vllm_omni.diffusion.data import OmniDiffusionConfig | ||
| from vllm_omni.diffusion.quantization import get_vllm_quant_config_for_layers | ||
|
|
||
| if TYPE_CHECKING: | ||
| from transformers import PretrainedConfig, PreTrainedModel | ||
| from transformers.models.auto.auto_factory import _BaseAutoModelClass | ||
| from vllm.model_executor.layers.quantization.base_config import ( | ||
| QuantizationConfig, | ||
| ) | ||
|
|
||
|
|
||
| Style = Literal["colwise", "colwise_rep", "rowwise", "rowwise_rep", "replicate"] | ||
|
|
||
|
|
||
| def replace_linear_class( | ||
| linear: nn.Linear, | ||
| style: Style = "replicate", | ||
| quant_config: QuantizationConfig | None = None, | ||
| *, | ||
| prefix: str = "", | ||
| ) -> ColumnParallelLinear | RowParallelLinear | ReplicatedLinear: | ||
| """ | ||
| Replace nn.Linear with one of vLLM's tensor parallel linear classes. | ||
|
|
||
| Args: | ||
| linear: `nn.Linear` to be replaced. | ||
| style: Tensor parallel style of the new linear, e.g. "colwise". | ||
| quant_config: Quantization config for the new linear. | ||
| Returns: | ||
| The new linear. | ||
| """ | ||
|
|
||
| if not isinstance(style, str): | ||
| raise ValueError(f"Unsupported parallel style type {type(style)}, expected str") | ||
|
|
||
| vllm_linear_maps = { | ||
| "colwise": (ColumnParallelLinear, {}), | ||
| "colwise_rep": (ColumnParallelLinear, {"gather_output": True}), | ||
| "rowwise": (RowParallelLinear, {}), | ||
| "rowwise_rep": (RowParallelLinear, {"input_is_parallel": False}), | ||
| "replicate": (ReplicatedLinear, {}), | ||
| } | ||
| vllm_linear_cls, vllm_linear_kwargs = vllm_linear_maps[style] | ||
|
|
||
| return vllm_linear_cls( | ||
| input_size=linear.in_features, | ||
| output_size=linear.out_features, | ||
| bias=linear.bias is not None, | ||
|
Isotr0py marked this conversation as resolved.
|
||
| quant_config=quant_config, | ||
| prefix=prefix, | ||
| return_bias=False, | ||
| **vllm_linear_kwargs, | ||
| ) | ||
|
|
||
|
|
||
| def recursive_replace_linear(model: nn.Module, od_config: OmniDiffusionConfig): | ||
| """Recursively replace modules in the model as needed. | ||
| Currently, this replaces: | ||
| - `nn.Linear` with vLLM's tensor parallel linear classes | ||
| """ | ||
| # Prefix the patterns because we always start from `self.model` | ||
| quant_config = get_vllm_quant_config_for_layers(od_config.quantization_config) | ||
|
|
||
|
Isotr0py marked this conversation as resolved.
|
||
| def _recursive_replace(module: nn.Module, prefix: str): | ||
| for child_name, child_module in module.named_children(): | ||
| new_module = child_module | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice utility! One small thing I noticed: Would it be worth accepting an optional
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We can reuse |
||
| qual_name = maybe_prefix(prefix, child_name) | ||
| # Replace modules as needed | ||
| if isinstance(child_module, nn.Linear): | ||
| style = "replicate" | ||
| new_module = replace_linear_class(child_module, style, quant_config, prefix=qual_name) | ||
| else: | ||
| _recursive_replace(child_module, prefix=qual_name) | ||
| if new_module is not child_module: | ||
| setattr(module, child_name, new_module) | ||
|
|
||
| _recursive_replace(model, prefix="") | ||
|
|
||
|
|
||
| def init_parameters( | ||
| module: nn.Module, | ||
| dtype: torch.dtype | None, | ||
| device: torch.device | None = None, | ||
| ): | ||
| for name, param in module.named_parameters(recurse=False): | ||
| if param.device == torch.device("meta"): | ||
| new_param = nn.Parameter( | ||
| torch.empty_like( | ||
| param.data, | ||
|
Isotr0py marked this conversation as resolved.
|
||
| dtype=dtype, | ||
| device=device, | ||
| ), | ||
| requires_grad=param.requires_grad, | ||
| ) | ||
| setattr(module, name, new_param) | ||
| for child in module.children(): | ||
| init_parameters(child, dtype, device) | ||
|
|
||
|
|
||
| def create_transformers_model( | ||
| auto_cls: _BaseAutoModelClass, | ||
| od_config: OmniDiffusionConfig, | ||
| hf_config: PretrainedConfig, | ||
| dtype: torch.dtype | None = None, | ||
| device: torch.device | None = None, | ||
| ) -> PreTrainedModel: | ||
| """Create a HuggingFace model using the given auto class and model name.""" | ||
| dtype = dtype or od_config.dtype | ||
| device = device or torch.get_default_device() | ||
| with init_on_device_without_buffers("meta"): | ||
| model = auto_cls.from_config(hf_config) | ||
| recursive_replace_linear(model, od_config) | ||
| init_parameters(model, dtype=dtype, device=device) | ||
| return model | ||
|
|
||
|
|
||
| def _load_json(model_path: str, filename: str, local_files_only: bool = True) -> dict: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.