Skip to content

Commit

Permalink
[DCU] fix DCU w8a8c8 GEMM shape
Browse files Browse the repository at this point in the history
  • Loading branch information
YanhuiDua committed Sep 11, 2024
1 parent 2f31866 commit 81c30d8
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 17 deletions.
2 changes: 1 addition & 1 deletion llm/predict/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import paddle
from paddle.distributed import fleet
from predict.predictor import ModelArgument, PredictorArgument, create_predictor
from predictor import ModelArgument, PredictorArgument, create_predictor

from paddlenlp.trainer import PdArgumentParser
from paddlenlp.utils import llm_utils
Expand Down
23 changes: 13 additions & 10 deletions paddlenlp/experimental/transformers/fused_transformer_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,19 @@
from paddlenlp_ops import cutlass_fp8_fp8_half_gemm_fused as fp8_gemm_fused
else:
from paddle.linalg import fp8_fp8_half_gemm_fused as fp8_gemm_fused
from paddlenlp_ops import (
dequant_int8,
encode_rotary_qk,
gemm_dequant,
qkv_transpose_split,
quant_int8,
rebuild_padding,
transpose_remove_padding,
write_cache_kv,
)
try:
from paddlenlp_ops import (

Check warning on line 54 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L53-L54

Added lines #L53 - L54 were not covered by tests
dequant_int8,
encode_rotary_qk,
gemm_dequant,
qkv_transpose_split,
quant_int8,
rebuild_padding,
transpose_remove_padding,
write_cache_kv,
)
except:
pass

Check warning on line 65 in paddlenlp/experimental/transformers/fused_transformer_layers.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/fused_transformer_layers.py#L64-L65

Added lines #L64 - L65 were not covered by tests

__all__ = [
"MoeConfig",
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/experimental/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def __init__(self, config: LlamaConfig):
use_neox_rotary_style=self.use_neox,
cachekv_int8_type=config.cachekv_int8_type,
rank_id=config.tensor_parallel_rank,
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True),
trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True),
)

self.set_transformer_block(transformer_config)
Expand Down Expand Up @@ -861,7 +861,7 @@ def set_state_dict(self, state_dict):
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
]
if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8":
if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type:

Check warning on line 864 in paddlenlp/experimental/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/llama/modeling.py#L864

Added line #L864 was not covered by tests
concated_qkv_weight = np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/experimental/transformers/mixtral/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self, config: MixtralConfig):
use_neox_rotary_style=self.use_neox,
cachekv_int8_type=config.cachekv_int8_type,
rank_id=config.tensor_parallel_rank,
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True),
trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True),
moe_config=moe_config,
)

Expand Down Expand Up @@ -527,7 +527,7 @@ def set_state_dict(self, state_dict):
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
"mixtral.layers.{}.self_attn.v_proj.weight".format(idx)
]
if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8":
if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type:

Check warning on line 530 in paddlenlp/experimental/transformers/mixtral/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/mixtral/modeling.py#L530

Added line #L530 was not covered by tests
concated_qkv_weight = np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/experimental/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def __init__(self, config: Qwen2Config):
use_neox_rotary_style=self.use_neox,
cachekv_int8_type=config.cachekv_int8_type,
rank_id=config.tensor_parallel_rank,
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True),
trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True),
)

self.set_transformer_block(transformer_config)
Expand Down Expand Up @@ -433,7 +433,7 @@ def set_state_dict(self, state_dict):
unfused_state_dict["qwen2.self_attn.v_proj.weight"] = state_dict[
"qwen2.layers.{}.self_attn.v_proj.weight".format(idx)
]
if paddle.is_compiled_with_rocm() and (self.quant_type == "a8w8" or self.quant_type == "a8w8c8"):
if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type:

Check warning on line 436 in paddlenlp/experimental/transformers/qwen2/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/experimental/transformers/qwen2/modeling.py#L436

Added line #L436 was not covered by tests
concated_qkv_weight = np.concatenate(
[
unfused_state_dict["self_attn.q_proj.weight"],
Expand Down

0 comments on commit 81c30d8

Please sign in to comment.