|
9 | 9 | from vllm.model_executor.layers.quantization.base_config import (
|
10 | 10 | QuantizationConfig)
|
11 | 11 | from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
12 |
| -from vllm.model_executor.utils import set_weight_attrs |
| 12 | +from vllm.model_executor.parameter import (BasevLLMParameter, |
| 13 | + ChannelQuantScaleParameter, |
| 14 | + GroupQuantScaleParameter, |
| 15 | + PackedvLLMParameter) |
13 | 16 |
|
14 | 17 | logger = init_logger(__name__)
|
15 | 18 |
|
@@ -132,6 +135,7 @@ def create_weights(
|
132 | 135 | **extra_weight_attrs,
|
133 | 136 | ):
|
134 | 137 | del output_size # Unused.
|
| 138 | + weight_loader = extra_weight_attrs["weight_loader"] |
135 | 139 |
|
136 | 140 | if params_dtype != torch.float16:
|
137 | 141 | raise ValueError(
|
@@ -170,64 +174,64 @@ def create_weights(
|
170 | 174 | "Each permutation group must reside on the same gpu")
|
171 | 175 |
|
172 | 176 | # Quantized 4Bit weights packed into Int32.
|
173 |
| - qweight = Parameter( |
174 |
| - torch.empty( |
| 177 | + qweight = PackedvLLMParameter( |
| 178 | + data=torch.empty( |
175 | 179 | input_size_per_partition // self.quant_config.tile_size,
|
176 | 180 | output_size_per_partition * self.quant_config.tile_size //
|
177 | 181 | self.quant_config.pack_factor,
|
178 | 182 | device="cuda",
|
179 | 183 | dtype=torch.int32,
|
180 | 184 | ),
|
181 |
| - requires_grad=False, |
182 |
| - ) |
183 |
| - set_weight_attrs( |
184 |
| - qweight, |
185 |
| - { |
186 |
| - "input_dim": 0, |
187 |
| - "output_dim": 1, |
188 |
| - "packed_dim": 1, |
189 |
| - "pack_factor": self.quant_config.pack_factor, |
190 |
| - "marlin_tile_size": self.quant_config.tile_size, |
191 |
| - }, |
192 |
| - ) |
| 185 | + input_dim=0, |
| 186 | + output_dim=1, |
| 187 | + packed_dim=1, |
| 188 | + packed_factor=self.quant_config.pack_factor, |
| 189 | + marlin_tile_size=self.quant_config.tile_size, |
| 190 | + weight_loader=weight_loader) |
193 | 191 |
|
194 | 192 | # Determine if channelwise or not
|
195 | 193 | input_groups = (1 if self.quant_config.group_size == -1 else
|
196 | 194 | input_size_per_partition //
|
197 | 195 | self.quant_config.group_size)
|
198 | 196 |
|
199 |
| - scales = Parameter( |
| 197 | + weight_scale_args = { |
| 198 | + "data": |
200 | 199 | torch.empty(
|
201 | 200 | input_groups,
|
202 | 201 | output_size_per_partition,
|
203 | 202 | device="cuda",
|
204 | 203 | dtype=params_dtype,
|
205 | 204 | ),
|
206 |
| - requires_grad=False, |
207 |
| - ) |
208 |
| - set_weight_attrs( |
209 |
| - scales, |
210 |
| - { |
211 |
| - "input_dim": None if input_groups == 1 else 0, |
212 |
| - "output_dim": 1, |
213 |
| - }, |
214 |
| - ) |
| 205 | + "weight_loader": |
| 206 | + weight_loader |
| 207 | + } |
| 208 | + if input_groups == 1: |
| 209 | + scales = ChannelQuantScaleParameter(output_dim=1, |
| 210 | + **weight_scale_args) |
| 211 | + else: |
| 212 | + scales = GroupQuantScaleParameter(output_dim=1, |
| 213 | + input_dim=0, |
| 214 | + **weight_scale_args) |
215 | 215 |
|
216 | 216 | # Allocate workspace (Used for internal locking mechanism)
|
217 | 217 | max_workspace_size = (
|
218 | 218 | output_size_per_partition //
|
219 | 219 | self.quant_config.min_n_threads) * self.quant_config.max_parallel
|
220 |
| - workspace = Parameter(torch.zeros(max_workspace_size, |
221 |
| - device="cuda", |
222 |
| - dtype=torch.int), |
223 |
| - requires_grad=False) |
| 220 | + |
| 221 | + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, |
| 222 | + device="cuda", |
| 223 | + dtype=torch.int), |
| 224 | + weight_loader=weight_loader) |
224 | 225 |
|
225 | 226 | layer.register_parameter("B", qweight)
|
226 |
| - set_weight_attrs(qweight, extra_weight_attrs) |
227 | 227 | layer.register_parameter("s", scales)
|
228 |
| - set_weight_attrs(scales, extra_weight_attrs) |
229 | 228 | layer.register_parameter("workspace", workspace)
|
230 |
| - set_weight_attrs(workspace, extra_weight_attrs) |
| 229 | + |
| 230 | + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: |
| 231 | + # required by torch.compile |
| 232 | + layer.B = Parameter(layer.B.data, requires_grad=False) |
| 233 | + layer.s = Parameter(layer.s.data, requires_grad=False) |
| 234 | + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) |
231 | 235 |
|
232 | 236 | def apply(
|
233 | 237 | self,
|
|
0 commit comments