Skip to content

Commit

Permalink
Merge branch 'develop' into fix1
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Feb 24, 2025
2 parents c726c61 + 30df8b6 commit 5e34b75
Show file tree
Hide file tree
Showing 213 changed files with 14,499 additions and 3,558 deletions.
79 changes: 42 additions & 37 deletions README.md

Large diffs are not rendered by default.

37 changes: 27 additions & 10 deletions csrc/gpu/append_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -97,21 +98,21 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
if (out_linear_in_scale > 0.0) {
if (fabs(quant_max_bound - 127.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::INT8,
qkv.place());
}
else if (fabs(quant_max_bound - 448.0f) < 0.000001) {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
paddle::DataType::FLOAT8_E4M3FN,
qkv.place());
}else{
PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0'].");
}
} else {
fmha_out = GetEmptyTensor(
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims},
{meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v},
D,
qkv.place());
}
Expand Down Expand Up @@ -203,6 +204,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -240,6 +242,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -282,6 +285,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
encoder_block_shape_q,
max_input_length,
max_enc_len_this_time_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -428,6 +432,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -465,6 +470,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -508,6 +514,7 @@ std::vector<paddle::Tensor> AppendAttentionKernel(
decoder_block_shape_q,
max_input_length,
max_len_kv_data,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -565,6 +572,7 @@ std::vector<paddle::Tensor> AppendAttention(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand All @@ -578,9 +586,10 @@ std::vector<paddle::Tensor> AppendAttention(
meta_data.token_nums = qkv_dims[0];
meta_data.kv_num_heads = key_cache_dims[1];
meta_data.head_dims = key_cache_dims[3];
const int total_num_head =
qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims;
meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads;
meta_data.head_dims_v = value_cache.dims()[3];
const int q_hidden_size =
qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v);
meta_data.q_num_heads = q_hidden_size / meta_data.head_dims;

meta_data.max_blocks_per_seq = block_tables.dims()[1];
meta_data.block_size = key_cache.dims()[2];
Expand Down Expand Up @@ -626,6 +635,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -672,6 +682,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -719,6 +730,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -764,6 +776,7 @@ std::vector<paddle::Tensor> AppendAttention(
cache_quant_type_str,
use_neox_rotary_style,
max_input_length,
softmax_scale,
quant_max_bound,
quant_min_bound,
out_linear_in_scale,
Expand Down Expand Up @@ -821,10 +834,12 @@ std::vector<std::vector<int64_t>> AppendAttentionInferShape(
const paddle::optional<std::vector<int64_t>>& out_linear_smooths_shape) {
const int token_num = qkv_shape[0];
const int kv_num_heads = key_cache_shape[1];
const int head_dim = key_cache_shape[3];
const int total_num_head = qkv_shape[qkv_shape.size() - 1] / head_dim;
const int num_heads = total_num_head - 2 * kv_num_heads;
return {{token_num, num_heads * head_dim}, qkv_shape};
const int head_dim_qk = key_cache_shape[3];
const int head_dim_v = value_cache_shape[3];
const int q_hidden_size =
qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v);
const int num_heads = q_hidden_size / head_dim_qk;
return {{token_num, num_heads * head_dim_v}, qkv_shape};
}

std::vector<paddle::DataType> AppendAttentionInferDtype(
Expand Down Expand Up @@ -865,6 +880,7 @@ std::vector<paddle::DataType> AppendAttentionInferDtype(
const std::string& cache_quant_type_str,
const bool use_neox_rotary_style,
const int max_input_length,
const float softmax_scale,
const float quant_max_bound,
const float quant_min_bound,
const float out_linear_in_scale,
Expand Down Expand Up @@ -941,6 +957,7 @@ PD_BUILD_OP(append_attention)
"cache_quant_type: std::string",
"use_neox_rotary_style: bool",
"max_input_length: int",
"softmax_scale: float",
"quant_max_bound: float",
"quant_min_bound: float",
"out_linear_in_scale: float",
Expand Down
Loading

0 comments on commit 5e34b75

Please sign in to comment.