Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions fastdeploy/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ def _setting_environ_variables(self):
"FLAGS_use_append_attn": 1,
"NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)),
}
# environment variables needed by Dy2St
variables.update(
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,7 @@ def _setting_environ_variables(self):
"PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python",
"NCCL_ALGO": "Ring",
"FLAGS_max_partition_size": int(os.getenv("FLAGS_max_partition_size", 1024)),
"OMP_NUM_THREADS": int(os.getenv("OMP_NUM_THREADS", 3)),
}
# environment variables needed by Dy2St
variables.update(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from fastdeploy.model_executor.layers.moe.moe import get_moe_scores
from fastdeploy.model_executor.layers.quantization.quant_base import QuantMethodBase
from fastdeploy.model_executor.ops.gpu import tritonmoe_preprocess
from fastdeploy.model_executor.utils import TensorTracker, set_weight_attrs
from fastdeploy.model_executor.utils import (
TensorTracker,
free_tensor,
process_weight_transpose,
set_weight_attrs,
weight_fully_copied,
)
from fastdeploy.utils import ceil_div

from .triton_moe_kernels import fused_moe_kernel_paddle
Expand Down Expand Up @@ -69,32 +75,52 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
]
# TODO(bukejiyu): remove v1 loader check when v0 loader is removed
if self.quant_config.is_checkpoint_bf16 and layer.fd_config.load_config.load_choices == "default_v1":
if self.model_format != "torch":
up_gate_proj_weight_shape = self.up_gate_proj_weight_shape
down_proj_weight_shape = self.down_proj_weight_shape
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 1, "down": 0, "up": 1},
}
else:
up_gate_proj_weight_shape = self.up_gate_proj_weight_shape[::-1]
down_proj_weight_shape = self.down_proj_weight_shape[::-1]
up_gate_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=up_gate_proj_weight_shape, output_dim=False),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}
down_proj_attrs = {
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=down_proj_weight_shape, output_dim=True),
"SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0},
}

layer.up_gate_proj_weight = layer.create_parameter(
shape=self.up_gate_proj_weight_shape,
shape=up_gate_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)

layer.down_proj_weight = layer.create_parameter(
shape=self.down_proj_weight_shape,
shape=down_proj_weight_shape,
dtype=layer.weight_dtype,
default_initializer=paddle.nn.initializer.Constant(0),
)
extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"

set_weight_attrs(
layer.up_gate_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.up_gate_proj_weight.shape, output_dim=True),
},
up_gate_proj_attrs,
)
set_weight_attrs(
layer.down_proj_weight,
{
**extra_weight_attrs,
"tensor_track": TensorTracker(shape=layer.down_proj_weight.shape, output_dim=False),
},
down_proj_attrs,
)
else:
setattr(
Expand Down Expand Up @@ -181,59 +207,64 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict):
@paddle.no_grad()
def process_weights_after_loading(self, layer):
""" """
if not self.quant_config.is_checkpoint_bf16:
return

def _process_quantize(weight_idx):
max_bound = 127
# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]

weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound

free_tensor(getattr(layer, weight_name))

# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)

algo = layer.quant_method.quant_config.name()
assert algo == "wint8"
max_bound = 127
weight_id_map = {"gate_up": 0, "down": 1}
if (
hasattr(layer.up_gate_proj_weight, "tensor_track")
and layer.up_gate_proj_weight.tensor_track is not None
and layer.up_gate_proj_weight.tensor_track.is_fully_copied()
):
weight_type = "gate_up"
layer.up_gate_proj_weight.tensor_track = None
if self.quant_config.is_checkpoint_bf16:

weight_id_map = {"gate_up": 0, "down": 1}
if weight_fully_copied(layer.up_gate_proj_weight):
weight_type = "gate_up"
else:
weight_type = "down"

if self.model_format == "torch":
# pt model
unquantized_weight_name = self.added_weight_attrs[weight_id_map[weight_type]].replace(
"quant_weight", "weight"
)
process_weight_transpose(layer, unquantized_weight_name)
_process_quantize(weight_id_map[weight_type])
else:
weight_type = "down"
layer.down_proj_weight.tensor_track = None

# weight
weight_name = self.added_weight_attrs[weight_id_map[weight_type]]
# scale
scale_name = self.added_scale_attrs[weight_id_map[weight_type]]

weight_tensor = getattr(layer, weight_name)
quanted_weight_scale = weight_tensor.abs().max(axis=1)
quanted_weight = weight_tensor / quanted_weight_scale[:, None, :] * max_bound
quanted_weight = paddle.round(quanted_weight).astype("int8")
quanted_weight_scale = quanted_weight_scale / max_bound

getattr(layer, weight_name).value().get_tensor()._clear()

# create weight
setattr(
layer,
weight_name,
layer.create_parameter(
shape=weight_tensor.shape,
dtype=quanted_weight.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
# create scale
setattr(
layer,
scale_name,
layer.create_parameter(
shape=quanted_weight_scale.shape,
dtype=quanted_weight_scale.dtype,
default_initializer=paddle.nn.initializer.Constant(0),
),
)
getattr(layer, weight_name).copy_(quanted_weight, False)
getattr(layer, scale_name).copy_(quanted_weight_scale, False)
return

@paddle.no_grad()
def apply(
Expand Down
6 changes: 3 additions & 3 deletions fastdeploy/model_executor/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from paddle.distributed import fleet

from fastdeploy.config import FDConfig
from fastdeploy.model_executor.utils import set_weight_attrs, slice_fn
from fastdeploy.model_executor.utils import h2d_copy, set_weight_attrs, slice_fn

from .utils import (
DEFAULT_VOCAB_PADDING_SIZE,
Expand Down Expand Up @@ -273,10 +273,10 @@ def weight_loader(self, param, loaded_weight, shard_id=None):
shard_weight = slice_fn(loaded_weight, output_dim, start_idx, end_idx)

if output_dim == 0:
param[: shard_weight.shape[0]].copy_(shard_weight, False)
h2d_copy(param[: shard_weight.shape[0]], shard_weight)
param[shard_weight.shape[0] :].fill_(0)
else:
param[:, : shard_weight.shape[1]].copy_(shard_weight, False)
h2d_copy(param[:, : shard_weight.shape[1]], shard_weight)
param[:, shard_weight.shape[1] :].fill_(0)

def forward(self, ids_remove_padding=None) -> paddle.Tensor:
Expand Down
Loading
Loading