Skip to content

Commit 28cc566

Browse files
committed
refactor pt loading
1 parent 08ca0f6 commit 28cc566

33 files changed

+1109
-859
lines changed

fastdeploy/engine/async_llm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ def _setting_environ_variables(self):
722722
"FLAGS_use_append_attn": 1,
723723
"NCCL_ALGO": "Ring",
724724
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
725+
"OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)),
725726
}
726727
# environment variables needed by Dy2St
727728
variables.update(

fastdeploy/engine/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ def _setting_environ_variables(self):
442442
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
443443
"NCCL_ALGO": "Ring",
444444
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
445+
"OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)),
445446
}
446447
# environment variables needed by Dy2St
447448
variables.update(

fastdeploy/model_executor/layers/backends/metax/moe/fused_moe_triton_metax_backend.py

Lines changed: 90 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@
2222
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
2323
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
2424
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
25-
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
25+
from fastdeploy.model_executor.utils import (
26+
TensorTracker,
27+
free_tensor,
28+
process_weight_transpose,
29+
set_weight_attrs,
30+
weight_fully_copied,
31+
)
2632
from fastdeploy.utils import ceil_div
2733

2834
from .triton_moe_kernels import fused_moe_kernel_paddle
@@ -69,32 +75,50 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
6975
]
7076
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
7177
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
78+
if self.model_format != "torch":
79+
up_gate_proj_weight_shape = self.up_gate_proj_weight_shape
80+
down_proj_weight_shape = self.down_proj_weight_shape
81+
up_gate_proj_attrs = {
82+
**extra_weight_attrs,
83+
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
84+
}
85+
down_proj_attrs = {
86+
**extra_weight_attrs,
87+
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
88+
}
89+
else:
90+
up_gate_proj_weight_shape = self.up_gate_proj_weight_shape[::-1]
91+
down_proj_weight_shape = self.down_proj_weight_shape[::-1]
92+
up_gate_proj_attrs = {
93+
**extra_weight_attrs,
94+
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
95+
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
96+
}
97+
down_proj_attrs = {
98+
**extra_weight_attrs,
99+
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
100+
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
101+
}
102+
72103
layer.up_gate_proj_weight = layer.create_parameter(
73-
shape=self.up_gate_proj_weight_shape,
104+
shape=up_gate_proj_weight_shape,
74105
dtype=layer.weight_dtype,
75106
default_initializer=paddle.nn.initializer.Constant(0),
76107
)
77108

78109
layer.down_proj_weight = layer.create_parameter(
79-
shape=self.down_proj_weight_shape,
110+
shape=down_proj_weight_shape,
80111
dtype=layer.weight_dtype,
81112
default_initializer=paddle.nn.initializer.Constant(0),
82113
)
83-
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
84114

85115
set_weight_attrs(
86116
layer.up_gate_proj_weight,
87-
{
88-
**extra_weight_attrs,
89-
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
90-
},
117+
up_gate_proj_attrs,
91118
)
92119
set_weight_attrs(
93120
layer.down_proj_weight,
94-
{
95-
**extra_weight_attrs,
96-
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
97-
},
121+
down_proj_attrs,
98122
)
99123
else:
100124
setattr(
@@ -181,59 +205,64 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
181205
@paddle.no_grad()
182206
def process_weights_after_loading(self, layer):
183207
""" """
184-
if not self.quant_config.is_checkpoint_bf16:
185-
return
208+
209+
def _process_quantize(weight_idx):
210+
max_bound = 127
211+
# weight
212+
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
213+
# scale
214+
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
215+
216+
weight_tensor = getattr(layer, weight_name)
217+
quanted_weight_scale = weight_tensor.abs().max(axis=1)
218+
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
219+
quanted_weight = paddle.round(quanted_weight).astype("int8")
220+
quanted_weight_scale = quanted_weight_scale / max_bound
221+
222+
free_tensor(getattr(layer, weight_name))
223+
224+
# create weight
225+
setattr(
226+
layer,
227+
weight_name,
228+
layer.create_parameter(
229+
shape=weight_tensor.shape,
230+
dtype=quanted_weight.dtype,
231+
default_initializer=paddle.nn.initializer.Constant(0),
232+
),
233+
)
234+
# create scale
235+
setattr(
236+
layer,
237+
scale_name,
238+
layer.create_parameter(
239+
shape=quanted_weight_scale.shape,
240+
dtype=quanted_weight_scale.dtype,
241+
default_initializer=paddle.nn.initializer.Constant(0),
242+
),
243+
)
244+
getattr(layer, weight_name).copy_(quanted_weight, False)
245+
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
186246

187247
algo = layer.quant_method.quant_config.name()
188248
assert algo == "wint8"
189-
max_bound = 127
190-
weight_id_map = {"gate_up": 0, "down": 1}
191-
if (
192-
hasattr(layer.up_gate_proj_weight, "tensor_track")
193-
and layer.up_gate_proj_weight.tensor_track is not None
194-
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
195-
):
196-
weight_type = "gate_up"
197-
layer.up_gate_proj_weight.tensor_track = None
249+
if self.quant_config.is_checkpoint_bf16:
250+
251+
weight_id_map = {"gate_up": 0, "down": 1}
252+
if weight_fully_copied(layer.up_gate_proj_weight):
253+
weight_type = "gate_up"
254+
else:
255+
weight_type = "down"
256+
257+
if self.model_format == "torch":
258+
# pt model
259+
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
260+
"quant_weight", "weight"
261+
)
262+
process_weight_transpose(layer, unquantized_weight_name)
263+
_process_quantize(weight_id_map[weight_type])
198264
else:
199-
weight_type = "down"
200-
layer.down_proj_weight.tensor_track = None
201-
202-
# weight
203-
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
204-
# scale
205-
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]
206-
207-
weight_tensor = getattr(layer, weight_name)
208-
quanted_weight_scale = weight_tensor.abs().max(axis=1)
209-
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
210-
quanted_weight = paddle.round(quanted_weight).astype("int8")
211-
quanted_weight_scale = quanted_weight_scale / max_bound
212-
213-
getattr(layer, weight_name).value().get_tensor()._clear()
214-
215-
# create weight
216-
setattr(
217-
layer,
218-
weight_name,
219-
layer.create_parameter(
220-
shape=weight_tensor.shape,
221-
dtype=quanted_weight.dtype,
222-
default_initializer=paddle.nn.initializer.Constant(0),
223-
),
224-
)
225-
# create scale
226-
setattr(
227-
layer,
228-
scale_name,
229-
layer.create_parameter(
230-
shape=quanted_weight_scale.shape,
231-
dtype=quanted_weight_scale.dtype,
232-
default_initializer=paddle.nn.initializer.Constant(0),
233-
),
234-
)
235-
getattr(layer, weight_name).copy_(quanted_weight, False)
236-
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
265+
return
237266

238267
@paddle.no_grad()
239268
def apply(

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from paddle.distributed import fleet
2424

2525
from fastdeploy.config import FDConfig
26-
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
26+
from fastdeploy.model_executor.utils import h2d_copy, set_weight_attrs, slice_fn
2727

2828
from .utils import (
2929
DEFAULT_VOCAB_PADDING_SIZE,
@@ -273,10 +273,10 @@ def weight_loader(self, param, loaded_weight, shard_id=None):
273273
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)
274274

275275
if output_dim == 0:
276-
param[: shard_weight.shape[0]].copy_(shard_weight, False)
276+
h2d_copy(param[: shard_weight.shape[0]], shard_weight)
277277
param[shard_weight.shape[0] :].fill_(0)
278278
else:
279-
param[:, : shard_weight.shape[1]].copy_(shard_weight, False)
279+
h2d_copy(param[:, : shard_weight.shape[1]], shard_weight)
280280
param[:, shard_weight.shape[1] :].fill_(0)
281281

282282
def forward(self, ids_remove_padding=None) -> paddle.Tensor:

0 commit comments

Comments
 (0)