Skip to content

Commit f1df5db

Browse files
authored
[Misc] Update marlin to use vLLMParameters (#7803)
1 parent 35ee2ad commit f1df5db

File tree

3 files changed

+41
-34
lines changed

3 files changed

+41
-34
lines changed

tests/weight_loading/models.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,6 @@ compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
1515
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
1616
awq, casperhansen/mixtral-instruct-awq, main
1717
awq_marlin, casperhansen/mixtral-instruct-awq, main
18-
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
18+
fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main
19+
marlin, nm-testing/zephyr-beta-7b-marlin-g128, main
20+
marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main

vllm/model_executor/layers/linear.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222

2323
WEIGHT_LOADER_V2_SUPPORTED = [
2424
"CompressedTensorsLinearMethod", "AWQMarlinLinearMethod",
25-
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod"
25+
"AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod",
26+
"MarlinLinearMethod"
2627
]
2728

2829

vllm/model_executor/layers/quantization/marlin.py

+36-32
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from vllm.model_executor.layers.quantization.base_config import (
1010
QuantizationConfig)
1111
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
12-
from vllm.model_executor.utils import set_weight_attrs
12+
from vllm.model_executor.parameter import (BasevLLMParameter,
13+
ChannelQuantScaleParameter,
14+
GroupQuantScaleParameter,
15+
PackedvLLMParameter)
1316

1417
logger = init_logger(__name__)
1518

@@ -132,6 +135,7 @@ def create_weights(
132135
**extra_weight_attrs,
133136
):
134137
del output_size # Unused.
138+
weight_loader = extra_weight_attrs["weight_loader"]
135139

136140
if params_dtype != torch.float16:
137141
raise ValueError(
@@ -170,64 +174,64 @@ def create_weights(
170174
"Each permutation group must reside on the same gpu")
171175

172176
# Quantized 4Bit weights packed into Int32.
173-
qweight = Parameter(
174-
torch.empty(
177+
qweight = PackedvLLMParameter(
178+
data=torch.empty(
175179
input_size_per_partition // self.quant_config.tile_size,
176180
output_size_per_partition * self.quant_config.tile_size //
177181
self.quant_config.pack_factor,
178182
device="cuda",
179183
dtype=torch.int32,
180184
),
181-
requires_grad=False,
182-
)
183-
set_weight_attrs(
184-
qweight,
185-
{
186-
"input_dim": 0,
187-
"output_dim": 1,
188-
"packed_dim": 1,
189-
"pack_factor": self.quant_config.pack_factor,
190-
"marlin_tile_size": self.quant_config.tile_size,
191-
},
192-
)
185+
input_dim=0,
186+
output_dim=1,
187+
packed_dim=1,
188+
packed_factor=self.quant_config.pack_factor,
189+
marlin_tile_size=self.quant_config.tile_size,
190+
weight_loader=weight_loader)
193191

194192
# Determine if channelwise or not
195193
input_groups = (1 if self.quant_config.group_size == -1 else
196194
input_size_per_partition //
197195
self.quant_config.group_size)
198196

199-
scales = Parameter(
197+
weight_scale_args = {
198+
"data":
200199
torch.empty(
201200
input_groups,
202201
output_size_per_partition,
203202
device="cuda",
204203
dtype=params_dtype,
205204
),
206-
requires_grad=False,
207-
)
208-
set_weight_attrs(
209-
scales,
210-
{
211-
"input_dim": None if input_groups == 1 else 0,
212-
"output_dim": 1,
213-
},
214-
)
205+
"weight_loader":
206+
weight_loader
207+
}
208+
if input_groups == 1:
209+
scales = ChannelQuantScaleParameter(output_dim=1,
210+
**weight_scale_args)
211+
else:
212+
scales = GroupQuantScaleParameter(output_dim=1,
213+
input_dim=0,
214+
**weight_scale_args)
215215

216216
# Allocate workspace (Used for internal locking mechanism)
217217
max_workspace_size = (
218218
output_size_per_partition //
219219
self.quant_config.min_n_threads) * self.quant_config.max_parallel
220-
workspace = Parameter(torch.zeros(max_workspace_size,
221-
device="cuda",
222-
dtype=torch.int),
223-
requires_grad=False)
220+
221+
workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size,
222+
device="cuda",
223+
dtype=torch.int),
224+
weight_loader=weight_loader)
224225

225226
layer.register_parameter("B", qweight)
226-
set_weight_attrs(qweight, extra_weight_attrs)
227227
layer.register_parameter("s", scales)
228-
set_weight_attrs(scales, extra_weight_attrs)
229228
layer.register_parameter("workspace", workspace)
230-
set_weight_attrs(workspace, extra_weight_attrs)
229+
230+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
231+
# required by torch.compile
232+
layer.B = Parameter(layer.B.data, requires_grad=False)
233+
layer.s = Parameter(layer.s.data, requires_grad=False)
234+
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
231235

232236
def apply(
233237
self,

0 commit comments

Comments
 (0)