Skip to content

Commit ff3b7b2

Browse files
committed
[Metax] support default_v1 loader based #4988
1 parent 5b24013 commit ff3b7b2

File tree

2 files changed

+119
-86
lines changed

2 files changed

+119
-86
lines changed

fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_cutlass_metax_backend.py

Lines changed: 117 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,21 @@
2323
from fastdeploy.distributed.communication import tensor_model_parallel_all_reduce
2424
from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase
2525
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
26+
from fastdeploy.model_executor.layers.quantization.weight_only import WeightOnlyConfig
2627
from fastdeploy.model_executor.layers.utils import get_tensor
2728
from fastdeploy.model_executor.ops.gpu import (
2829
fused_expert_moe,
2930
moe_expert_dispatch,
3031
moe_expert_ffn,
3132
moe_expert_reduce,
3233
)
33-
from fastdeploy.model_executor.utils import TensorTracker, free_tensor, set_weight_attrs
34+
from fastdeploy.model_executor.utils import (
35+
TensorTracker,
36+
free_tensor,
37+
process_weight_transpose,
38+
set_weight_attrs,
39+
weight_fully_copied,
40+
)
3441

3542

3643
class MetaxCutlassMoEMethod(MoEMethodBase):
@@ -142,18 +149,11 @@ def apply_tp(
142149
1.0,
143150
)
144151
else:
145-
added_weight_attrs0 = getattr(layer, self.added_weight_attrs[0])
146-
added_weight_attrs1 = getattr(layer, self.added_weight_attrs[1])
147-
148-
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
149-
added_weight_attrs0 = paddle.transpose(added_weight_attrs0, perm=[0, 2, 1])
150-
added_weight_attrs1 = paddle.transpose(added_weight_attrs1, perm=[0, 2, 1])
151-
152152
fused_moe_out = fused_expert_moe(
153153
x,
154154
gate.weight,
155-
added_weight_attrs0,
156-
added_weight_attrs1,
155+
getattr(layer, self.added_weight_attrs[0]),
156+
getattr(layer, self.added_weight_attrs[1]),
157157
None,
158158
(layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None),
159159
None,
@@ -177,7 +177,10 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
177177

178178
def __init__(self, quant_config):
179179
super().__init__(quant_config)
180-
self.quant_config = quant_config
180+
if quant_config is None:
181+
self.quant_config = WeightOnlyConfig(algo="weight_only_int8", is_checkpoint_bf16=True)
182+
else:
183+
self.quant_config = quant_config
181184
self.moe_quant_type = self.quant_config.algo
182185
self.pack_num = 1
183186
self.weight_only_linear_arch = os.getenv("FLAGS_weight_only_linear_arch")
@@ -252,33 +255,61 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
252255
]
253256
self.up_gate_proj_scale_shape = [layer.num_local_experts, layer.moe_intermediate_size * 2]
254257
self.down_proj_scale_shape = [layer.num_local_experts, layer.hidden_size]
258+
self.model_format = extra_weight_attrs.get("model_format")
255259
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
256260
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
261+
if self.model_format != "torch":
262+
up_gate_proj_weight_shape = [
263+
layer.num_local_experts,
264+
layer.hidden_size,
265+
layer.moe_intermediate_size * 2,
266+
]
267+
down_proj_weight_shape = [layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size]
268+
up_gate_proj_attrs = {
269+
**extra_weight_attrs,
270+
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
271+
}
272+
down_proj_attrs = {
273+
**extra_weight_attrs,
274+
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
275+
}
276+
else:
277+
up_gate_proj_weight_shape = [
278+
layer.num_local_experts,
279+
layer.moe_intermediate_size * 2,
280+
layer.hidden_size,
281+
]
282+
down_proj_weight_shape = [layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size]
283+
up_gate_proj_attrs = {
284+
**extra_weight_attrs,
285+
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
286+
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
287+
}
288+
down_proj_attrs = {
289+
**extra_weight_attrs,
290+
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
291+
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
292+
}
293+
257294
layer.up_gate_proj_weight = layer.create_parameter(
258-
shape=[layer.num_local_experts, layer.hidden_size, layer.moe_intermediate_size * 2],
295+
shape=up_gate_proj_weight_shape,
259296
dtype=layer.weight_dtype,
260297
default_initializer=paddle.nn.initializer.Constant(0),
261298
)
262299

263300
layer.down_proj_weight = layer.create_parameter(
264-
shape=[layer.num_local_experts, layer.moe_intermediate_size, layer.hidden_size],
301+
shape=down_proj_weight_shape,
265302
dtype=layer.weight_dtype,
266303
default_initializer=paddle.nn.initializer.Constant(0),
267304
)
268-
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
305+
# extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
269306
set_weight_attrs(
270307
layer.up_gate_proj_weight,
271-
{
272-
**extra_weight_attrs,
273-
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
274-
},
308+
up_gate_proj_attrs,
275309
)
276310
set_weight_attrs(
277311
layer.down_proj_weight,
278-
{
279-
**extra_weight_attrs,
280-
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
281-
},
312+
down_proj_attrs,
282313
)
283314
else:
284315
self.weight_dtype = "int8"
@@ -325,7 +356,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
325356
default_initializer=paddle.nn.initializer.Constant(0),
326357
),
327358
)
328-
extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
359+
# extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
329360
moe_extra_weight_attrs = {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}
330361
set_weight_attrs(layer.up_gate_proj_weight, moe_extra_weight_attrs)
331362
set_weight_attrs(layer.down_proj_weight, moe_extra_weight_attrs)
@@ -337,69 +368,71 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
337368
set_weight_attrs(layer.down_proj_weight_scale, scale_extra_weight_attrs)
338369

339370
def process_weights_after_loading(self, layer):
340-
""" """
341-
if not self.quant_config.is_checkpoint_bf16:
342-
return
343-
weight_id_map = {"gate_up": 0, "down": 1}
344-
if (
345-
hasattr(layer.up_gate_proj_weight, "tensor_track")
346-
and layer.up_gate_proj_weight.tensor_track is not None
347-
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
348-
):
349-
weight_type = "gate_up"
350-
else:
351-
weight_type = "down"
352-
353-
# 1.init shape and type
354-
# weight
355-
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
356-
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
357-
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
358-
weight_shape[1], weight_shape[2] = weight_shape[2], weight_shape[1]
359-
weight_dtype = "int8"
360-
# scale
361-
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
362-
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
363-
scale_dtype = self.default_dtype
364-
365-
# 2.crate tmp tensor
366-
367-
weight = paddle.empty(weight_shape, dtype=weight_dtype)
368-
scale = paddle.empty(scale_shape, dtype=scale_dtype)
369-
370-
# 3.quantize weight
371-
372-
for expert_id in range(layer.num_local_experts):
373-
weight[expert_id], scale[expert_id] = weight_quantize(
374-
getattr(layer, unquantized_weight_name)[expert_id],
375-
algo=self.moe_quant_type,
376-
arch=self.weight_only_linear_arch,
377-
)
371+
def _process_quantize(weight_idx):
372+
# 1.init shape and type
373+
weight_name = self.added_weight_attrs[weight_idx]
374+
unquantized_weight_name = weight_name.replace("quant_weight", "weight")
375+
weight_shape = self.up_gate_proj_weight_shape if weight_type == "gate_up" else self.down_proj_weight_shape
376+
transposed_weight_shape = [weight_shape[0], weight_shape[2], weight_shape[1]]
377+
weight_dtype = "int8"
378+
# scale
379+
scale_name = self.added_scale_attrs[weight_idx]
380+
scale_shape = self.up_gate_proj_scale_shape if weight_type == "gate_up" else self.down_proj_scale_shape
381+
scale_dtype = self.default_dtype
382+
383+
# 2.crate tmp tensor
384+
385+
weight = paddle.empty(transposed_weight_shape, dtype=weight_dtype)
386+
scale = paddle.empty(scale_shape, dtype=scale_dtype)
387+
388+
# 3.quantize weight
389+
390+
for expert_id in range(layer.num_local_experts):
391+
weight[expert_id], scale[expert_id] = weight_quantize(
392+
getattr(layer, unquantized_weight_name)[expert_id],
393+
algo=self.moe_quant_type,
394+
arch=self.weight_only_linear_arch,
395+
)
378396

379-
free_tensor(getattr(layer, unquantized_weight_name))
397+
free_tensor(getattr(layer, unquantized_weight_name))
380398

381-
# create weight
382-
setattr(
383-
layer,
384-
weight_name,
385-
layer.create_parameter(
386-
shape=weight_shape,
387-
dtype=weight_dtype,
388-
default_initializer=paddle.nn.initializer.Constant(0),
389-
),
390-
)
391-
# create scale
392-
setattr(
393-
layer,
394-
scale_name,
395-
layer.create_parameter(
396-
shape=scale_shape,
397-
dtype=scale_dtype,
398-
default_initializer=paddle.nn.initializer.Constant(0),
399-
),
400-
)
401-
getattr(layer, weight_name).copy_(weight, False)
402-
getattr(layer, scale_name).copy_(scale, False)
399+
setattr(
400+
layer,
401+
weight_name,
402+
layer.create_parameter(
403+
shape=weight_shape,
404+
dtype=weight_dtype,
405+
default_initializer=paddle.nn.initializer.Constant(0),
406+
),
407+
)
408+
# create scale
409+
setattr(
410+
layer,
411+
scale_name,
412+
layer.create_parameter(
413+
shape=scale_shape,
414+
dtype=scale_dtype,
415+
default_initializer=paddle.nn.initializer.Constant(0),
416+
),
417+
)
418+
getattr(layer, weight_name).copy_(weight.transpose([0, 2, 1]), False)
419+
getattr(layer, scale_name).copy_(scale, False)
420+
421+
if self.quant_config.is_checkpoint_bf16:
422+
weight_id_map = {"gate_up": 0, "down": 1}
423+
if weight_fully_copied(layer.up_gate_proj_weight):
424+
weight_type = "gate_up"
425+
else:
426+
weight_type = "down"
427+
428+
if self.model_format == "torch":
429+
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
430+
"quant_weight", "weight"
431+
)
432+
process_weight_transpose(layer, unquantized_weight_name)
433+
_process_quantize(weight_id_map[weight_type])
434+
else:
435+
return
403436

404437
def process_loaded_weights(self, layer: nn.Layer, state_dict):
405438
"""

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def get_moe_method():
6262
MetaxCutlassWeightOnlyMoEMethod,
6363
)
6464

65-
return MetaxCutlassWeightOnlyMoEMethod(None)
65+
return MetaxCutlassWeightOnlyMoEMethod
6666
raise NotImplementedError
6767

6868

@@ -227,7 +227,7 @@ def weight_loader(
227227
return
228228
if hasattr(param, "SHARD_ID_TO_SHARDED_DIM"):
229229
SHARD_ID_TO_SHARDED_DIM = param.SHARD_ID_TO_SHARDED_DIM
230-
elif current_platform.is_cuda() or current_platform.is_iluvatar():
230+
elif current_platform.is_cuda() or current_platform.is_iluvatar() or current_platform.is_maca():
231231
SHARD_ID_TO_SHARDED_DIM = {"gate": 1, "down": 0, "up": 1}
232232
else:
233233
SHARD_ID_TO_SHARDED_DIM = {"gate": 0, "down": 1, "up": 0}

0 commit comments

Comments
 (0)