diff --git a/csrc/gpu/append_attn/append_attention_c16_impl.cuh b/csrc/gpu/append_attn/append_attention_c16_impl.cuh index 3244bd1129ca..f790eb2574c9 100644 --- a/csrc/gpu/append_attn/append_attention_c16_impl.cuh +++ b/csrc/gpu/append_attn/append_attention_c16_impl.cuh @@ -1429,7 +1429,6 @@ void MultiQueryAppendAttention( static_cast(tmp_d->ptr()), reinterpret_cast(out->data()), speculate_max_draft_token_num); - // merge constexpr int vec_size = num_elems_per_128b(); if (is_decoder) { diff --git a/csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu b/csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu index 4d5ef4a9c30c..7f98e7945431 100644 --- a/csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu +++ b/csrc/gpu/append_attn/get_block_shape_and_split_kv_block.cu @@ -219,7 +219,6 @@ std::vector GetBlockShapeAndSplitKVBlock( ); auto max_len_kv_cpu = max_len_kv.copy_to(paddle::CPUPlace(), false); - // decoder int max_dec_len_this_time_data = max_dec_len_this_time.data()[0]; if (max_dec_len_this_time_data > 0) { diff --git a/csrc/gpu/append_attn/utils.cuh b/csrc/gpu/append_attn/utils.cuh index 980071f93755..0784ee4c9f7f 100644 --- a/csrc/gpu/append_attn/utils.cuh +++ b/csrc/gpu/append_attn/utils.cuh @@ -284,6 +284,16 @@ __forceinline__ __host__ __device__ void vec_cast( __VA_ARGS__ \ break; \ } \ + case 256: { \ + constexpr size_t HEAD_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + case 512: { \ + constexpr size_t HEAD_DIM = 512; \ + __VA_ARGS__ \ + break; \ + } \ default: { \ PD_THROW("not support the head_dim: ", head_dim); \ } \ @@ -377,6 +387,9 @@ __forceinline__ __host__ __device__ void vec_cast( } else if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ __VA_ARGS__ \ + } else if (group_size == 64) { \ + constexpr size_t GROUP_SIZE = 64; \ + __VA_ARGS__ \ } else if (group_size == 16) { \ constexpr size_t GROUP_SIZE = 16; \ __VA_ARGS__ \ diff --git a/csrc/gpu/simple_append_attention.cu b/csrc/gpu/simple_append_attention.cu new file mode 100644 index 000000000000..15170286312b --- /dev/null +++ b/csrc/gpu/simple_append_attention.cu @@ -0,0 +1,594 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "append_attn/append_attention_kernel.h" +#include "append_attn/decoder_write_cache_with_rope_kernel.h" +#include "append_attn/speculate_write_cache_with_rope_kernel.h" +#include "append_attn/encoder_write_cache_with_rope_kernel.h" + +template +std::vector SimpleAppendAttentionKernel( + const AppendAttnMetaData& meta_data, + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& max_dec_len_this_time, + const paddle::Tensor& max_len_kv, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + int encoder_num_blocks_data = encoder_num_blocks.data()[0]; + int kv_num_blocks_data = kv_num_blocks.data()[0]; + int decoder_num_blocks_data = decoder_num_blocks.data()[0]; + int max_enc_len_this_time_data = max_enc_len_this_time.data()[0]; + int max_dec_len_this_time_data = max_dec_len_this_time.data()[0]; + int max_len_kv_data = max_len_kv.data()[0]; + const int encoder_block_shape_q = get_encoder_block_shape_q(); + const int decoder_block_shape_q = get_decoder_block_shape_q(); + auto main_stream = qkv.stream(); + static cudaEvent_t main_event; + static cudaEvent_t decoder_event; + static cudaStream_t decoder_stream; + static bool init_flag = false; + if (max_enc_len_this_time_data > 0 && max_dec_len_this_time_data > 0 && + !init_flag) { + cudaEventCreateWithFlags(&main_event, cudaEventDisableTiming); + cudaEventCreateWithFlags(&decoder_event, cudaEventDisableTiming); + cudaStreamCreateWithFlags(&decoder_stream, cudaStreamNonBlocking); + init_flag = true; + } + + paddle::Tensor qkv_out; + if (qkv_out_scales) { + qkv_out = GetEmptyTensor(qkv.dims(), D, qkv.place()); + } else { + qkv_out = qkv; + } + paddle::Tensor fmha_out; + if (out_linear_in_scale > 0.0) { + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v}, + paddle::DataType::INT8, + qkv.place()); + } + else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v}, + paddle::DataType::FLOAT8_E4M3FN, + qkv.place()); + }else{ + PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } + } else { + fmha_out = GetEmptyTensor( + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v}, + D, + qkv.place()); + } + + if (max_enc_len_this_time_data > 0) { + EncoderWriteCacheWithRopeKernel( + meta_data, + qkv_out, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + padding_offsets, + cum_offsets, + block_tables, + kv_batch_ids, + kv_tile_ids_per_batch, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + kv_num_blocks_data, + max_input_length, + use_neox_rotary_style, + main_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache)); + + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + cache_quant_type_str, + encoder_num_blocks_data, + encoder_block_shape_q, + max_input_length, + max_enc_len_this_time_data, + softmax_scale, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + false, + true, + main_stream, + &fmha_out); + } + + if (max_dec_len_this_time_data > 0) { + cudaStream_t exec_stream; + if (max_enc_len_this_time_data > 0) { + cudaStreamWaitEvent(decoder_stream, main_event); + exec_stream = decoder_stream; + } else { + exec_stream = main_stream; + } + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache)); + CascadeAppendAttentionKernel( + meta_data, + qkv_out, + key_cache, + value_cache, + attn_mask, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + seq_lens_this_time, + seq_lens_decoder, + seq_lens_encoder, + padding_offsets, + cum_offsets, + block_tables, + decoder_batch_ids, + decoder_tile_ids_per_batch, + cache_quant_type_str, + decoder_num_blocks_data, + decoder_block_shape_q, + max_input_length, + max_len_kv_data, + softmax_scale, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + !speculate_decoder, + !speculate_decoder, + exec_stream, + &fmha_out); + + if (max_enc_len_this_time_data > 0) { + cudaEventRecord(decoder_event, exec_stream); + cudaStreamWaitEvent(main_stream, decoder_event); + } + } + + return {fmha_out, qkv_out}; +} + +std::vector SimpleAppendAttention( + const paddle::Tensor& qkv, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& padding_offsets, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& block_tables, + const paddle::Tensor& encoder_batch_ids, + const paddle::Tensor& encoder_tile_ids_per_batch, + const paddle::Tensor& encoder_num_blocks, + const paddle::Tensor& kv_batch_ids, + const paddle::Tensor& kv_tile_ids_per_batch, + const paddle::Tensor& kv_num_blocks, + const paddle::Tensor& decoder_batch_ids, + const paddle::Tensor& decoder_tile_ids_per_batch, + const paddle::Tensor& decoder_num_blocks, + const paddle::Tensor& max_enc_len_this_time, + const paddle::Tensor& max_dec_len_this_time, + const paddle::Tensor& max_len_kv, + const paddle::optional& rotary_embs, + const paddle::optional& attn_mask, + const paddle::optional& qkv_bias, + const paddle::optional& qkv_out_scales, + const paddle::optional& cache_k_quant_scales, + const paddle::optional& cache_v_quant_scales, + const paddle::optional& cache_k_dequant_scales, + const paddle::optional& cache_v_dequant_scales, + const paddle::optional& cache_k_zp, + const paddle::optional& cache_v_zp, + const paddle::optional& out_linear_shifts, + const paddle::optional& out_linear_smooths, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + AppendAttnMetaData meta_data; + + const auto& qkv_dims = qkv.dims(); + const auto& key_cache_dims = key_cache.dims(); + meta_data.token_nums = qkv_dims[0]; + meta_data.kv_num_heads = key_cache_dims[1]; + meta_data.head_dims = key_cache_dims[3]; + meta_data.head_dims_v = value_cache.dims()[3]; + const int q_hidden_size = + qkv_dims[qkv_dims.size() - 1] - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v); + meta_data.q_num_heads = q_hidden_size / meta_data.head_dims; + + meta_data.max_blocks_per_seq = block_tables.dims()[1]; + meta_data.block_size = key_cache.dims()[2]; + meta_data.batch_size = cum_offsets.dims()[0]; + + switch (qkv.dtype()) { + case paddle::DataType::FLOAT16: { + return SimpleAppendAttentionKernel( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + max_len_kv, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + softmax_scale, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + speculate_decoder); + } + case paddle::DataType::BFLOAT16: { + return SimpleAppendAttentionKernel( + meta_data, + qkv, + key_cache, + value_cache, + seq_lens_encoder, + seq_lens_decoder, + seq_lens_this_time, + padding_offsets, + cum_offsets, + block_tables, + encoder_batch_ids, + encoder_tile_ids_per_batch, + encoder_num_blocks, + kv_batch_ids, + kv_tile_ids_per_batch, + kv_num_blocks, + decoder_batch_ids, + decoder_tile_ids_per_batch, + decoder_num_blocks, + max_enc_len_this_time, + max_dec_len_this_time, + max_len_kv, + rotary_embs, + attn_mask, + qkv_bias, + qkv_out_scales, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + cache_k_zp, + cache_v_zp, + out_linear_shifts, + out_linear_smooths, + cache_quant_type_str, + use_neox_rotary_style, + max_input_length, + softmax_scale, + quant_max_bound, + quant_min_bound, + out_linear_in_scale, + speculate_max_draft_token_num, + causal, + speculate_decoder); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16 and bfloat16 are supported. "); + break; + } + } + return {paddle::Tensor{}}; +} + +std::vector> SimpleAppendAttentionInferShape( + const std::vector& qkv_shape, + const std::vector& key_cache_shape, + const std::vector& value_cache_shape, + const std::vector& seq_lens_encoder_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_this_time_shape, + const std::vector& padding_offsets_shape, + const std::vector& cum_offsets_shape, + const std::vector& block_tables_shape, + const std::vector& encoder_batch_ids_shape, + const std::vector& encoder_tile_ids_per_batch_shape, + const std::vector& encoder_num_blocks_shape, + const std::vector& kv_batch_ids_shape, + const std::vector& kv_tile_ids_per_batch_shape, + const std::vector& kv_num_blocks_shape, + const std::vector& decoder_batch_ids_shape, + const std::vector& decoder_tile_ids_per_batch_shape, + const std::vector& decoder_num_blocks_shape, + const std::vector& max_enc_len_this_time_shape, + const std::vector& max_dec_len_this_time_shape, + const std::vector& max_len_kv_shape, + const paddle::optional>& rotary_embs_shape, + const paddle::optional>& attn_mask_shape, + const paddle::optional>& qkv_bias_shape, + const paddle::optional>& qkv_out_scales_shape, + const paddle::optional>& cache_k_quant_scales_shape, + const paddle::optional>& cache_v_quant_scales_shape, + const paddle::optional>& cache_k_dequant_scales_shape, + const paddle::optional>& cache_v_dequant_scales_shape, + const paddle::optional>& cache_k_zp_shape, + const paddle::optional>& cache_v_zp_shape, + const paddle::optional>& out_linear_shifts_shape, + const paddle::optional>& out_linear_smooths_shape) { + const int token_num = qkv_shape[0]; + const int kv_num_heads = key_cache_shape[1]; + const int head_dim_qk = key_cache_shape[3]; + const int head_dim_v = value_cache_shape[3]; + const int q_hidden_size = + qkv_shape[qkv_shape.size() - 1] - kv_num_heads * (head_dim_qk + head_dim_v); + const int num_heads = q_hidden_size / head_dim_qk; + return {{token_num, num_heads * head_dim_v}, qkv_shape}; +} + +std::vector SimpleAppendAttentionInferDtype( + const paddle::DataType& qkv_dtype, + const paddle::DataType& key_cache_dtype, + const paddle::DataType& value_cache_dtype, + const paddle::DataType& seq_lens_encoder_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_this_time_dtype, + const paddle::DataType& padding_offsets_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& block_tables_dtype, + const paddle::DataType& encoder_batch_ids_dtype, + const paddle::DataType& encoder_tile_ids_per_batch_dtype, + const paddle::DataType& encoder_num_blocks_dtype, + const paddle::DataType& kv_batch_ids_dtype, + const paddle::DataType& kv_tile_ids_per_batch_dtype, + const paddle::DataType& kv_num_blocks_dtype, + const paddle::DataType& decoder_batch_ids_dtype, + const paddle::DataType& decoder_tile_ids_per_batch_dtype, + const paddle::DataType& decoder_num_blocks_dtype, + const paddle::DataType& max_enc_len_this_time_dtype, + const paddle::DataType& max_dec_len_this_time_dtype, + const paddle::DataType& max_len_kv_dtype, + const paddle::optional& rotary_embs_dtype, + const paddle::optional& attn_mask_dtype, + const paddle::optional& qkv_bias_dtype, + const paddle::optional& qkv_out_scales_dtype, + const paddle::optional& cache_k_quant_scales_dtype, + const paddle::optional& cache_v_quant_scales_dtype, + const paddle::optional& cache_k_dequant_scales_dtype, + const paddle::optional& cache_v_dequant_scales_dtype, + const paddle::optional& cache_k_zp_dtype, + const paddle::optional& cache_v_zp_dtype, + const paddle::optional& out_linear_shifts_dtype, + const paddle::optional& out_linear_smooths_dtype, + const std::string& compute_dtype, + const std::string& cache_quant_type_str, + const bool use_neox_rotary_style, + const int max_input_length, + const float softmax_scale, + const float quant_max_bound, + const float quant_min_bound, + const float out_linear_in_scale, + const int speculate_max_draft_token_num, + const bool causal, + const bool speculate_decoder) { + if (compute_dtype == "bf16") { + if (out_linear_in_scale > 0.0) { + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + return {paddle::DataType::INT8, paddle::DataType::BFLOAT16}; + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::BFLOAT16}; + }else{ + PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } + } else { + return {paddle::DataType::BFLOAT16, paddle::DataType::BFLOAT16}; + } + } else if (compute_dtype == "fp16") { + if (out_linear_in_scale > 0.0) { + if (fabs(quant_max_bound - 127.0f) < 0.000001) { + return {paddle::DataType::INT8, paddle::DataType::FLOAT16}; + } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { + return {paddle::DataType::FLOAT8_E4M3FN, paddle::DataType::FLOAT16}; + }else{ + PD_THROW("Only supported attr of quant_max_bound in ['127.0', '448.0']."); + } + } else { + return {paddle::DataType::FLOAT16, paddle::DataType::FLOAT16}; + } + } else { + PD_THROW("Only supported attr of compute_dtype in ['fp16', 'bf16']."); + } +} + +PD_BUILD_OP(simple_append_attention) + .Inputs({"qkv", + "key_cache", + "value_cache", + "seq_lens_encoder", + "seq_lens_decoder", + "seq_lens_this_time", + "padding_offsets", + "cum_offsets", + "block_tables", + "encoder_batch_ids", + "encoder_tile_ids_per_batch", + "encoder_num_blocks", + "kv_batch_ids", + "kv_tile_ids_per_batch", + "kv_num_blocks", + "decoder_batch_ids", + "decoder_tile_ids_per_batch", + "decoder_num_blocks", + "max_enc_len_this_time", + "max_dec_len_this_time", + "max_len_kv", + paddle::Optional("rotary_embs"), + paddle::Optional("attn_mask"), + paddle::Optional("qkv_bias"), + paddle::Optional("qkv_out_scales"), + paddle::Optional("cache_k_quant_scales"), + paddle::Optional("cache_v_quant_scales"), + paddle::Optional("cache_k_dequant_scales"), + paddle::Optional("cache_v_dequant_scales"), + paddle::Optional("cache_k_zp"), + paddle::Optional("cache_v_zp"), + paddle::Optional("out_linear_shifts"), + paddle::Optional("out_linear_smooths")}) + .Outputs({"fmha_out", "qkv_out", "key_cache_out", "value_cache_out"}) + .SetInplaceMap({{"key_cache", "key_cache_out"}, + {"value_cache", "value_cache_out"}}) + .Attrs({"compute_type: std::string", + "cache_quant_type: std::string", + "use_neox_rotary_style: bool", + "max_input_length: int", + "softmax_scale: float", + "quant_max_bound: float", + "quant_min_bound: float", + "out_linear_in_scale: float", + "speculate_max_draft_token_num: int", + "causal: bool", + "speculate_decoder: bool"}) + .SetKernelFn(PD_KERNEL(SimpleAppendAttention)) + .SetInferShapeFn(PD_INFER_SHAPE(SimpleAppendAttentionInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(SimpleAppendAttentionInferDtype)); \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index a37b44584caa..da88b76fcba1 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -153,7 +153,7 @@ def get_gencode_flags(): if cc >= 80: sources += ["gpu/int8_gemm_with_cutlass/gemm_dequant.cu"] - sources += ["./gpu/append_attention.cu", "./gpu/multi_head_latent_attention.cu"] + sources += ["./gpu/append_attention.cu", "./gpu/multi_head_latent_attention.cu", "./gpu/simple_append_attention.cu"] # add this kernel sources += find_end_files("./gpu/append_attn", ".cu") sources += find_end_files("./gpu/append_attn/template_instantiation", ".cu") diff --git a/llm/predict/output.json b/llm/predict/output.json new file mode 100644 index 000000000000..2f90bdb04a96 --- /dev/null +++ b/llm/predict/output.json @@ -0,0 +1,2 @@ +{"src": "2014年3月,大范围雾霾天气长时间影响我国东部地区,严重危害人体健康。造成雾霾天气的人为原因有____\r\n①工业生产中使用矿物作为燃料,大量排放污染物 ②汽车尾气的大量排放 \r\n③风力小,空气流动不畅 ④冬季取暖排放粉尘\nA. ①②③\nB. ②③④\nC. ①③④\nD. ①②④", "tgt": "", "output": "pad>", "tgt": "", "output": "pad>" * predictor_args.src_length ] * predictor_args.batch_size target_texts = [""] * predictor_args.batch_size diff --git a/llm/predict/run.sh b/llm/predict/run.sh new file mode 100644 index 000000000000..186b649bfa31 --- /dev/null +++ b/llm/predict/run.sh @@ -0,0 +1,26 @@ +#!/bin/bash +export CUDA_VISIBLE_DEVICES=3 + +export PYTHONPATH=/root/paddlejob/workspace/env_run/output/changwenbin/dongyazhu/PaddleNLP:$PYTHONPATH + +# python ./predictor.py --model_name_or_path Qwen/Qwen2.5-14B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 + +seqlen=1024 +nsys profile -o seqlen-1.5-linear-${seqlen} python ./predictor.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 --src_length ${seqlen} --min_length 100 --total_max_length 8400 + +seqlen=4096 +nsys profile -o seqlen-1.5-linear-${seqlen} python ./predictor.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 --src_length ${seqlen} --min_length 100 --total_max_length 8400 + +# seqlen=8192 +# nsys profile -o seqlen${seqlen} python ./predictor.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 --src_length ${seqlen} --min_length 100 --total_max_length 16384 + +# seqlen=4096 +# python ./predictor.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 + +# python ./predictor.py --model_name_or_path Qwen/Qwen2-1.5B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 + +seqlen=1024 +nsys profile -o seqlen-40-linear-${seqlen} python ./predictor.py --model_name_or_path Qwen/Qwen2.5-14B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 --src_length ${seqlen} --min_length 100 --total_max_length 8400 + +seqlen=4096 +nsys profile -o seqlen-40-linear-${seqlen} python ./predictor.py --model_name_or_path Qwen/Qwen2.5-14B-Instruct --dtype bfloat16 --mode dynamic --inference_model 1 --append_attn 1 --batch_size 2 --src_length ${seqlen} --min_length 100 --total_max_length 8400 \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index a7f0d4bd6e4f..530de25fea69 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -34,6 +34,7 @@ from paddlenlp.utils.import_utils import is_paddlenlp_ops_available from paddlenlp.utils.log import logger +import nvtx if not is_paddlenlp_ops_available(): logger.warning( @@ -1175,7 +1176,11 @@ def compute_qkv(self, src, residual_input, i): if self.config.mla_config.use_absorb(): qkv_out = ln_out else: + paddle.device.synchronize() + qkv_linear_nvtx = nvtx.start_range(message="qkv_linear", color="blue") qkv_out = self.compute_qkv_linear(ln_out, i) + paddle.device.synchronize() + nvtx.end_range(qkv_linear_nvtx) return qkv_out, residual_input @@ -1714,6 +1719,7 @@ def forward( ) residual_input = src for i in range(self.num_layers): + # print(666) qkv_out, residual_input = self.compute_qkv(src, residual_input, i) fmha_out = self.compute_attn( time_step, @@ -3307,6 +3313,12 @@ def compute_attn( if self.config.append_attn: from paddlenlp_ops import append_attention + paddle.device.synchronize() + transformer_nvtx = nvtx.start_range(message="GQA_Append_attn", color="red") + + # q: [bsz, seq_len, q_head, head_dim] q_head -> 12 + # k: [bsz, seq_len, kv_head, head_dim] kv_head -> 2 + # v: [bsz, seq_len, kv_head, head_dim] kv_head -> 2 fmha_out = append_attention( qkv_out, caches[2 * i], @@ -3353,6 +3365,8 @@ def compute_attn( True, # causal self.config.speculate_config.speculate_method is not None, # speculate_decoder )[0] + paddle.device.synchronize() + nvtx.end_range(transformer_nvtx) else: if paddle.is_compiled_with_xpu(): from paddlenlp_ops import mla_block_multihead_attention_xpu