From c5d8d5bb2dc3e45b59d4187e08dcd951d3d17cb7 Mon Sep 17 00:00:00 2001 From: RichardWooSJTU <37864677+RichardWooSJTU@users.noreply.github.com> Date: Wed, 10 Jan 2024 22:23:38 +0800 Subject: [PATCH] [LLM] Support block_attention/cachekv quant for llama (#7649) * support blha and cache kv quant * lint * fix unit test * fix infer when blha is on * code refine * add docs and fix ops * merge blha read res in predictor * finish docs * add docs and unittest * add unittest * migrate read res --- .github/codecov.yml | 2 +- csrc/generation/get_output.cc | 69 +++ csrc/generation/get_padding_offset_v2.cu | 96 ++++ csrc/generation/rebuild_padding_v2.cu | 131 +++++ csrc/generation/reset_need_stop_value.cc | 12 + csrc/generation/save_with_output_msg.cc | 58 ++ csrc/generation/set_value_by_flags_v2.cu | 58 ++ csrc/generation/step.cu | 354 ++++++++++++ .../stop_generation_multi_ends_v2.cu | 75 +++ csrc/generation/token_penalty_multi_scores.cu | 2 +- .../token_penalty_multi_scores_v2.cu | 250 ++++++++ csrc/generation/update_inputs.cu | 106 ++++ csrc/generation/write_int8_cache_kv.cu | 349 ++++++++++++ csrc/setup_cuda.py | 11 + llm/benchmark.sh | 2 +- llm/docs/inference.md | 69 ++- llm/export_model.py | 6 +- llm/llama/ptq_argument.json | 42 +- llm/predictor.py | 535 +++++++++++++++++- llm/utils.py | 32 ++ paddlenlp/experimental/model_utils.py | 24 + .../transformers/fused_transformer_layers.py | 274 ++++++++- .../transformers/generation_utils.py | 311 +++++++++- .../transformers/llama/modeling.py | 490 +++++++++++++++- .../transformers/llama/ptq_scales_map.json | 4 + .../llama/ptq_scales_map_shift_smooth.json | 4 + tests/fixtures/llm/ptq.yaml | 2 +- tests/llm/test_predictor.py | 78 +++ tests/llm/test_ptq.py | 13 + 29 files changed, 3380 insertions(+), 79 deletions(-) create mode 100644 csrc/generation/get_output.cc create mode 100644 csrc/generation/get_padding_offset_v2.cu create mode 100644 csrc/generation/rebuild_padding_v2.cu create mode 100644 csrc/generation/reset_need_stop_value.cc create mode 100644 csrc/generation/save_with_output_msg.cc create mode 100644 csrc/generation/set_value_by_flags_v2.cu create mode 100644 csrc/generation/step.cu create mode 100644 csrc/generation/stop_generation_multi_ends_v2.cu create mode 100644 csrc/generation/token_penalty_multi_scores_v2.cu create mode 100644 csrc/generation/update_inputs.cu create mode 100644 csrc/generation/write_int8_cache_kv.cu diff --git a/.github/codecov.yml b/.github/codecov.yml index 7a560bba78fd..c2151a19d38e 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -10,4 +10,4 @@ coverage: threshold: 1% # Allow the coverage to drop by 1%, and posting a success status. patch: default: - target: 80% # lines adjusted Coverage < 80% CI will fail \ No newline at end of file + target: 80% # lines adjusted Coverage < 80% CI will fail diff --git a/csrc/generation/get_output.cc b/csrc/generation/get_output.cc new file mode 100644 index 000000000000..87535e0a6362 --- /dev/null +++ b/csrc/generation/get_output.cc @@ -0,0 +1,69 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define MAX_BSZ 512 + +struct msgdata { + long mtype; + int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens +}; + +void GetOutput(const paddle::Tensor& x, + int64_t rank_id, + bool wait_flag) { + if (rank_id > 0) return; + + static struct msgdata msg_rcv; + + static key_t key = ftok("./", 1); + + static int msgid = msgget(key, IPC_CREAT | 0666); + + int64_t *out_data = const_cast(x.data()); + int ret = -1; + if (!wait_flag) { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT); + } else { + ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0); + } + if(ret == -1) + { + // read none + out_data[0] = -2; + out_data[1] = 0; + return; + } + + int bsz = msg_rcv.mtext[1]; + + for (int64_t i = 0; i < bsz + 2; i++) { + out_data[i] = (int64_t)msg_rcv.mtext[i]; + } + return; +} + +PD_BUILD_OP(get_output) + .Inputs({"x"}) + .Attrs({"rank_id: int64_t", + "wait_flag: bool"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"x", "x_out"}}) + .SetKernelFn(PD_KERNEL(GetOutput)); diff --git a/csrc/generation/get_padding_offset_v2.cu b/csrc/generation/get_padding_offset_v2.cu new file mode 100644 index 000000000000..3acfad6cb8a7 --- /dev/null +++ b/csrc/generation/get_padding_offset_v2.cu @@ -0,0 +1,96 @@ +#include "paddle/extension.h" + +__global__ void RemovePaddingV2(int64_t *output_data, + const int64_t *input_data, + const int *seq_lens, + const int *cum_offsets, + const int sequence_length) { + const int bi = blockIdx.x; + const int tid = threadIdx.x; + + for (int i = tid; i < seq_lens[bi]; i += blockDim.x) { + const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i; + const int src_seq_id = bi * sequence_length + i; + output_data[tgt_seq_id] = input_data[src_seq_id]; + } +} + +__global__ void GetPaddingOffsetKernelV2(int *padding_offset, + int *cum_offsets_out, + int *cu_seqlens_q, + int *cu_seqlens_k, + const int *cum_offsets, + const int *seq_lens, + const int max_seq_len) { + // get padding offset of each batch + const int bi = blockIdx.x; + const int ti = threadIdx.x; + int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1]; + for (int i = ti; i < seq_lens[bi]; i += blockDim.x) { + padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset; + } + if (ti == 0) { + cum_offsets_out[bi] = cum_offset; + int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi]; + cu_seqlens_q[bi + 1] = cum_seq_len; + cu_seqlens_k[bi + 1] = cum_seq_len; + } +} + + +std::vector GetPaddingOffsetV2(const paddle::Tensor& input_ids, + const paddle::Tensor& cum_offsets, + const paddle::Tensor& token_num, + const paddle::Tensor& seq_len) { + auto cu_stream = input_ids.stream(); + std::vector input_ids_shape = input_ids.shape(); + const int bsz = seq_len.shape()[0]; + const int seq_length = input_ids_shape[1]; + auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false); + auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false); + + const int token_num_data = cpu_token_num.data()[0]; + auto x_remove_padding = paddle::full({token_num_data}, 0, paddle::DataType::INT64, input_ids.place()); + auto padding_offset = paddle::full({token_num_data}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place()); + int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128); + GetPaddingOffsetKernelV2<<>>( + padding_offset.data(), + cum_offsets_out.data(), + cu_seqlens_q.data(), + cu_seqlens_k.data(), + cum_offsets.data(), + seq_len.data(), + seq_length); + RemovePaddingV2<<>>( + x_remove_padding.data(), + input_ids.data(), + seq_len.data(), + cum_offsets_out.data(), + seq_length); + return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num}; +} + +std::vector> GetPaddingOffsetV2InferShape(const std::vector& input_ids_shape, + const std::vector& cum_offsets_shape, + const std::vector& token_num_shape, + const std::vector& seq_len_shape) { + int64_t bsz = seq_len_shape[0]; + int64_t seq_len = input_ids_shape[1]; + return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}}; +} + +std::vector GetPaddingOffsetV2InferDtype(const paddle::DataType& input_ids_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& token_num_dtype, + const paddle::DataType& seq_len_dtype) { + return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype}; +} + +PD_BUILD_OP(get_padding_offset_v2) + .Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"}) + .Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"}) + .SetKernelFn(PD_KERNEL(GetPaddingOffsetV2)) + .SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetV2InferDtype)); \ No newline at end of file diff --git a/csrc/generation/rebuild_padding_v2.cu b/csrc/generation/rebuild_padding_v2.cu new file mode 100644 index 000000000000..4d61936952da --- /dev/null +++ b/csrc/generation/rebuild_padding_v2.cu @@ -0,0 +1,131 @@ +#include "helper.h" + +template +__global__ void RebuildPaddingV2Kernel(T *output_data, + const T *input_data, + const int *cum_offsets, + const int *seq_len_decoder, + const int *seq_len_encoder, + const int seq_len, + const int dim_embed, + const int elem_nums) { + using LoadT = AlignedVector; + LoadT src_vec; + const int global_idx = blockDim.x * blockIdx.x + threadIdx.x; + for (int i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) { + const int bi = i / dim_embed; + const int bias_idx = i % dim_embed; + int seq_id = 0; + // just encoder or stop, get last token; just decoder, get first token. + if (seq_len_decoder[bi] == 0) { + if (seq_len_encoder[bi] != 0) { + seq_id = seq_len_encoder[bi] - 1; + } else { + return; + } + } + const int ori_token_idx = bi * seq_len - cum_offsets[bi] + seq_id; + const int src_offset = ori_token_idx * dim_embed + bias_idx; + Load(&input_data[src_offset], &src_vec); + Store(src_vec, &output_data[i]); + } +} + +template +std::vector rebuild_padding_v2(const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + int max_input_length) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + auto cu_stream = tmp_out.stream(); + std::vector tmp_out_shape = tmp_out.shape(); + const int token_num = tmp_out_shape[0]; + const int dim_embed = tmp_out_shape[1]; + const int bsz = cum_offsets.shape()[0]; + auto out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place()); + constexpr int PackSize = VEC_16B / sizeof(DataType_); + int elem_nums = out.numel(); + int pack_num = elem_nums / PackSize; + const int blocksize = 128; + const int grid_size = (pack_num + blocksize - 1) / blocksize; + RebuildPaddingV2Kernel<<>>( + reinterpret_cast(out.data()), + reinterpret_cast(const_cast(tmp_out.data())), + cum_offsets.data(), + seq_lens_decoder.data(), + seq_lens_encoder.data(), + max_input_length, + dim_embed, + elem_nums); + return {out}; +} + +std::vector RebuildPaddingV2(const paddle::Tensor& tmp_out, // [token_num, dim_embed] + const paddle::Tensor& cum_offsets, // [bsz, 1] + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& seq_lens_encoder, + int max_input_length) { + switch (tmp_out.type()) { + case paddle::DataType::BFLOAT16: { + return rebuild_padding_v2( + tmp_out, + cum_offsets, + seq_lens_decoder, + seq_lens_encoder, + max_input_length + ); + } + case paddle::DataType::FLOAT16: { + return rebuild_padding_v2( + tmp_out, + cum_offsets, + seq_lens_decoder, + seq_lens_encoder, + max_input_length + ); + } + case paddle::DataType::FLOAT32: { + return rebuild_padding_v2( + tmp_out, + cum_offsets, + seq_lens_decoder, + seq_lens_encoder, + max_input_length + ); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16, bfloat16 and float32 are supported. "); + break; + } + } +} + +std::vector> RebuildPaddingV2InferShape(const std::vector& tmp_out_shape, + const std::vector& cum_offsets_shape, + const std::vector& seq_lens_decoder_shape, + const std::vector& seq_lens_encoder_shape) { + int64_t bsz = cum_offsets_shape[0]; + int64_t dim_embed = tmp_out_shape[1]; + return {{bsz, dim_embed}}; +} + +std::vector RebuildPaddingV2InferDtype(const paddle::DataType& tmp_out_dtype, + const paddle::DataType& cum_offsets_dtype, + const paddle::DataType& seq_lens_decoder_dtype, + const paddle::DataType& seq_lens_encoder_dtype) { + return {tmp_out_dtype}; +} + +PD_BUILD_OP(rebuild_padding_v2) + .Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder"}) + .Outputs({"out"}) + .Attrs({"max_input_length: int"}) + .SetKernelFn(PD_KERNEL(RebuildPaddingV2)) + .SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingV2InferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingV2InferDtype)); \ No newline at end of file diff --git a/csrc/generation/reset_need_stop_value.cc b/csrc/generation/reset_need_stop_value.cc new file mode 100644 index 000000000000..07efb643d067 --- /dev/null +++ b/csrc/generation/reset_need_stop_value.cc @@ -0,0 +1,12 @@ +#include "paddle/extension.h" + +void SetStopValue(const paddle::Tensor& not_need_stop) { + bool *stop_data = const_cast(not_need_stop.data()); + stop_data[0] = true; +} + +PD_BUILD_OP(reset_stop_value) + .Inputs({"not_need_stop"}) + .Outputs({"not_need_stop_out"}) + .SetInplaceMap({{"not_need_stop", "not_need_stop_out"}}) + .SetKernelFn(PD_KERNEL(SetStopValue)); diff --git a/csrc/generation/save_with_output_msg.cc b/csrc/generation/save_with_output_msg.cc new file mode 100644 index 000000000000..ea04f8e3e6a0 --- /dev/null +++ b/csrc/generation/save_with_output_msg.cc @@ -0,0 +1,58 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" + +#define MAX_BSZ 512 + +struct msgdata { + long mtype; + int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens +}; + +void SaveOutMmsg(const paddle::Tensor& x, + const paddle::Tensor& not_need_stop, + int64_t rank_id) { + if (rank_id > 0) return; + auto x_cpu = x.copy_to(paddle::CPUPlace(), false); + int64_t *x_data = x_cpu.data(); + static struct msgdata msg_sed; + static key_t key = ftok("./", 1); + static int msgid = msgget(key, IPC_CREAT | 0666); + + msg_sed.mtype = 1; + bool not_need_stop_data = not_need_stop.data()[0]; + msg_sed.mtext[0] = not_need_stop_data ? 1 : -1; + int bsz = x.shape()[0]; + msg_sed.mtext[1] = bsz; + for (int i = 2; i < bsz + 2; i++) { + msg_sed.mtext[i] = (int)x_data[i - 2]; + } + if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) { + // printf("full msg buffer\n"); + } + return; +} + +PD_BUILD_OP(save_output) + .Inputs({"x", "not_need_stop"}) + .Attrs({"rank_id: int64_t"}) + .Outputs({"x_out"}) + .SetInplaceMap({{"x", "x_out"}}) + .SetKernelFn(PD_KERNEL(SaveOutMmsg)); \ No newline at end of file diff --git a/csrc/generation/set_value_by_flags_v2.cu b/csrc/generation/set_value_by_flags_v2.cu new file mode 100644 index 000000000000..f954c8c96d1d --- /dev/null +++ b/csrc/generation/set_value_by_flags_v2.cu @@ -0,0 +1,58 @@ +#include "paddle/extension.h" + +__global__ void set_value_by_flag_and_id_v2(const bool *stop_flags, + int64_t *pre_ids_all, + const int64_t *input_ids, + const int *seq_lens_encoder, + const int *seq_lens_decoder, + const int64_t *step_idx, + int bs, + int length, + int length_input_ids) { + int tid = threadIdx.x; + if (tid < bs && !stop_flags[tid]) { + int64_t *pre_ids_all_now = pre_ids_all + tid * length; + const int64_t *input_ids_now = input_ids + tid * length_input_ids; + const int seq_len_dec = seq_lens_decoder[tid]; + const int seq_len_enc = seq_lens_encoder[tid]; + if (seq_len_dec == 0 && seq_len_enc == 0) return; // stoped + if (step_idx[tid] >= 0) { + if (seq_len_dec == 0) { // encoder, get last token accord to seq_lens_encoder + pre_ids_all_now[step_idx[tid]] = input_ids_now[seq_len_enc - 1]; + } else { // decoedr, get first token + pre_ids_all_now[step_idx[tid]] = input_ids_now[0]; + } + } + } +} + +void SetValueByFlagsAndIdxV2(const paddle::Tensor& pre_ids_all, + const paddle::Tensor& input_ids, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& step_idx, + const paddle::Tensor& stop_flags) { + auto cu_stream = stop_flags.stream(); + std::vector pre_ids_all_shape = pre_ids_all.shape(); + + int bs = seq_lens_this_time.shape()[0]; + int length = pre_ids_all_shape[1]; + int length_input_ids = input_ids.shape()[1]; + int block_size = (bs + 32 - 1) / 32 * 32; + set_value_by_flag_and_id_v2<<<1, block_size, 0, cu_stream>>>(stop_flags.data(), + const_cast(pre_ids_all.data()), + input_ids.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + step_idx.data(), + bs, + length, + length_input_ids); +} + +PD_BUILD_OP(set_value_by_flags_and_idx_v2) + .Inputs({"pre_ids_all", "input_ids", "seq_lens_this_time", "seq_lens_encoder", "seq_lens_decoder", "step_idx", "stop_flags"}) + .Outputs({"pre_ids_all_out"}) + .SetInplaceMap({{"pre_ids_all", "pre_ids_all_out"}}) + .SetKernelFn(PD_KERNEL(SetValueByFlagsAndIdxV2)); \ No newline at end of file diff --git a/csrc/generation/step.cu b/csrc/generation/step.cu new file mode 100644 index 000000000000..b586db566916 --- /dev/null +++ b/csrc/generation/step.cu @@ -0,0 +1,354 @@ +#include "helper.h" + +// #define DEBUG_STEP + +__global__ void free_and_dispatch_block(bool *stop_flags, + int *seq_lens_this_time, + int *seq_lens_decoder, + int *block_tables, + int *encoder_block_lens, + bool *is_block_step, + int *step_block_list, // [bsz] + int *step_len, + int *recover_block_list, + int *recover_len, + int *need_block_list, + int *need_block_len, + int *used_list_len, + int *free_list, + int *free_list_len, + const int bsz, + const int block_size, + const int block_num_per_seq, + const int max_decoder_block_num) { + typedef cub::BlockReduce, 512> BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + const int tid = threadIdx.x; + if (tid < bsz) { + int *block_table_now = block_tables + tid * block_num_per_seq; + if (stop_flags[tid] && !is_block_step[tid]) { + // 回收block块 + const int encoder_block_len = encoder_block_lens[tid]; + const int decoder_used_len = used_list_len[tid]; + if (decoder_used_len > 0) { + const int ori_free_list_len = atomicAdd(free_list_len, decoder_used_len); +#ifdef DEBUG_STEP + printf("free block seq_id: %d, free block num: %d, encoder_block_len: %d, ori_free_list_len: %d\n", + tid, decoder_used_len, encoder_block_len, ori_free_list_len); +#endif + for (int i = 0; i < decoder_used_len; i++) { + free_list[ori_free_list_len + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + encoder_block_lens[tid] = 0; + used_list_len[tid] = 0; + } + } else if (seq_lens_decoder[tid] != 0 && block_table_now[seq_lens_decoder[tid] / block_size] == -1) { + // 统计需要分配block的位置和总数 + const int ori_need_block_len = atomicAdd(need_block_len, 1); + need_block_list[ori_need_block_len] = tid; +#ifdef DEBUG_STEP + printf("seq_id: %d need block\n", tid); +#endif + } + } + __syncthreads(); + if (tid == 0) { + printf("need_block_len: %d, free_list_len: %d\n", need_block_len[0], free_list_len[0]); + } + + while (need_block_len[0] > free_list_len[0]) { + +#ifdef DEBUG_STEP + if (tid == 0) { + printf("need_block_len: %d, free_list_len: %d\n", need_block_len[0], free_list_len[0]); + } +#endif + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len + const int used_block_num = tid < bsz && !is_block_step[tid]? used_list_len[tid] : 0; + cub::KeyValuePair kv_pair = {tid, used_block_num}; + kv_pair = BlockReduce(temp_storage).Reduce(kv_pair, cub::ArgMax()); + if (tid == 0) { + const int encoder_block_len = encoder_block_lens[kv_pair.key]; +#ifdef DEBUG_STEP + printf("max_id: %d, max_num: %d, encoder_block_len: %d\n", kv_pair.key, kv_pair.value, encoder_block_len); +#endif + int *block_table_now = block_tables + kv_pair.key * block_num_per_seq; + for (int i = 0; i < kv_pair.value; i++) { + free_list[free_list_len[0] + i] = block_table_now[encoder_block_len + i]; + block_table_now[encoder_block_len + i] = -1; + } + step_block_list[step_len[0]] = kv_pair.key; + step_len[0] += 1; + free_list_len[0] += kv_pair.value; + stop_flags[kv_pair.key] = true; + is_block_step[kv_pair.key] = true; + seq_lens_this_time[kv_pair.key] = 0; + seq_lens_decoder[kv_pair.key] = 0; + } + __syncthreads(); + } + + // 为需要block的位置分配block,每个位置分配一个block + if (tid < need_block_len[0]) { + const int need_block_id = need_block_list[tid]; + if (!stop_flags[need_block_id]) { + // 如果需要的位置正好是上一步中被释放的位置,不做处理 + used_list_len[need_block_id] += 1; + const int ori_free_list_len = atomicSub(free_list_len, 1); + int *block_table_now = block_tables + need_block_id * block_num_per_seq; + block_table_now[seq_lens_decoder[need_block_id] / block_size] = free_list[ori_free_list_len - 1]; + } + need_block_list[tid] = -1; + } + __syncthreads(); + + // 计算可以复原的query id + if (tid == 0) { + int ori_free_list_len = free_list_len[0]; + int ori_step_len = step_len[0]; + printf("ori_step_len %d\n", ori_step_len); + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + while (ori_step_len > 0 && ori_free_list_len >= used_len) { +#ifdef DEBUG_STEP + printf("recover seq_id: %d, free_list_len: %d, used_list_len: %d\n", + ori_step_block_id, ori_free_list_len, used_len); +#endif + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step[ori_step_block_id] = false; + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len < max_decoder_block_num ? tmp_used_len + 1 : tmp_used_len; + } + } + need_block_len[0] = 0; + } +} + +// 根据上一步计算出的可以复原的query_id进行状态恢复 +__global__ void recover_block(int *recover_block_list, // [bsz] + int *recover_len, + bool *stop_flags, + int *seq_lens_this_time, + int *ori_seq_lens_encoder, + int *seq_lens_encoder, + int *seq_lens_decoder, + int *block_tables, + int *free_list, + int *free_list_len, + int64_t *input_ids, + int64_t *pre_ids, + int64_t *step_idx, + int *encoder_block_lens, + int *used_list_len, + const int64_t *next_tokens, + const int bsz, + const int block_num_per_seq, + const int length, + const int pre_id_length, + const int first_token_id) { + const int bid = blockIdx.x; + const int tid = threadIdx.x; + __shared__ int ori_free_list_len; + if (bid < recover_len[0]) { + const int recover_id = recover_block_list[bid]; + const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + const int step_idx_now = step_idx[recover_id]; + const int seq_len = ori_seq_len_encoder + step_idx_now; + const int encoder_block_len = encoder_block_lens[recover_id]; + const int decoder_used_len = used_list_len[recover_id]; + int *block_table_now = block_tables + recover_id * block_num_per_seq; + int64_t *input_ids_now = input_ids + recover_id * length; + int64_t *pre_ids_now = pre_ids + recover_id * pre_id_length; + if (tid == 0) { + seq_lens_this_time[recover_id] = seq_len; + seq_lens_encoder[recover_id] = seq_len; + stop_flags[recover_id] = false; + input_ids_now[ori_seq_len_encoder + step_idx_now - 1] = next_tokens[recover_id]; // next tokens + input_ids_now[0] = first_token_id; // set first prompt token + const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len); + ori_free_list_len = ori_free_list_len_tid0; +#ifdef DEBUG_STEP + printf("seq_id: %d, ori_seq_len_encoder: %d, step_idx_now: %d, seq_len: %d, ori_free_list_len_tid0: %d, ori_free_list_len: %d\n", + recover_id, ori_seq_len_encoder, step_idx_now, seq_len, ori_free_list_len_tid0, ori_free_list_len); +#endif + } + __syncthreads(); + // 恢复block table + for (int i = tid; i < decoder_used_len; i += blockDim.x) { + block_table_now[encoder_block_len + i] = free_list[ori_free_list_len - i - 1]; + } + // 恢复input_ids + for (int i = tid; i < step_idx_now - 1; i += blockDim.x) { + input_ids_now[ori_seq_len_encoder + i] = pre_ids_now[i + 1]; + } + } + + if (bid == 0 && tid == 0) { + recover_len[0] = 0; + } +} + +void StepPaddle(const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& ori_seq_lens_encoder, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& step_block_list, + const paddle::Tensor& step_lens, + const paddle::Tensor& recover_block_list, + const paddle::Tensor& recover_lens, + const paddle::Tensor& need_block_list, + const paddle::Tensor& need_block_len, + const paddle::Tensor& used_list_len, + const paddle::Tensor& free_list, + const paddle::Tensor& free_list_len, + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, + const paddle::Tensor& next_tokens, + const int block_size, + const int encoder_decoder_block_num, + const int64_t first_token_id) { + auto cu_stream = seq_lens_this_time.stream(); + const int bsz = seq_lens_this_time.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + constexpr int BlockSize = 512; // bsz < 256 + const int max_decoder_block_num = pre_id_length / block_size - encoder_decoder_block_num; +#ifdef DEBUG_STEP + printf("bsz: %d, block_num_per_seq: %d, length: %d, max_decoder_block_num: %d\n", bsz, block_num_per_seq, length, max_decoder_block_num); +#endif + free_and_dispatch_block<<<1, BlockSize, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num + ); +#ifdef DEBUG_STEP + cudaDeviceSynchronize(); +#endif + auto cpu_recover_lens = recover_lens.copy_to(paddle::CPUPlace(), false); + const int grid_size = cpu_recover_lens.data()[0]; +#ifdef DEBUG_STEP + printf("grid_size2 %d\n", grid_size); +#endif + if (grid_size > 0) { + recover_block<<>>( + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(ori_seq_lens_encoder.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(input_ids.data()), + const_cast(pre_ids.data()), + const_cast(step_idx.data()), + const_cast(encoder_block_lens.data()), + const_cast(used_list_len.data()), + next_tokens.data(), + bsz, + block_num_per_seq, + length, + pre_id_length, + first_token_id + ); +#ifdef DEBUG_STEP + cudaDeviceSynchronize(); +#endif + } +} + +PD_BUILD_OP(step_paddle) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens"}) + .Attrs({"block_size: int", + "encoder_decoder_block_num: int", + "first_token_id: int64_t"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"block_tables", "block_tables_out"}, + {"encoder_block_lens", "encoder_block_lens_out"}, + {"is_block_step", "is_block_step_out"}, + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}}) + .SetKernelFn(PD_KERNEL(StepPaddle)); \ No newline at end of file diff --git a/csrc/generation/stop_generation_multi_ends_v2.cu b/csrc/generation/stop_generation_multi_ends_v2.cu new file mode 100644 index 000000000000..7f23029681a5 --- /dev/null +++ b/csrc/generation/stop_generation_multi_ends_v2.cu @@ -0,0 +1,75 @@ +#include "paddle/extension.h" +#include +#include +#include +#include +#include +#include +#include +#include + +__device__ bool is_in_end_v2(const int64_t id, const int64_t *end_ids, int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} + +__global__ void set_value_by_flags_v2( + bool *stop_flags, + int64_t *topk_ids, + int64_t *next_tokens, + const int64_t *end_ids, + const int *seq_lens, + const int bs, + const int end_length) { + int tid = threadIdx.x; + if (tid < bs) { + if (stop_flags[tid]) { + if (seq_lens[tid] == 0) { + topk_ids[tid] = -1; + } else { + topk_ids[tid] = end_ids[0]; + next_tokens[tid] = end_ids[0]; + } + } else { + next_tokens[tid] = topk_ids[tid]; + } + if (is_in_end_v2(topk_ids[tid], end_ids, end_length)) { + stop_flags[tid] = true; + } + } +} + +void GetStopFlagsMultiV2(const paddle::Tensor& topk_ids, + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens, + const paddle::Tensor& end_ids, + const paddle::Tensor& next_tokens) { + PD_CHECK(topk_ids.dtype() == paddle::DataType::INT64); + PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL); + + auto cu_stream = topk_ids.stream(); + std::vector shape = topk_ids.shape(); + int64_t bs_now = shape[0]; + int64_t end_length = end_ids.shape()[0]; + int block_size = (bs_now + 32 - 1) / 32 * 32; + set_value_by_flags_v2<<<1, block_size, 0, cu_stream>>>( + const_cast(stop_flags.data()), + const_cast(topk_ids.data()), + const_cast(next_tokens.data()), + end_ids.data(), + seq_lens.data(), + bs_now, end_length); +} + +PD_BUILD_OP(set_stop_value_multi_ends_v2) + .Inputs({"topk_ids", "stop_flags", "seq_lens", "end_ids", "next_tokens"}) + .Outputs({"topk_ids_out", "stop_flags_out", "next_tokens_out"}) + .SetInplaceMap({{"topk_ids", "topk_ids_out"}, + {"stop_flags", "stop_flags_out"}, + {"next_tokens", "next_tokens_out"}}) + .SetKernelFn(PD_KERNEL(GetStopFlagsMultiV2)); \ No newline at end of file diff --git a/csrc/generation/token_penalty_multi_scores.cu b/csrc/generation/token_penalty_multi_scores.cu index 3ef010501921..d32f872d705c 100644 --- a/csrc/generation/token_penalty_multi_scores.cu +++ b/csrc/generation/token_penalty_multi_scores.cu @@ -228,4 +228,4 @@ PD_BUILD_OP(get_token_penalty_multi_scores) .Outputs({"logits_out"}) .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScores)) .SetInferShapeFn(PD_INFER_SHAPE(TokenPenaltyMultiScoresInferShape)) - .SetInferDtypeFn(PD_INFER_DTYPE(TokenPenaltyMultiScoresInferDtype)); + .SetInferDtypeFn(PD_INFER_DTYPE(TokenPenaltyMultiScoresInferDtype)); \ No newline at end of file diff --git a/csrc/generation/token_penalty_multi_scores_v2.cu b/csrc/generation/token_penalty_multi_scores_v2.cu new file mode 100644 index 000000000000..b1bbdd4a40d6 --- /dev/null +++ b/csrc/generation/token_penalty_multi_scores_v2.cu @@ -0,0 +1,250 @@ +#include "helper.h" + + +template +__global__ inline void min_length_logits_process_v2(T* logits, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t bs, + const int64_t length, + const int64_t end_length) { + int bi = threadIdx.x; + if (bi >= bs) return; + if (cur_len[bi] < 0) { + return; + } + if (cur_len[bi] < min_len[bi]) { + for (int i=0; i < end_length; i++) { + logits[bi * length + eos_token_id[i]] = -1e10; + } + } +} + +template<> +__global__ inline void min_length_logits_process_v2(half* logits, + const int64_t *cur_len, + const int64_t *min_len, + const int64_t *eos_token_id, + const int64_t bs, + const int64_t length, + const int64_t end_length) { + int bi = threadIdx.x; + if (bi >= bs) return; + if (cur_len[bi] < 0) { + return; + } + if (cur_len[bi] < min_len[bi]) { + for (int i=0; i < end_length; i++) { + logits[bi * length + eos_token_id[i]] = -1e4; + } + } +} + + +__global__ void update_repeat_times_v2(const int64_t *pre_ids, + const int64_t *cur_len, + int *repeat_times, + const int64_t bs, + const int64_t length, + const int64_t length_id) { + int bi = blockIdx.x; + if (cur_len[bi] < 0) { + return; + } + int tid = threadIdx.x; + const int64_t *pre_ids_now = pre_ids + bi * length_id; + int *repeat_times_now = repeat_times + bi * length; + for (int i = tid; i < length_id; i += blockDim.x) { + int64_t id = pre_ids_now[i]; + if (id < 0) break; + atomicAdd(&repeat_times_now[id], 1); + } +} + +template +__global__ void update_value_by_repeat_times_v2(const int *repeat_times, + const T *penalty_scores, + const T *frequency_score, + const T *presence_score, + const float *temperatures, + T *logits, + const int64_t bs, + const int64_t length) { + int bi = blockIdx.x; + int tid = threadIdx.x; + T *logits_now = logits + bi * length; + const int *repeat_times_now = repeat_times + bi * length; + float alpha = static_cast(penalty_scores[bi]); + float beta = static_cast(frequency_score[bi]); + float gamma = static_cast(presence_score[bi]); + for (int i = tid; i < length; i += blockDim.x) { + int times = repeat_times_now[i]; + float logit_now = static_cast(logits_now[i]); + if (times != 0) { + logit_now = logit_now < 0 ? logit_now * alpha : logit_now / alpha; + logit_now = logit_now - times * beta - gamma; + } + logits_now[i] = static_cast(logit_now / temperatures[bi]); + } +} + +template +__global__ void ban_bad_words(T *logits, + const int64_t *bad_words_list, + const int64_t bs, + const int64_t length, + const int64_t bad_words_length) { + const int bi = blockIdx.x; + int tid = threadIdx.x; + T *logits_now = logits + bi * length; + for (int i = tid; i < bad_words_length; i += blockDim.x) { + const int64_t bad_words_token_id = bad_words_list[i]; + if (bad_words_token_id >= length || bad_words_token_id < 0) continue; + logits_now[bad_words_token_id] = -1e10; + } +} + +template +void token_penalty_multi_scores_kernel_v2(const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_score, + const paddle::Tensor& presence_score, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id) { + + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + auto cu_stream = logits.stream(); + std::vector shape = logits.shape(); + auto repeat_times = paddle::full(shape, 0, paddle::DataType::INT32, pre_ids.place()); + int64_t bs = shape[0]; + int64_t length = shape[1]; + int64_t length_id = pre_ids.shape()[1]; + int64_t length_bad_words = bad_tokens.shape()[0]; + + int64_t end_length = eos_token_id.shape()[0]; + + int block_size = (bs + 32 - 1) / 32 * 32; + min_length_logits_process_v2<<<1, block_size, 0, cu_stream>>>( + reinterpret_cast(const_cast(logits.data())), + cur_len.data(), + min_len.data(), + eos_token_id.data(), + bs, length, end_length); + + block_size = (length_id + 32 - 1) / 32 * 32; + block_size = min(block_size, 512); + update_repeat_times_v2<<>>( + pre_ids.data(), + cur_len.data(), + repeat_times.data(), + bs, + length, + length_id); + + block_size = (length + 32 - 1) / 32 * 32; + block_size = min(block_size, 512); + update_value_by_repeat_times_v2<<>>( + repeat_times.data(), + reinterpret_cast(const_cast(penalty_scores.data())), + reinterpret_cast(const_cast(frequency_score.data())), + reinterpret_cast(const_cast(presence_score.data())), + temperatures.data(), + reinterpret_cast(const_cast(logits.data())), + bs, + length); + + block_size = (length_bad_words + 32 - 1) / 32 * 32; + block_size = min(block_size, 512); + ban_bad_words<<>>( + reinterpret_cast(const_cast(logits.data())), + bad_tokens.data(), + bs, + length, + length_bad_words + ); +} + +void TokenPenaltyMultiScoresV2(const paddle::Tensor& pre_ids, + const paddle::Tensor& logits, + const paddle::Tensor& penalty_scores, + const paddle::Tensor& frequency_scores, + const paddle::Tensor& presence_scores, + const paddle::Tensor& temperatures, + const paddle::Tensor& bad_tokens, + const paddle::Tensor& cur_len, + const paddle::Tensor& min_len, + const paddle::Tensor& eos_token_id) { + + switch (logits.type()) { + case paddle::DataType::BFLOAT16: { + return token_penalty_multi_scores_kernel_v2( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id + ); + } + case paddle::DataType::FLOAT16: { + return token_penalty_multi_scores_kernel_v2( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id + ); + } + case paddle::DataType::FLOAT32: { + return token_penalty_multi_scores_kernel_v2( + pre_ids, + logits, + penalty_scores, + frequency_scores, + presence_scores, + temperatures, + bad_tokens, + cur_len, + min_len, + eos_token_id + ); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only float16, bfloat16 and float32 are supported. "); + break; + } + } +} + +PD_BUILD_OP(get_token_penalty_multi_scores_v2) + .Inputs({"pre_ids", + "logits", + "penalty_scores", + "frequency_scores", + "presence_scores", + "temperatures", + "bad_tokens", + "cur_len", + "min_len", + "eos_token_id"}) + .Outputs({"logits_out"}) + .SetInplaceMap({{"logits", "logits_out"}}) + .SetKernelFn(PD_KERNEL(TokenPenaltyMultiScoresV2)); \ No newline at end of file diff --git a/csrc/generation/update_inputs.cu b/csrc/generation/update_inputs.cu new file mode 100644 index 000000000000..ab9bcde27208 --- /dev/null +++ b/csrc/generation/update_inputs.cu @@ -0,0 +1,106 @@ +#include "helper.h" + +template +__global__ void update_inputs_kernel( + bool *not_need_stop, + int *seq_lens_this_time, + int *seq_lens_encoder, + int *seq_lens_decoder, + int64_t *input_ids, + const int64_t *stop_nums, + const bool *stop_flags, + const bool *is_block_step, + const int64_t *next_tokens, + const int bsz, + const int max_bsz, + const int input_ids_stride) { + int thread_idx = threadIdx.x; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + bool stop_flag_now = false; + int64_t stop_flag_now_int = 0; + if (thread_idx < max_bsz) { + if (thread_idx < bsz) { + stop_flag_now = stop_flags[thread_idx]; + if (is_block_step[thread_idx]) { + stop_flag_now_int = 0; + } else { + stop_flag_now_int = static_cast(stop_flag_now); + } + } else { + stop_flag_now_int = 1; + } + } + if (thread_idx < bsz) { + const int seq_len_this_time = seq_lens_this_time[thread_idx]; + const int seq_len_encoder = seq_lens_encoder[thread_idx]; + const int seq_len_decoder = seq_lens_decoder[thread_idx]; + + seq_lens_decoder[thread_idx] = stop_flag_now ? 0 : (seq_len_decoder == 0 ? seq_len_encoder : seq_len_decoder + 1); + + seq_lens_this_time[thread_idx] = stop_flag_now ? 0 : 1; + seq_lens_encoder[thread_idx] = 0; + int64_t *input_ids_now = input_ids + thread_idx * input_ids_stride; + input_ids_now[0] = next_tokens[thread_idx]; + } + __syncthreads(); + int64_t stop_sum = BlockReduce(temp_storage).Sum(stop_flag_now_int); + if (thread_idx == 0) { + not_need_stop[0] = stop_sum < stop_nums[0]; + } +} + +void UpdateInputes(const paddle::Tensor& stop_flags, + const paddle::Tensor& not_need_stop, // cpu + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_decoder, + const paddle::Tensor& input_ids, + const paddle::Tensor& stop_nums, + const paddle::Tensor& next_tokens, + const paddle::Tensor& is_block_step) { + const int max_bsz = stop_flags.shape()[0]; + const int now_bsz = seq_lens_this_time.shape()[0]; + const int input_ids_stride = input_ids.shape()[1]; + auto not_need_stop_gpu = not_need_stop.copy_to(stop_flags.place(), false); + update_inputs_kernel<1024><<<1, 1024, 0, input_ids.stream()>>>( + const_cast(not_need_stop_gpu.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(input_ids.data()), + stop_nums.data(), + stop_flags.data(), + is_block_step.data(), + next_tokens.data(), + now_bsz, + max_bsz, + input_ids_stride + ); + auto not_need_stop_cpu = not_need_stop_gpu.copy_to(not_need_stop.place(), false); + bool *not_need_stop_data = const_cast(not_need_stop.data()); + not_need_stop_data[0] = not_need_stop_cpu.data()[0]; +} + +PD_BUILD_OP(update_inputs) + .Inputs({"stop_flags", + "not_need_stop", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "input_ids", + "stop_nums", + "next_tokens", + "is_block_step"}) + .Outputs({"not_need_stop_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "input_ids_out"}) + .SetInplaceMap({{"not_need_stop", "not_need_stop_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"input_ids", "input_ids_out"}}) + .SetKernelFn(PD_KERNEL(UpdateInputes)); diff --git a/csrc/generation/write_int8_cache_kv.cu b/csrc/generation/write_int8_cache_kv.cu new file mode 100644 index 000000000000..3e423f0d9db7 --- /dev/null +++ b/csrc/generation/write_int8_cache_kv.cu @@ -0,0 +1,349 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "helper.h" + +constexpr int32_t WARP_SIZE = 32; +constexpr int32_t HALF_WARP = 16; +constexpr float QUANT_MAX_BOUND = 127.0; +constexpr float QUANT_MIN_BOUND = -127.0; + +template +inline __device__ __host__ T div_up(T m, T n) { + return (m + n - 1) / n; +} + +template +struct QuantFunc{ + __host__ __device__ uint8_t operator()(T x, float quant_scale) { + float tmp = static_cast(x) * quant_scale; + tmp = round(tmp); + if (tmp > QUANT_MAX_BOUND) + tmp = QUANT_MAX_BOUND; + else if (tmp < QUANT_MIN_BOUND) + tmp = QUANT_MIN_BOUND; + return static_cast(tmp + 128.0f);; + } +}; + +template +struct MaxFunc{ + __device__ T operator()(T a, T b){ + return max(a, b); + } +}; + +template<> +struct MaxFunc{ + __device__ half operator()(half a, half b){ +#if __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return max(static_cast(a), static_cast(b)); +#endif + } +}; + +template<> +struct MaxFunc<__nv_bfloat16>{ + __device__ __nv_bfloat16 operator()(__nv_bfloat16 a, __nv_bfloat16 b){ +#if __CUDA_ARCH__ >= 800 + return __hmax(a, b); +#else + return max(static_cast(a), static_cast(b)); +#endif + } +}; + +template +struct AbsFunc{ + __device__ T operator()(T x){ + return abs(x); + } +}; + +template<> +struct AbsFunc{ + __device__ half operator()(half x){ + #if __CUDA_ARCH__ >= 800 + return __habs(x); + #else + return abs(static_cast(x)); + #endif + } +}; + +template<> +struct AbsFunc<__nv_bfloat16>{ + __device__ __nv_bfloat16 operator()(__nv_bfloat16 x){ + #if __CUDA_ARCH__ >= 800 + return __habs(x); + #else + return abs(static_cast(x)); + #endif + } +}; + +template +__inline__ __device__ T LocalReduceMax(Vec& vec) { + T local_max = static_cast(0.0); + #pragma unroll + for (int i = 0; i < VecSize; ++i) { + local_max = vec[i] > local_max ? vec[i] : local_max; + } + return local_max; +} + +template +__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) { + #pragma unroll + for (int mask = HALF_WARP; mask > 0; mask >>= 1){ + val = MaxFunc()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE)); + } + return val; +} + +template +__inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) { + static __shared__ T smem[WARP_SIZE]; + int32_t lane_id = threadIdx.x % WARP_SIZE; + int32_t warp_id = threadIdx.x / WARP_SIZE; + + val = WarpReduceAbsMax(val, mask); + + if (lane_id == 0) { + smem[warp_id] = val; + } + + __syncthreads(); + + T abs_max_val = (threadIdx.x < (blockDim.x / WARP_SIZE)) ? smem[threadIdx.x] : static_cast(0.0f); + abs_max_val = WarpReduceAbsMax(abs_max_val, mask); + return abs_max_val; +} + + +template +__global__ void write_cache_k_int8_kernel(const T* k, const int64_t num_head, const int64_t dim_head, const int64_t seq_len, int max_seq_len, uint8_t* cache, float* quant_scales, float* dequant_scales) { + const int bi = blockIdx.y; + const int hi = blockIdx.x; + + using InVec = AlignedVector; + using OutVec = AlignedVector; + + InVec in_vec; + OutVec out_vec; + InVec abs_max_vec; +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + abs_max_vec[i] = 0.0f; + } + + T local_abs_max; + + for (int idx = threadIdx.x * VecSize; idx < seq_len * dim_head; idx += blockDim.x * VecSize) { + int linear_idx = bi * num_head * seq_len * dim_head + hi * seq_len * dim_head + idx; + Load(k + linear_idx, &in_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + abs_max_vec[i] = MaxFunc()(abs_max_vec[i], AbsFunc()(in_vec[i])); + } + } + + local_abs_max = LocalReduceMax(abs_max_vec); + T abs_max_val = BlockReduceAbsMax(local_abs_max, 0xffffffff); + + __shared__ float quant_scale; + if (threadIdx.x == 0) { + quant_scale = 127.0f / static_cast(abs_max_val); + } + + __syncthreads(); + + for (int idx = threadIdx.x * VecSize; idx < seq_len * dim_head; idx += blockDim.x * VecSize) { + int linear_idx = bi * num_head * seq_len * dim_head + hi * seq_len * dim_head + idx; + // [bsz, num_head, seq_len, dim_head/x, x] + Load(k + linear_idx, &in_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = QuantFunc()(in_vec[i], quant_scale); + } + int dim_head_div_x = dim_head / VecSize; + int seq_id = idx / dim_head; + int vec_id = threadIdx.x % dim_head_div_x; + // [bsz, num_head, dim_head/x, max_seq_len, x] + Store(out_vec, cache + bi * num_head * max_seq_len * dim_head + hi * max_seq_len * dim_head + vec_id * max_seq_len * VecSize + seq_id * VecSize); + } + + if (threadIdx.x == 0) { + quant_scales[bi * num_head + hi] = quant_scale; + dequant_scales[bi * num_head + hi] = 1.0f / quant_scale; + } +} + +template +__global__ void write_cache_v_int8_kernel(const T* v, const int64_t num_head, const int64_t dim_head, const int64_t seq_len, int max_seq_len, uint8_t* cache, float* quant_scales, float* dequant_scales) { + const int bi = blockIdx.y; + const int hi = blockIdx.x; + + using InVec = AlignedVector; + using OutVec = AlignedVector; + + InVec in_vec; + OutVec out_vec; + InVec abs_max_vec; + #pragma unroll + for (int i = 0; i < VecSize; ++i) { + abs_max_vec[i] = 0.0f; + } + + T local_abs_max; + + for (int idx = threadIdx.x * VecSize; idx < seq_len * dim_head; idx += blockDim.x * VecSize) { + int linear_idx = bi * num_head * seq_len * dim_head + hi * seq_len * dim_head + idx; + Load(v + linear_idx, &in_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + abs_max_vec[i] = MaxFunc()(abs_max_vec[i], AbsFunc()(in_vec[i])); + } + } + + local_abs_max = LocalReduceMax(abs_max_vec); + T abs_max_val = BlockReduceAbsMax(local_abs_max, 0xffffffff); + + __shared__ float quant_scale; + if (threadIdx.x == 0) { + quant_scale = 127.0f / static_cast(abs_max_val); + } + + __syncthreads(); + for (int idx = threadIdx.x * VecSize; idx < seq_len * dim_head; idx += blockDim.x * VecSize) { + int linear_idx = bi * num_head * seq_len * dim_head + hi * seq_len * dim_head + idx; + // [bsz, num_head, seq_len, dim_head/x, x] + Load(v + linear_idx, &in_vec); +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + out_vec[i] = QuantFunc()(in_vec[i], quant_scale); + } + int dim_head_div_x = dim_head / VecSize; + int seq_id = idx / dim_head; + int vec_id = threadIdx.x % dim_head_div_x; + // [bsz, num_head, max_seq_len, dim_head/x, x] + Store(out_vec, cache + bi * num_head * max_seq_len * dim_head + hi * max_seq_len * dim_head + seq_id * dim_head + vec_id * VecSize); + } + + if (threadIdx.x == 0) { + quant_scales[bi * num_head + hi] = quant_scale; + dequant_scales[bi * num_head + hi] = 1.0f / quant_scale; + } +} + +template +void LaunchWriteInt8CacheKV(const paddle::Tensor& input_k, + const paddle::Tensor& input_v, + const paddle::Tensor& cache_kv, + const paddle::Tensor& k_quant_scales, + const paddle::Tensor& v_quant_scales, + const paddle::Tensor& k_dequant_scales, + const paddle::Tensor& v_dequant_scales + ) { + typedef PDTraits traits_; + typedef typename traits_::DataType DataType_; + typedef typename traits_::data_t data_t; + + const int64_t bsz = input_k.shape()[0]; + const int64_t seq_len = input_k.shape()[2]; + const int64_t cache_bsz = cache_kv.shape()[1]; + const int64_t num_head = cache_kv.shape()[2]; + const int64_t dim_head = cache_kv.shape()[4]; + + auto cache_kv_out = paddle::full({1}, -1, paddle::DataType::UINT8, cache_kv.place()); + + const DataType_ *k_ptr = reinterpret_cast(input_k.data()); + const DataType_ *v_ptr = reinterpret_cast(input_v.data()); + + // [2, bsz, num_head, max_seq_len, head_dim] + int max_seq_len = cache_kv.shape()[3]; + uint8_t *cache_kv_data = reinterpret_cast(const_cast(cache_kv.data())); + + float* k_quant_scales_data = const_cast(k_quant_scales.data()); + float* k_dequant_scales_data = const_cast(k_dequant_scales.data()); + + float* v_quant_scales_data = const_cast(v_quant_scales.data()); + float* v_dequant_scales_data = const_cast(v_dequant_scales.data()); + + int64_t cache_k_size = cache_bsz * num_head * max_seq_len * dim_head; + + uint8_t *cache_k_ptr = cache_kv_data; + uint8_t *cache_v_ptr = cache_kv_data + cache_k_size; + + constexpr int block_sz = 512; + constexpr int VecSize = VEC_16B / sizeof(DataType_); + + assert(dim_head % VecSize == 0); + // PD_CHECK((dim_head % x) == 0, "PD_CHECK returns ", false, ", dim_head must be divisible by vec_size."); + + dim3 grid(num_head, bsz); + + // transpose [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, dim_head/x, max_seq_len, x] + write_cache_k_int8_kernel<<>>( + k_ptr, num_head, dim_head, seq_len, max_seq_len, cache_k_ptr, k_quant_scales_data, k_dequant_scales_data); + + + // copy [bsz, num_head, seq_len, dim_head/x, x]-> + // [bsz, num_head, max_seq_len, dim_head/x, x] + write_cache_v_int8_kernel<<>>( + v_ptr, num_head, dim_head, seq_len, max_seq_len, cache_v_ptr, v_quant_scales_data, v_dequant_scales_data); + +} + + +void WriteInt8CacheKV(const paddle::Tensor& input_k, + const paddle::Tensor& input_v, + const paddle::Tensor& cache_kv, + const paddle::Tensor& k_quant_scales, + const paddle::Tensor& v_quant_scales, + const paddle::Tensor& k_dequant_scales, + const paddle::Tensor& v_dequant_scales) { + switch (input_k.type()) { + case paddle::DataType::BFLOAT16: { + return LaunchWriteInt8CacheKV( + input_k, input_v, cache_kv, k_quant_scales, v_quant_scales, k_dequant_scales, v_dequant_scales + ); + } + case paddle::DataType::FLOAT16: { + return LaunchWriteInt8CacheKV( + input_k, input_v, cache_kv, k_quant_scales, v_quant_scales, k_dequant_scales, v_dequant_scales + ); + } + case paddle::DataType::FLOAT32: { + return LaunchWriteInt8CacheKV( + input_k, input_v, cache_kv, k_quant_scales, v_quant_scales, k_dequant_scales, v_dequant_scales + ); + } + default: { + PD_THROW( + "NOT supported data type. " + "Only bfloat16, float16 and float32 are supported. "); + break; + } + } +} + +PD_BUILD_OP(write_int8_cache_kv) + .Inputs({"input_k", "input_v", "cache_kv", "k_quant_scales", "v_quant_scales", "q_dequant_scales", "v_dequant_scales"}) + .Outputs({"cache_kv_out"}) + .SetInplaceMap({{"cache_kv", "cache_kv_out"}}) + .SetKernelFn(PD_KERNEL(WriteInt8CacheKV)); \ No newline at end of file diff --git a/csrc/setup_cuda.py b/csrc/setup_cuda.py index e3087ec48f1e..e2957ecf6501 100644 --- a/csrc/setup_cuda.py +++ b/csrc/setup_cuda.py @@ -57,6 +57,7 @@ def get_gencode_flags(): "./generation/save_with_output.cc", "./generation/set_value_by_flags.cu", "./generation/token_penalty_multi_scores.cu", + "./generation/token_penalty_multi_scores_v2.cu", "./generation/stop_generation_multi_ends.cu", "./generation/fused_get_rope.cu", "./generation/get_padding_offset.cu", @@ -65,6 +66,16 @@ def get_gencode_flags(): "./generation/transpose_removing_padding.cu", "./generation/write_cache_kv.cu", "./generation/encode_rotary_qk.cu", + "./generation/get_padding_offset_v2.cu", + "./generation/rebuild_padding_v2.cu", + "./generation/set_value_by_flags_v2.cu", + "./generation/stop_generation_multi_ends_v2.cu", + "./generation/update_inputs.cu", + "./generation/get_output.cc", + "./generation/reset_need_stop_value.cc", + "./generation/save_with_output_msg.cc", + "./generation/write_int8_cache_kv.cu", + "./generation/step.cu", "./generation/quant_int8.cu", "./generation/dequant_int8.cu", ], diff --git a/llm/benchmark.sh b/llm/benchmark.sh index 3545a9f232a7..d49858b42b76 100644 --- a/llm/benchmark.sh +++ b/llm/benchmark.sh @@ -33,4 +33,4 @@ python predictor.py \ --mode "static" \ --batch_size 1 \ --benchmark \ - --inference_model \ No newline at end of file + --inference_model diff --git a/llm/docs/inference.md b/llm/docs/inference.md index ffc0ca8ae96d..a0e7b720f0d7 100644 --- a/llm/docs/inference.md +++ b/llm/docs/inference.md @@ -71,6 +71,8 @@ PaddleNLP 中已经添加高性能推理模型相关实现,支持: * WINT8:指Weight-Only Quantization INT8,即对权重进行INT8量化的模型。 * PTQ-A8W8:指使用PTQ对线性层的激活和权重都量化为INT8的模型。 +为了进一步提升推理的吞吐,我们基于PageAttention的思想设计并实现了BlockAttention,在保持高性能推理和动态插入的基础上可以动态地为cachekv分配存储空间,极大地节省显存,从而在同一时刻处理更多的query以获得吞吐的提升。下面分别给出关闭BlockAttention和打开BlockAttention进行高性能推理的命令参考。 + ### 2.2 环境准备 - PaddleNLP develop @@ -83,7 +85,9 @@ git clone https://github.com/PaddlePaddle/PaddleNLP cd ./paddlenlp/csrc && python setup_cuda.py install ``` -### 2.3 高性能动态图推理 +### 2.3 关闭BlockAttention的高性能推理 + +#### 2.3.1 动态图推理 ```shell # 动态图模型推理命令参考 @@ -103,7 +107,7 @@ python predictor.py --model_name_or_path checkpoints/llama_ptq_ckpts --inference 2. PrefixTuning推理需要传入相应的pre_cache,需要额外设置`export_precache`为`true`,并且传入对应的PrefixTuning参数保存路径`prefix_path`。 3. 使用Weight Only Int8 推理需要额外传入 `quant_type`。 -### 2.4 高性能静态图推理 +#### 2.3.2 静态图推理 **step1:动转静** ```shell # 动转静命令参考 @@ -150,6 +154,64 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_ 4. A8W8推理传入的 `model_name_or_path` 为PTQ校准产出的量化模型。 +### 2.4 打开BlockAttention的高性能推理 + +#### 2.4.1 动态图推理 + +```shell +# 动态图模型推理命令参考 +python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn + +# Weight Only Int8 动态图推理参考 +python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --quant_type weight_only_int8 --block_attn + +# PTQ-A8W8推理命令参考 +python predictor.py --model_name_or_path checkpoints/llama_ptq_ckpts --inference_model --dtype float16 --block_attn + +# CacheKV 动态量化推理命令参考 +python predictor.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --dtype float16 --block_attn --cachekv_int8 +``` + +#### 2.4.2 静态图推理 +**step1:动转静** +```shell +# 动转静命令参考 +python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn + +# Weight Only Int8 动转静命令参考 +python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --quant_type weight_only_int8 --block_attn + +# PTQ-A8W8动转静命令参考 +python export_model.py --model_name_or_path checkpoints/llama_ptq_ckpts --inference_model --output_path ./inference --dtype float16 --block_attn + +# CacheKV 动态量化动转静命令参考 +python export_model.py --model_name_or_path meta-llama/Llama-2-7b-chat --inference_model --output_path ./inference --dtype float16 --block_attn --cachekv_int8 +``` + +**step2:静态图推理** +```shell +# 静态图推理命令参考 +python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn + +# Weight Only Int8 静态图推理命令参考 +python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --quant_type weight_only_int8 --block_attn + +# PTQ-A8W8静态图推理命令参考 +# 以下环境变量用于开启int8矩阵乘的算法选择以获得更快的推理速度,打开之后第一次执行会执行算法选择从而导致速度较慢。 +export FLAGS_use_autotune=1 +export FLAGS_cublaslt_exhaustive_search_times=10 +export FLAGS_cache_inference_while_scope=1 + +python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --block_attn + +# CacheKV 动态量化8静态图推理命令参考 +python predictor.py --model_name_or_path ./inference --inference_model --dtype "float16" --mode "static" --cachekv_int8 --block_attn +``` +**Note**: +1. 使用Weight Only Int8 推理需要额外传入 `quant_type`。 +2. A8W8推理传入的 `model_name_or_path` 为PTQ校准产出的量化模型。 + + ## 3. 推理参数介绍 - `model_name_or_path`: 必须,预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。 @@ -168,3 +230,6 @@ python predictor.py --model_name_or_path ./inference --inference_model --quant_ - `model_type`: 初始化不同类型模型,gpt-3: GPTForCausalLM; ernie-3.5-se: Ernie35ForCausalLM; 默认为 None。 - `mode`: 使用动态图或者静态图推理,值为:[dynamic, static],默认为 dynamic。 - `inference_model`: 是否使用Inference Model 推理,默认值为 False。 +- `block_attn`: 是否使用Block Attention 推理, 默认值为False。 +- `block_size`: 如果使用Block Attention 推理,指定一个Block可以存储的token数量,默认值为64。 +- `cachekv_int8`: 是否使用cachekv int8量化用于节省显存,默认值为False。 diff --git a/llm/export_model.py b/llm/export_model.py index ef707db67463..48932560efe1 100644 --- a/llm/export_model.py +++ b/llm/export_model.py @@ -83,7 +83,11 @@ def main(): predictor.model.to_static( get_infer_model_path(export_args.output_path, predictor_args.model_prefix), - {"dtype": predictor_args.dtype, "export_precache": predictor_args.export_precache}, + { + "dtype": predictor_args.dtype, + "export_precache": predictor_args.export_precache, + "use_cachekv_int8": predictor_args.use_cachekv_int8, + }, ) predictor.model.config.save_pretrained(export_args.output_path) predictor.tokenizer.save_pretrained(export_args.output_path) diff --git a/llm/llama/ptq_argument.json b/llm/llama/ptq_argument.json index 3f3f432f1371..0a64f3818834 100644 --- a/llm/llama/ptq_argument.json +++ b/llm/llama/ptq_argument.json @@ -1,22 +1,22 @@ { - "model_name_or_path": "./checkpoints/llama_sft_ckpts", - "per_device_train_batch_size": 8, - "per_device_eval_batch_size": 8, - "eval_accumulation_steps":16, - "src_length": 1024, - "max_length": 2048, - "fp16": true, - "fp16_opt_level": "O2", - "dataset_name_or_path": "./data", - "output_dir": "./checkpoints/llama_ptq_ckpts", - "do_eval": true, - "eval_with_do_generation": false, - "do_ptq": true, - "ptq_step": 16, - "smooth": true, - "smooth_step": 16, - "smooth_all_linears": true, - "smooth_piecewise_search": true, - "smooth_k_piece": 3, - "smooth_search_piece": true - } \ No newline at end of file + "model_name_or_path": "./checkpoints/llama_sft_ckpts", + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "src_length": 1024, + "max_length": 2048, + "fp16": true, + "fp16_opt_level": "O2", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/llama_ptq_ckpts", + "do_eval": true, + "eval_with_do_generation": false, + "do_ptq": true, + "ptq_step": 16, + "smooth": true, + "smooth_step": 16, + "smooth_all_linears": true, + "smooth_piecewise_search": true, + "smooth_k_piece": 3, + "smooth_search_piece": true +} \ No newline at end of file diff --git a/llm/predictor.py b/llm/predictor.py index 7175a58e697d..f3e468876277 100644 --- a/llm/predictor.py +++ b/llm/predictor.py @@ -25,7 +25,9 @@ import numpy as np import paddle import paddle.distributed.fleet.base.topology as tp +import paddle.incubate.multiprocessing as mp from paddle.distributed import fleet +from paddlenlp_ops import reset_stop_value from utils import ( dybatch_preprocess, get_alibi_slopes, @@ -36,6 +38,7 @@ get_prefix_tuning_params, init_chat_template, load_real_time_tokens, + read_res, ) from paddlenlp.generation import GenerationConfig, TextIteratorStreamer @@ -54,6 +57,9 @@ from paddlenlp.utils.import_utils import import_module, is_paddlenlp_ops_available from paddlenlp.utils.log import logger +# Note(@RochardWooSJTU): MAX_BSZ must be the same as definition in get_output / save_output +MAX_BSZ = 512 + @dataclass class PredictorArgument: @@ -100,6 +106,13 @@ class PredictorArgument: }, ) + block_attn: bool = field(default=False, metadata={"help": "whether use block attention"}) + block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."}) + cachekv_int8: bool = field( + default=False, + metadata={"help": "If cachekv_int8 set as `True`, cache kv would be quantized to int8 dynamically. "}, + ) + chat_template: str = field( default=None, metadata={ @@ -115,6 +128,10 @@ class PredictorArgument: def total_max_length(self): return self.src_length + self.max_length + @property + def use_cachekv_int8(self): + return "dynamic" if self.cachekv_int8 else "None" + @dataclass class ModelArgument: @@ -687,6 +704,424 @@ def _infer(self, inputs: dict[str, paddle.Tensor]): return None +class BlockInferencePredictorMixin: + def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer): + + self.num_layers = len(self.cache_kvs_shape) // 2 + self.num_attention_heads = self.cache_kvs_shape[0][-3] + self.head_dim = self.cache_kvs_shape[0][-1] + self.max_block_nums = self.cache_kvs_shape[0][0] + self.batch_size = config.batch_size + self.model_name_or_path = config.model_name_or_path + + self.architectures = self.model_config.architectures[0].lower() + + self.dtype = config.dtype or self.model_config + + self.total_max_length = config.src_length + config.max_length + self.block_size = config.block_size + self.pre_max_block_num = (self.total_max_length + config.block_size - 1) // config.block_size + self.max_block_nums = config.batch_size * self.pre_max_block_num + + self.pre_cache_length = 0 + + if config.export_precache: + pre_cache_npy = np.load(config.prefix_path) + self.pre_cache_length = pre_cache_npy.shape[-2] + config.max_length -= self.pre_cache_length + self.pre_caches = [ + paddle.zeros( + [config.batch_size, self.num_attention_heads, self.pre_cache_length, self.head_dim], + dtype=self.dtype, + ) + for _ in range(2 * self.num_layers) + ] + for i in range(self.num_layers): + self.pre_caches[2 * i][:, :, :, :] = paddle.to_tensor(pre_cache_npy[i][0], dtype=self.dtype).unsqueeze( + 0 + ) + self.pre_caches[2 * i + 1][:, :, :, :] = paddle.to_tensor( + pre_cache_npy[i][1], dtype=self.dtype + ).unsqueeze(0) + + self.pre_cache_mask = paddle.zeros( + shape=[config.batch_size, 1, config.src_length, config.src_length + self.pre_cache_length], + dtype=config.dtype, + ) + self.pre_cache_mask[:, :, :, : self.pre_cache_length] = 1 + self.pre_cache_mask[:, :, :, self.pre_cache_length :] = paddle.tril( + paddle.ones(shape=[config.batch_size, 1, config.src_length, config.src_length], dtype=config.dtype) + ) + + if config.use_cachekv_int8 == "dynamic": + self.k_quant_scales = [ + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) + ] + self.v_quant_scales = [ + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) + ] + self.k_dequant_scales = [ + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) + ] + self.v_dequant_scales = [ + paddle.zeros([config.batch_size, self.num_attention_heads], dtype="float32") + for _ in range(self.num_layers) + ] + + if config.benchmark: + self.min_length = config.max_length + else: + self.min_length = 2 + + self.free_list = [i for i in range(self.max_block_nums)][::-1] + self.used_list = [[] for _ in range(config.batch_size)] + + def init_inputs(self, config: PredictorArgument): + self.inputs = {} + + if config.export_precache: + self.inputs["src_mask"] = (self.pre_cache_mask - 1) * 1e4 + self.inputs["pre_ids"] = paddle.full([config.batch_size, self.total_max_length], -1, dtype="int64") + self.inputs["bad_tokens"] = paddle.to_tensor( + [ + -1, + ], + dtype="int64", + ) + self.inputs["penalty_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") + self.inputs["frequency_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") + self.inputs["presence_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") + + self.inputs["min_length"] = paddle.full( + shape=[config.batch_size, 1], fill_value=self.min_length, dtype="int64" + ) + self.inputs["max_length"] = paddle.full( + shape=[config.batch_size, 1], fill_value=config.max_length, dtype="int64" + ) + self.inputs["stop_nums"] = paddle.full(shape=[1], fill_value=config.batch_size, dtype="int64") + self.inputs["rope_emb"] = self._get_rotary_position_embedding( + paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim + ) + eos_token_id = get_eos_token_id(self.tokenizer, self.generation_config) + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.inputs["eos_token_id"] = paddle.to_tensor( + np.array(eos_token_id * config.batch_size).reshape(-1, 1).astype("int64") + ) + # need update + self.inputs["block_tables"] = paddle.full( + shape=[config.batch_size, self.pre_max_block_num], fill_value=-1, dtype="int32" + ) + self.inputs["input_ids"] = paddle.full( + shape=[config.batch_size, self.total_max_length], fill_value=-1, dtype="int64" + ) + self.inputs["top_p"] = paddle.full(shape=[config.batch_size, 1], fill_value=config.top_p, dtype="float32") + self.inputs["temperature"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") + self.inputs["seq_lens_this_time"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") + self.inputs["seq_lens_encoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") + self.inputs["seq_lens_decoder"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int32") + self.inputs["step_idx"] = paddle.full(shape=[config.batch_size, 1], fill_value=0, dtype="int64") + self.inputs["not_need_stop"] = paddle.full(shape=[1], fill_value=False, dtype="bool").cpu() + self.inputs["stop_flags"] = paddle.full(shape=[config.batch_size, 1], fill_value=True, dtype="bool") + self.inputs["next_tokens"] = paddle.full(shape=[config.batch_size, 1], fill_value=-1, dtype="int64") + self.inputs["is_block_step"] = paddle.full(shape=[config.batch_size], fill_value=False, dtype="bool") + free_list = list(range(self.pre_max_block_num - 1, int(self.pre_max_block_num * 0.75) - 1, -1)) + self.inputs["encoder_block_lens"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") + self.inputs["step_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["step_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["recover_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["recover_lens"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["need_block_list"] = paddle.full(shape=[config.batch_size], fill_value=-1, dtype="int32") + self.inputs["need_block_len"] = paddle.full(shape=[1], fill_value=0, dtype="int32") + self.inputs["used_list_len"] = paddle.full(shape=[config.batch_size], fill_value=0, dtype="int32") + self.inputs["free_list"] = paddle.to_tensor(free_list, dtype="int32") + self.inputs["free_list_len"] = paddle.full(shape=[1], fill_value=self.pre_max_block_num * 0.25, dtype="int32") + + def _get_rotary_position_embedding(self, position_ids, head_dim): + """ + Pre-calculate rotary position embedding for position_ids. + + Args: + position_ids: [1, S] + head_dim: D + + Returns: + rot_emb: [2, 1, S, 1, D], cos + sin + """ + bsz, max_seq_len = position_ids.shape[:2] + rot_emb = paddle.zeros((2, bsz, max_seq_len, 1, head_dim), dtype="float32") + inv_freq = 10000 ** (-paddle.arange(0, head_dim, 2, dtype="float32") / head_dim) + + # shape: [B, S, D/2] + freqs = paddle.einsum("ij,k->ijk", position_ids.cast("float32"), inv_freq) + # shape: [B, S, 1, D] + emb = paddle.concat([freqs, freqs], axis=-1).reshape((bsz, max_seq_len, 1, head_dim)) + + rot_emb[0] = paddle.cos(emb) + rot_emb[1] = paddle.sin(emb) + return rot_emb + + def _preprocess(self, source): + if self.tokenizer.chat_template is not None: + source = [source] if isinstance(source, str) else source + source = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in source] + + for i, text in enumerate(source): + tokens = self.tokenizer( + text, + return_tensors="np", + padding=True, + max_length=self.config.src_length, + # if use chat_template, it will not add special_tokens + add_special_tokens=self.tokenizer.chat_template is None + or isinstance(self.tokenizer, ChatGLMv2Tokenizer), + ) + input_ids = tokens["input_ids"][0] + length = len(input_ids) + self.inputs["input_ids"][i : i + 1, :length] = input_ids + self.inputs["penalty_score"][i : i + 1] = self.config.repetition_penalty + self.inputs["frequency_score"][i : i + 1] = 0.0 + self.inputs["presence_score"][i : i + 1] = 0.0 + self.inputs["top_p"][i : i + 1] = self.config.top_p + self.inputs["temperature"][i : i + 1] = self.config.temperature + self.inputs["seq_lens_this_time"][i : i + 1] = length + self.inputs["seq_lens_encoder"][i : i + 1] = length + self.inputs["seq_lens_decoder"][i : i + 1] = 0 + self.inputs["step_idx"][i : i + 1] = 0 + self.inputs["stop_flags"][i : i + 1] = False + reset_stop_value(self.inputs["not_need_stop"]) + need_block_nums = ( + length + self.config.max_length + self.pre_cache_length + self.block_size - 1 + ) // self.block_size + for bi in range(need_block_nums): + bi_now = self.free_list.pop() + self.used_list[i].append(bi_now) + self.inputs["block_tables"][i : i + 1, bi] = bi_now + + +class DygraphBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): + def __init__( + self, + config: PredictorArgument, + model: PretrainedModel = None, + tokenizer: PretrainedTokenizer = None, + ): + self.cache_kvs_shape = model.get_cache_kvs_shape(model.config, config.batch_size) + BasePredictor.__init__(self, config, tokenizer) + BlockInferencePredictorMixin.__init__(self, config, tokenizer) + + if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static": + self.cache_kvs = [paddle.zeros(shape, dtype="uint8") for shape in self.cache_kvs_shape] + else: + self.cache_kvs = [paddle.zeros(shape, dtype=self.dtype) for shape in self.cache_kvs_shape] + + self.model = model + + self.init_inputs(config) + if config.export_precache: + self.inputs["pre_caches"] = self.pre_caches + if config.use_cachekv_int8 == "dynamic": + self.inputs["k_quant_scales"] = self.k_quant_scales + self.inputs["v_quant_scales"] = self.v_quant_scales + self.inputs["k_dequant_scales"] = self.k_dequant_scales + self.inputs["v_dequant_scales"] = self.v_dequant_scales + + self.inputs["cache_kvs"] = self.cache_kvs + + @paddle.no_grad() + def _infer(self, inputs: dict[str, paddle.Tensor]): + self.model.generate( + **inputs, + ) + + @paddle.no_grad() + def predict(self, input_texts: str | list[str]): + self._preprocess(input_texts) + + result_queue = mp.Queue() + tensor_queue = mp.Queue() + + output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + output_tensor = output_tensor.cpu() + tensor_queue.put(output_tensor) + + read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue]) + read_res_process.start() + + while self.inputs["not_need_stop"]: + self._infer(self.inputs) + # reset free_list + for i in range(self.config.batch_size): + self.free_list.extend(self.used_list[i]) + self.used_list[i] = [] + reset_stop_value(self.inputs["not_need_stop"]) + + outputs = [] + while len(outputs) < self.batch_size: + outputs.append(result_queue.get(timeout=1)[-1]) + return outputs + + +class StaticBlockInferencePredictor(BlockInferencePredictorMixin, BasePredictor): + def __init__( + self, + config: PredictorArgument, + cache_kvs_shape: list[list[int]], + tokenizer: PretrainedTokenizer = None, + ): + self.cache_kvs_shape = cache_kvs_shape + BasePredictor.__init__(self, config, tokenizer) + BlockInferencePredictorMixin.__init__(self, config, tokenizer) + + self.init_inputs(config) + + if config.export_precache: + for i in range(self.num_layers): + self.inputs["pre_caches_{}".format(i)] = self.pre_caches[i] + + self.cache_kvs = {} + if config.use_cachekv_int8 == "dynamic" or config.use_cachekv_int8 == "static": + for i in range(len(self.cache_kvs_shape) // 2): + self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros(self.cache_kvs_shape[2 * i], dtype="uint8") + self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros( + self.cache_kvs_shape[2 * i + 1], dtype="uint8" + ) + else: + for i in range(len(self.cache_kvs_shape) // 2): + self.cache_kvs["key_caches_{}".format(i)] = paddle.zeros( + self.cache_kvs_shape[2 * i], dtype=config.dtype + ) + self.cache_kvs["value_caches_{}".format(i)] = paddle.zeros( + self.cache_kvs_shape[2 * i + 1], dtype=config.dtype + ) + + for i in range(self.num_layers): + if self.config.use_cachekv_int8 == "dynamic": + self.inputs["k_quant_scales_" + str(i)] = self.k_quant_scales[i] + self.inputs["v_quant_scales_" + str(i)] = self.v_quant_scales[i] + self.inputs["k_dequant_scales_" + str(i)] = self.k_dequant_scales[i] + self.inputs["v_dequant_scales_" + str(i)] = self.v_dequant_scales[i] + + self._create_predictor(config) + self.input_names = self.predictor.get_input_names() + + self._share_data() + self.seq_lens_handle = self.predictor.get_input_handle("seq_lens_this_time") + + def _create_predictor(self, predictor_args: PredictorArgument): + if not is_paddlenlp_ops_available(): + raise ValueError( + "you should install the paddlenlp ops to run inference predictor, " + "https://github.com/PaddlePaddle/PaddleNLP/blob/develop/csrc/README.md" + ) + + infer_model_path = get_infer_model_path(predictor_args.model_name_or_path, predictor_args.model_prefix) + + config = paddle.inference.Config(infer_model_path + ".pdmodel", infer_model_path + ".pdiparams") + + config.switch_ir_optim(False) + device_id = int(os.environ.get("FLAGS_selected_gpus", 0)) + config.enable_use_gpu(100, device_id) + # config.disable_glog_info() + # config.enable_memory_optim() + + if self.tensor_parallel_degree > 1: + trainer_endpoints = fleet.worker_endpoints() + current_endpoint = trainer_endpoints[self.tensor_parallel_rank] + + dist_config = config.dist_config() + dist_config.set_ranks(self.tensor_parallel_degree, self.tensor_parallel_rank) + dist_config.set_endpoints(trainer_endpoints, current_endpoint) + dist_config.enable_dist_model(True) + + dist_config.set_comm_init_config(os.path.join(predictor_args.model_name_or_path, "rank_mapping.csv")) + config.set_dist_config(dist_config) + + self.predictor = paddle.inference.create_predictor(config) + + def _share_data(self): + """ + Share external data for inference predictor. + """ + for name in self.input_names: + if "pre_key_" in name or "pre_value_" in name: + input_tensor = self.predictor.get_input_handle(name) + input_tensor.share_external_data(self.inputs[name]) + continue + if "caches" in name: + input_tensor = self.predictor.get_input_handle(name) + input_tensor.share_external_data(self.cache_kvs[name]) + continue + if "seq_lens_this_time" in name: + continue + input_tensor = self.predictor.get_input_handle(name) + input_tensor.share_external_data(self.inputs[name]) + + def _infer(self): + self.predictor.run() + + def predict(self, input_texts: str | list[str]): + + s_time = time.time() + self._preprocess(input_texts) + real_bsz = len(input_texts) + + import copy + + seq_lens_this_time = copy.deepcopy(self.inputs["seq_lens_this_time"][:real_bsz]) + self.seq_lens_handle.share_external_data(seq_lens_this_time) + logger.info(f"preprocess spend {time.time() - s_time}") + + result_queue = mp.Queue() + tensor_queue = mp.Queue() + + output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64") + output_tensor = output_tensor.cpu() + tensor_queue.put(output_tensor) + + read_res_process = mp.Process(target=read_res, args=[self.model_name_or_path, tensor_queue, result_queue]) + read_res_process.start() + + s_time = time.time() + while self.inputs["not_need_stop"]: + self.predictor.run() + logger.info(f"running spend {time.time() - s_time}") + + # reset free_list + for i in range(self.config.batch_size): + self.free_list.extend(self.used_list[i]) + self.used_list[i] = [] + reset_stop_value(self.inputs["not_need_stop"]) + + outputs = [] + while len(outputs) < self.batch_size: + outputs.append(result_queue.get(timeout=1)[-1]) + return outputs + + def _preprocess(self, source): + BlockInferencePredictorMixin._preprocess(self, source) + for i, text in enumerate(source): + tokens = self.tokenizer(text, return_tensors="np", padding=False, max_length=(self.config.src_length)) + input_ids = tokens["input_ids"][0] + length = len(input_ids) + need_block_nums = ( + length + self.config.max_length + self.pre_cache_length + self.block_size - 1 + ) // self.block_size + self.inputs["encoder_block_lens"][i : i + 1] = need_block_nums + + +def get_ptq_multicards_num(directory): + count = 0 + prefix = "act_scales_" + for filename in os.listdir(directory): + if filename.startswith(prefix): + count += 1 + return count + + def create_predictor( predictor_args: PredictorArgument, model_args: ModelArgument, @@ -790,6 +1225,8 @@ def create_predictor( config.weight_only_quant_bits = -1 config.quant_type = None config.model_name_or_path = "" + config.use_cachekv_int8 = predictor_args.use_cachekv_int8 + config.single_card_ptq = True if predictor_args.quant_type is not None and predictor_args.quant_type.startswith("weight_only_int"): weight_only_quant_bits = int(predictor_args.quant_type[-1]) @@ -800,6 +1237,11 @@ def create_predictor( config.model_name_or_path = predictor_args.model_name_or_path config.quant_type = config.quantization_config.quant_type + ptq_multicards_num = get_ptq_multicards_num(config.model_name_or_path) + logger.info(f"PTQ from {ptq_multicards_num} cards, so we will not split") + if ptq_multicards_num > 1: + config.single_card_ptq = False + # Turn on GEMM int8 kernel tuning paddle.base.core.enable_autotune() paddle.base.core.update_autotune_status() @@ -810,13 +1252,30 @@ def create_predictor( from paddlenlp.experimental.transformers import ( LlamaForMiniGPT4InferenceModel as LlamaInferenceModel, ) + elif predictor_args.block_attn: + config.max_seq_len = predictor_args.total_max_length + config.block_size = predictor_args.block_size + from paddlenlp.experimental.transformers import ( + LlamaForCausalLMBlockInferenceModel as LlamaInferenceModel, + ) + + model = LlamaInferenceModel.from_pretrained( + predictor_args.model_name_or_path, + config=config, + dtype=predictor_args.dtype, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + ) else: from paddlenlp.experimental.transformers import ( LlamaForCausalLMInferenceModel as LlamaInferenceModel, ) - model = LlamaInferenceModel.from_pretrained( - predictor_args.model_name_or_path, config=config, dtype=predictor_args.dtype - ) + + model = LlamaInferenceModel.from_pretrained( + predictor_args.model_name_or_path, + config=config, + dtype=predictor_args.dtype, + ) model.eval() elif "opt" in config.architectures[0].lower(): @@ -893,15 +1352,35 @@ def create_predictor( model.eval() else: raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]") - predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) + if predictor_args.block_attn: + predictor = DygraphBlockInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) + else: + predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer) + elif predictor_args.mode == "static": config = AutoConfig.from_pretrained(predictor_args.model_name_or_path) if "llama" in config.architectures[0].lower(): + if predictor_args.block_attn: + config.block_size = predictor_args.block_size + config.max_seq_len = predictor_args.total_max_length + config.use_dynamic_cachekv_quant = predictor_args.use_cachekv_int8 == "dynamic" + from paddlenlp.experimental.transformers import ( + LlamaForCausalLMBlockInferenceModel as LlamaInferenceModel, + ) + else: + from paddlenlp.experimental.transformers import ( + LlamaForCausalLMInferenceModel as LlamaInferenceModel, + ) + + cache_kvs_shape = LlamaInferenceModel.get_cache_kvs_shape( + config, predictor_args.batch_size, predictor_args.total_max_length + ) + elif "chatglmv2forcausallm" in config.architectures[0].lower(): from paddlenlp.experimental.transformers import ( - LlamaForCausalLMInferenceModel, + ChatGLMv2ForCausalLMInferenceModel, ) - cache_kvs_shape = LlamaForCausalLMInferenceModel.get_cache_kvs_shape( + cache_kvs_shape = ChatGLMv2ForCausalLMInferenceModel.get_cache_kvs_shape( config, predictor_args.batch_size, predictor_args.total_max_length ) elif "chatglmv2forcausallm" in config.architectures[0].lower(): @@ -946,7 +1425,10 @@ def create_predictor( ) else: raise ValueError("the `model type` should be one of [llama, chatglm, bloom, gpt, qwen]") - predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) + if predictor_args.block_attn: + predictor = StaticBlockInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) + else: + predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer) else: raise ValueError("the `mode` should be one of [dynamic, static]") @@ -989,7 +1471,9 @@ def predict(): with open(model_args.output_file, "w", encoding="utf-8") as f: for bs, batch_source_text in enumerate(batch_source_texts): + logger.info("Start predict") outputs = predictor.predict(batch_source_text) + logger.info("End predict") if predictor.tensor_parallel_rank > 0: continue @@ -1009,27 +1493,42 @@ def predict(): def benchmark(predictor, predictor_args, model_args): # Just construct a simple benchmark input. We pad input to the src_length. - test_texts = "hello world, how are you?" - benchmark_texts = [test_texts + "" * predictor_args.src_length for _ in range(predictor_args.batch_size)] + test_texts = "who are you" + benchmark_texts = [test_texts + "" * (predictor_args.src_length) for _ in range(predictor_args.batch_size)] batch_benchmark_texts = batchfy_text(benchmark_texts, predictor_args.batch_size) print("***********Start Benchmark**********") - warmup_time = 10 - test_time = 100 + warmup_time = 2 + test_time = 10 print("***********Start Warmup**********") - for _ in range(warmup_time): - for bs, batch_source_text in enumerate(batch_benchmark_texts): - outputs = predictor.predict(batch_source_text) + for i in range(warmup_time): + print("warm up ", i) + for _, batch_source_text in enumerate(batch_benchmark_texts): + predictor.predict(batch_source_text) + + from paddle import profiler + + # 创建性能分析器相关的代码 + def my_on_trace_ready(prof): # 定义回调函数,性能分析器结束采集数据时会被调用 + callback = profiler.export_chrome_tracing("./profiler_demo") # 创建导出性能数据到profiler_demo文件夹的回调函数 + callback(prof) # 执行该导出函数 + prof.summary(sorted_by=profiler.SortedKeys.GPUTotal) # 打印表单,按GPUTotal排序表单项 + + p = profiler.Profiler(scheduler=[3, 4], on_trace_ready=my_on_trace_ready, timer_only=False) # 初始化Profiler对象 print("***********Start Speed Test**********") start = time.perf_counter() output_tokens = 0 - for _ in range(test_time): - for bs, batch_source_text in enumerate(batch_benchmark_texts): - outputs = predictor.predict(batch_source_text) - output_tokens += sum([len(output) for output in outputs]) + p.start() + for i in range(test_time): + print("test ", i) + for _, batch_source_text in enumerate(batch_benchmark_texts): + predictor.predict(batch_source_text) + output_tokens += predictor_args.max_length * predictor_args.batch_size + p.step() + p.stop() end = time.perf_counter() print("Avg Elapse time is: ", (end - start) / test_time) print("Output tokens is: ", output_tokens) diff --git a/llm/utils.py b/llm/utils.py index d15d72dcddae..3ea18855456f 100644 --- a/llm/utils.py +++ b/llm/utils.py @@ -22,14 +22,17 @@ import numpy as np import paddle import paddle.distributed as dist +import paddle.incubate.multiprocessing as mp from paddle.distributed import fleet from paddle.io import BatchSampler, DataLoader, DistributedBatchSampler +from paddlenlp_ops import get_output from sklearn.metrics import accuracy_score from paddlenlp.datasets import InTokensIterableDataset from paddlenlp.trainer import Trainer, TrainerCallback from paddlenlp.trainer.trainer_utils import IterableDatasetShard, has_length from paddlenlp.transformers import ( + AutoTokenizer, ChatGLMv2Tokenizer, LlamaForCausalLMPipe, PretrainedConfig, @@ -687,3 +690,32 @@ def get_default_max_encoding_length(config: PretrainedConfig, default: int = 102 if max_position_embeddings is None: return default return max_position_embeddings // 4 * 3 + + +def read_res(model_name_or_path: str, tensor_queue: mp.Queue, result_queue: mp.Queue): + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + ) + + paddle.device.set_device("cpu") + outputs = [] + output_tensor = tensor_queue.get(timeout=1) + + logger.info("Start read result message") + logger.info(f"Current path is {os.getcwd()}") + while True: + get_output(output_tensor, 0, True) + if output_tensor[0, 0] == -2: # read none + continue + bsz = output_tensor[1, 0].numpy() + output_numpy = output_tensor[2 : bsz + 2].numpy() + output_numpy[output_numpy == -1] = 2 + outputs.append(output_numpy) + if output_tensor[0, 0] == -1: + break + output = np.concatenate(outputs, axis=1).tolist() + seqs = tokenizer.batch_decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=False) + for i, seq in enumerate(seqs): + result_queue.put([i, seq]) + + logger.info("Finish read result message") diff --git a/paddlenlp/experimental/model_utils.py b/paddlenlp/experimental/model_utils.py index d1746c401116..151a90f2e9ae 100644 --- a/paddlenlp/experimental/model_utils.py +++ b/paddlenlp/experimental/model_utils.py @@ -396,3 +396,27 @@ def __init__( self.scale["ffn1_weight_scale"].append( np.concatenate([self.scale["ffn1_1_weight_scale"][i, :], self.scale["ffn1_2_weight_scale"][i, :]]) ) + + +class CacheScaleLoader: + def __init__( + self, scale_json_file_path="cache_scales.json", key_map_dict=None, num_of_layers=None, num_heads=None + ): + with open(scale_json_file_path) as json_file: + self.scale_dict = json.load(json_file) + self.key_map = key_map_dict + self.scale = {} + for scale_type, key_template in self.key_map.items(): + if "cache_k" in scale_type: + scale_type_out = "cache_k_out_scale" + else: + scale_type_out = "cache_v_out_scale" + self.scale[scale_type] = np.full([num_of_layers, num_heads], fill_value=-1.0) + self.scale[scale_type_out] = np.full([num_of_layers, num_heads], fill_value=-1.0) + + for i in range(num_of_layers): + if key_template.replace("#", str(i)) in self.scale_dict.keys(): + self.scale[scale_type][i, :] = [ + 127.0 / num for num in self.scale_dict[key_template.replace("#", str(i))] + ] + self.scale[scale_type_out][i, :] = [1.0 / self.scale[scale_type][i, j] for j in range(num_heads)] diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 241d67a93375..0e7205932ed3 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -36,6 +36,7 @@ qkv_transpose_split, quant_int8, rebuild_padding, + rebuild_padding_v2, transpose_remove_padding, write_cache_kv, ) @@ -52,6 +53,9 @@ "FusedMultiTransformerPostLayernorm", "FusedMultiTransformerWeightOnly", "FusedMultiTransformerWeightOnlyPostLayernorm", + "FusedBlockMultiTransformer", + "FusedBlockMultiTransformerWeightOnly", + "FusedBlockMultiTransformerA8W8", ] @@ -176,6 +180,10 @@ def __init__( linear_smooth_attrs=None, ffn2_shift_attrs=None, ffn2_smooth_attrs=None, + cache_k_scale_attrs=None, + cache_v_scale_attrs=None, + cache_k_out_scale_attrs=None, + cache_v_out_scale_attrs=None, quant_round_type=0, quant_max_bound=127.0, quant_min_bound=-127.0, @@ -186,6 +194,8 @@ def __init__( trans_qkvw=True, ring_id=-1, kv_num_heads=-1, + use_dynamic_cachekv_quant=True, + rank_id=-1, ): self.embed_dim = embed_dim self.num_heads = num_heads @@ -227,14 +237,21 @@ def __init__( self.linear_smooth_attrs = linear_smooth_attrs self.ffn2_shift_attrs = ffn2_shift_attrs self.ffn2_smooth_attrs = ffn2_smooth_attrs + self.cache_k_scale_attrs = cache_k_scale_attrs + self.cache_v_scale_attrs = cache_v_scale_attrs + self.cache_k_out_scale_attrs = cache_k_out_scale_attrs + self.cache_v_out_scale_attrs = cache_v_out_scale_attrs + self.quant_round_type = quant_round_type self.quant_max_bound = quant_max_bound self.quant_min_bound = quant_min_bound + self.use_dynamic_cachekv_quant = use_dynamic_cachekv_quant self.epsilon = epsilon self.residual_alpha = residual_alpha self.num_layers = num_layers self.nranks = nranks + self.rank_id = rank_id self.trans_qkvw = trans_qkvw self.ring_id = ring_id @@ -243,6 +260,8 @@ class FusedMultiTransformerBase(Layer): def __init__(self, config: FusedMultiTransformerConfig): super().__init__() + self.config = config + assert config.embed_dim > 0, "Expected embed_dim to be greater than 0, " "but received {}".format( config.embed_dim ) @@ -298,6 +317,8 @@ def __init__(self, config: FusedMultiTransformerConfig): self.ffn_ln_scales, self.ffn_ln_biases = [], [] self.ffn1_weights, self.ffn1_biases = [], [] self.ffn2_weights, self.ffn2_biases = [], [] + self.cache_k_scales, self.cache_v_scales = [], [] + self.cache_k_out_scales, self.cache_v_out_scales = [], [] for i in range(self.num_layers): ln_scale_attr = self.get_attr(config.ln_scale_attrs, i) @@ -315,6 +336,11 @@ def __init__(self, config: FusedMultiTransformerConfig): ffn2_weight_attr = self.get_attr(config.ffn2_weight_attrs, i) ffn2_bias_attr = self.get_attr(config.ffn2_bias_attrs, i) + cache_k_scale_attr = self.get_attr(config.cache_k_scale_attrs, i) + cache_v_scale_attr = self.get_attr(config.cache_v_scale_attrs, i) + cache_k_out_scale_attr = self.get_attr(config.cache_k_out_scale_attrs, i) + cache_v_out_scale_attr = self.get_attr(config.cache_v_out_scale_attrs, i) + ln_scale = self.create_parameter( attr=ln_scale_attr, shape=[config.embed_dim], @@ -413,6 +439,42 @@ def __init__(self, config: FusedMultiTransformerConfig): is_bias=True, ) + cache_k_scale = None + if cache_k_scale_attr: + cache_k_scale = self.create_parameter( + shape=[self.num_heads], + attr=cache_k_scale_attr, + dtype="float32", + is_bias=False, + ) + + cache_v_scale = None + if cache_v_scale_attr: + cache_v_scale = self.create_parameter( + shape=[self.num_heads], + attr=cache_v_scale_attr, + dtype="float32", + is_bias=False, + ) + + cache_k_out_scale = None + if cache_k_out_scale_attr: + cache_k_out_scale = self.create_parameter( + shape=[self.num_heads], + attr=cache_k_out_scale_attr, + dtype="float32", + is_bias=False, + ) + + cache_v_out_scale = None + if cache_v_out_scale_attr: + cache_v_out_scale = self.create_parameter( + shape=[self.num_heads], + attr=cache_v_out_scale_attr, + dtype="float32", + is_bias=False, + ) + # tensor model parallel if config.nranks > 1: # column parallel @@ -438,6 +500,11 @@ def __init__(self, config: FusedMultiTransformerConfig): self.ffn2_weights.append(ffn2_weight) self.ffn2_biases.append(ffn2_bias) + self.cache_k_scales.append(cache_k_scale) + self.cache_v_scales.append(cache_v_scale) + self.cache_k_out_scales.append(cache_k_out_scale) + self.cache_v_out_scales.append(cache_v_out_scale) + self._add_parameter(ln_scale) self._add_parameter(ln_bias) self._add_parameter(qkv_weight) @@ -452,6 +519,11 @@ def __init__(self, config: FusedMultiTransformerConfig): self._add_parameter(ffn2_weight) self._add_parameter(ffn2_bias) + self._add_parameter(cache_k_scale) + self._add_parameter(cache_v_scale) + self._add_parameter(cache_k_out_scale) + self._add_parameter(cache_v_out_scale) + self.dropout_rate = config.dropout_rate from paddle.incubate.nn.functional import fused_linear @@ -595,6 +667,7 @@ def compute_attn( pre_caches_length, attn_mask, i, + **kwargs, ): # fmha compute if time_step is None: # context @@ -643,6 +716,7 @@ def compute_ffn2(self, ffn1_out, i): return paddle.matmul(ffn1_out, self.ffn2_weights[i]) def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layers): + if i != num_layers - 1: norm_out = self.norm_func( ffn2_out, @@ -666,6 +740,23 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer )[0] return tmp_out, residual_input + def pre_process(self, **kwargs): + pass + + def post_process(self, **kwargs): + time_step = kwargs.get("time_step", None) + multi_block_output = kwargs.get("multi_block_output", None) + cum_offsets = kwargs.get("cum_offsets", None) + seq_lens = kwargs.get("seq_lens", None) + input_ids = kwargs.get("input_ids", None) + + if time_step is None: + out = rebuild_padding(multi_block_output, cum_offsets, seq_lens, input_ids) + else: + out = multi_block_output + + return out + def forward( self, input_ids, @@ -680,6 +771,7 @@ def forward( rotary_emb_dims=0, seq_lens=None, time_step=None, + **kwargs, ): r""" Applies multi transformer layers on the input. @@ -716,11 +808,16 @@ def forward( tuple (output, caches), which output is the output of Transformer layers, caches is inplace with input `caches`. """ + self.pre_process(**kwargs) + kwargs["cum_offsets"] = cum_offsets + if caches is not None: - assert len(caches) == len(self.qkv_weights) + assert len(caches) == len(self.qkv_weights) or len(caches) == 2 * len(self.qkv_weights) + + assert self.num_layers == len(self.qkv_weights) residual_input = src - for i in range(len(caches)): + for i in range(self.num_layers): qkv_out, residual_input = self.compute_qkv(src, residual_input, i) out_linear_out = self.compute_attn( time_step, @@ -735,6 +832,7 @@ def forward( pre_caches_length, attn_mask, i, + **kwargs, ) # all_reduce if self.nranks > 1: @@ -755,13 +853,17 @@ def forward( dist.all_reduce(ffn2_out) # norm + residual_add_bias - tmp_out, residual_input = self.compute_bias_residual_layernorm(ffn2_out, residual_input, i, len(caches)) + tmp_out, residual_input = self.compute_bias_residual_layernorm( + ffn2_out, residual_input, i, self.num_layers + ) src = tmp_out - if time_step is None: - out = rebuild_padding(tmp_out, cum_offsets, seq_lens, input_ids) - else: - out = tmp_out + kwargs["time_step"] = time_step + kwargs["multi_block_output"] = tmp_out + kwargs["seq_lens"] = seq_lens + kwargs["input_ids"] = input_ids + + out = self.post_process(**kwargs) return out, caches @@ -1228,3 +1330,161 @@ def compute_bias_residual_layernorm(self, ffn2_out, residual_input, i, num_layer residual=residual_input, )[0] return tmp_out, residual_input + + +class FusedBlockMultiTransformer(FusedMultiTransformerBase): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + def compute_attn( + self, + time_step, + qkv_out, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + **kwargs, + ): + k_quant_scales = kwargs.get("k_quant_scales", None) + v_quant_scales = kwargs.get("v_quant_scales", None) + k_dequant_scales = kwargs.get("k_dequant_scales", None) + v_dequant_scales = kwargs.get("v_dequant_scales", None) + + if not self.config.use_dynamic_cachekv_quant: + k_quant_scales = self.cache_k_scales + v_quant_scales = self.cache_v_scales + k_dequant_scales = self.cache_k_out_scales + v_dequant_scales = self.cache_v_out_scales + + fmha_out = paddle.incubate.nn.functional.block_multihead_attention( + qkv_out, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("block_tables", None), + pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache + pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache + k_quant_scales[i] if k_quant_scales is not None else None, + v_quant_scales[i] if v_quant_scales is not None else None, + k_dequant_scales[i] if k_dequant_scales is not None else None, + v_dequant_scales[i] if v_dequant_scales is not None else None, + None, # qkv_out_scales + None, # qkv_bias + None, # out_shifts + None, # out_smooths + rotary_embs, + attn_mask, + kwargs.get("tgt_mask", None), + kwargs.get("max_input_length", -1), + kwargs.get("block_size", 64), + self.use_neox_rotary_style, + self.config.use_dynamic_cachekv_quant, + quant_round_type=self.config.quant_round_type, + quant_max_bound=self.config.quant_max_bound, + quant_min_bound=self.config.quant_min_bound, + )[0] + + out_linear_out = self.compute_out_linear(fmha_out, i) + + return out_linear_out + + def post_process(self, **kwargs): + multi_block_output = kwargs.get("multi_block_output", None) + cum_offsets = kwargs.get("cum_offsets", None) + seq_lens_encoder = kwargs.get("seq_lens_encoder", None) + seq_lens_decoder = kwargs.get("seq_lens_decoder", None) + max_input_length = kwargs.get("max_input_length", -1) + + out = rebuild_padding_v2(multi_block_output, cum_offsets, seq_lens_decoder, seq_lens_encoder, max_input_length) + + return out + + +class FusedBlockMultiTransformerWeightOnly(FusedBlockMultiTransformer, FusedMultiTransformerWeightOnly): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + +class FusedBlockMultiTransformerA8W8(FusedBlockMultiTransformer, FusedMultiTransformerA8W8): + def __init__(self, config: FusedMultiTransformerConfig): + super().__init__(config) + + def compute_attn( + self, + time_step, + qkv_out, + padding_offset, + seq_lens, + input_ids, + rotary_embs, + rotary_emb_dims, + caches, + pre_caches, + pre_caches_length, + attn_mask, + i, + **kwargs, + ): + k_quant_scales = kwargs.get("k_quant_scales", None) + v_quant_scales = kwargs.get("v_quant_scales", None) + k_dequant_scales = kwargs.get("k_dequant_scales", None) + v_dequant_scales = kwargs.get("v_dequant_scales", None) + + if not self.config.use_dynamic_cachekv_quant: + k_quant_scales = self.cache_k_scales + v_quant_scales = self.cache_v_scales + k_dequant_scales = self.cache_k_out_scales + v_dequant_scales = self.cache_v_out_scales + + fmha_out = paddle.incubate.nn.functional.block_multihead_attention( + qkv_out, + caches[2 * i], + caches[2 * i + 1], + kwargs.get("seq_lens_encoder", None), + kwargs.get("seq_lens_decoder", None), + kwargs.get("seq_lens_this_time", None), + kwargs.get("padding_offsets", None), + kwargs.get("cum_offsets", None), + kwargs.get("cu_seqlens_q", None), + kwargs.get("cu_seqlens_k", None), + kwargs.get("block_tables", None), + pre_caches[2 * i] if pre_caches is not None else None, # pre_key_cache + pre_caches[2 * i + 1] if pre_caches is not None else None, # pre_value_cache + k_quant_scales[i] if k_quant_scales is not None else None, + v_quant_scales[i] if v_quant_scales is not None else None, + k_dequant_scales[i] if k_dequant_scales is not None else None, + v_dequant_scales[i] if v_dequant_scales is not None else None, + self.qkv_out_scales[i], + self.qkv_biases[i] if len(self.qkv_biases) > 0 else None, + self.linear_shifts[i] if len(self.linear_shifts) > 0 else None, + self.linear_smooths[i] if len(self.linear_smooths) > 0 else None, + rotary_embs, + attn_mask, + kwargs.get("tgt_mask", None), + kwargs.get("max_input_length", -1), + kwargs.get("block_size", 64), + self.use_neox_rotary_style, + self.config.use_dynamic_cachekv_quant, + quant_round_type=self.quant_round_type, + quant_max_bound=self.quant_max_bound, + quant_min_bound=self.quant_min_bound, + out_scale=self.act_scales["out_linear_in_scale"][i], + compute_dtype=self._fuse_kernel_compute_dtype, + )[0] + + out_linear_out = self.compute_out_linear(fmha_out, i) + + return out_linear_out diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 842d3612f369..cd95d0685986 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -19,14 +19,19 @@ import paddle.nn.functional as F from paddlenlp_ops import ( get_token_penalty_multi_scores, + get_token_penalty_multi_scores_v2, + save_output, save_with_output, set_stop_value_multi_ends, + set_stop_value_multi_ends_v2, set_value_by_flags_and_idx, + set_value_by_flags_and_idx_v2, + update_inputs, ) from paddlenlp.generation import GenerationMixin, LogitsProcessor, LogitsProcessorList -__all__ = ["GenerationInferenceModel"] +__all__ = ["GenerationInferenceModel", "GenerationBlockInferenceModel"] class ForcedDecodingEOSTokenLogitsProcessor(LogitsProcessor): @@ -155,7 +160,6 @@ def generate( pre_caches=None, **model_kwargs, ): - model_kwargs["position_ids"] = position_ids model_kwargs["attention_mask"] = attention_mask @@ -389,3 +393,306 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): model_kwargs["seq_len_decoder"], model_kwargs["tgt_pos"], ) + + +class GenerationBlockInferenceModel(GenerationMixin): + @classmethod + def get_cache_kvs_shape(cls, max_batch_size: int = None, max_length: int = None) -> list[list[int]]: + raise NotImplementedError + + def to_static(self, output_path: str, config: dict): + dtype = config.get("dtype", paddle.get_default_dtype()) + cachekv_dtype = dtype + + cache_kvs_shapes = self.get_cache_kvs_shape( + self.config, max_batch_size=config.get("max_batch_size", -1), max_length=config.get("max_length", None) + ) + export_precache = config.get("export_precache", False) + if export_precache: + precache_kv_spec = [ + paddle.static.InputSpec(shape=[None, None, None, None], dtype=dtype, name=f"pre_caches_{i}") + for i in range(len(cache_kvs_shapes)) + ] + else: + precache_kv_spec = None + use_cachekv_int8 = config.get("use_cachekv_int8", "None") + + if use_cachekv_int8 == "static" or use_cachekv_int8 == "dynamic": + cachekv_dtype = "uint8" + + if use_cachekv_int8 == "dynamic": + cache_k_quant_scales = [ + paddle.static.InputSpec( + shape=[None, self.config.num_attention_heads], + dtype="float32", + name="k_quant_scales_{}".format(i), + ) + for i in range(int(len(cache_kvs_shapes) / 2)) + ] + + cache_v_quant_scales = [ + paddle.static.InputSpec( + shape=[None, self.config.num_attention_heads], + dtype="float32", + name="v_quant_scales_{}".format(i), + ) + for i in range(int(len(cache_kvs_shapes) / 2)) + ] + + cache_k_dequant_scales = [ + paddle.static.InputSpec( + shape=[None, self.config.num_attention_heads], + dtype="float32", + name="k_dequant_scales_{}".format(i), + ) + for i in range(int(len(cache_kvs_shapes) / 2)) + ] + cache_v_dequant_scales = [ + paddle.static.InputSpec( + shape=[None, self.config.num_attention_heads], + dtype="float32", + name="v_dequant_scales_{}".format(i), + ) + for i in range(int(len(cache_kvs_shapes) / 2)) + ] + else: + cache_k_quant_scales = None + cache_v_quant_scales = None + cache_k_dequant_scales = None + cache_v_dequant_scales = None + + caches = [] + for i in range(len(cache_kvs_shapes) // 2): + caches.append( + paddle.static.InputSpec( + shape=cache_kvs_shapes[2 * i], dtype=cachekv_dtype, name="key_caches_{}".format(i) + ) + ) + caches.append( + paddle.static.InputSpec( + shape=cache_kvs_shapes[2 * i + 1], dtype=cachekv_dtype, name="value_caches_{}".format(i) + ) + ) + if export_precache: + src_mask_spec = paddle.static.InputSpec(shape=[None, 1, None, None], dtype=dtype, name="src_mask") + else: + src_mask_spec = None + input_spec = [ + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"), # input_ids + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="temperature"), # temperature + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="top_p"), # top_p + paddle.static.InputSpec(shape=[None], dtype="int64", name="eos_token_id"), # eos_token_id + src_mask_spec, # src_mask + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="penalty_score"), # penalty_score + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="frequency_score"), # frequency_score + paddle.static.InputSpec(shape=[None, 1], dtype="float32", name="presence_score"), # presence_score + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="next_tokens"), # next_tokens + paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="is_block_step"), # is_block_step + paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_lens_this_time"), # seq_lens_this_time + paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_lens_encoder"), # seq_lens_encoder + paddle.static.InputSpec(shape=[None, 1], dtype="int32", name="seq_lens_decoder"), # seq_lens_decoder + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="step_idx"), # step_idx + paddle.static.InputSpec(shape=[None, 1], dtype="bool", name="stop_flags"), # stop_flags + paddle.static.InputSpec( + shape=[2, None, self.config.max_seq_len, None, None], dtype="float32", name="rope_emb" + ), # rope_emb + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="min_length"), # min_dec_len + paddle.static.InputSpec(shape=[None, 1], dtype="int64", name="max_length"), # max_dec_len + paddle.static.InputSpec(shape=[1, 1], dtype="int64", name="stop_nums"), # stop_nums + paddle.static.InputSpec(shape=[None], dtype="int64", name="bad_tokens"), # bad_tokens + paddle.static.InputSpec(shape=[1, 1], dtype="bool", name="not_need_stop"), # not_need_stop + paddle.static.InputSpec(shape=[None, None], dtype="int32", name="block_tables"), # block_tables + paddle.static.InputSpec(shape=[None, None], dtype="int64", name="pre_ids"), # pre_ids + precache_kv_spec, + caches, # cache_kvs + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_dequant_scales, + cache_v_dequant_scales, + ] + model = paddle.jit.to_static(self.generate, input_spec=input_spec) + paddle.jit.save( + model, output_path, skip_prune_program=True + ) # Note(Zhengzekang): If we prune program it may cause some inference error. + + @staticmethod + def prepare_input_ids_for_generation(bos_token_id, encoder_output=None): + batch_size = 1 + seq_len = 1 + if bos_token_id is None: + raise ValueError("`bos_token_id` should be defined when no " "`input_ids` are provided.") + if encoder_output is not None: + batch_size = encoder_output.shape[0] + seq_len = encoder_output.shape[1] + return paddle.ones([batch_size, seq_len], dtype="int64") * bos_token_id + + @paddle.no_grad() + def generate( + self, + input_ids=None, + temperature=None, + top_p=None, + eos_token_id=None, + src_mask=None, + penalty_score=None, + frequency_score=None, + presence_score=None, + next_tokens=None, + is_block_step=None, + seq_lens_this_time=None, # update + seq_lens_encoder=None, # update + seq_lens_decoder=None, # update + step_idx=None, + stop_flags=None, + rope_emb=None, + min_length=None, + max_length=None, + stop_nums=None, + bad_tokens=None, + not_need_stop=None, + block_tables=None, + pre_ids=None, + pre_caches=None, + cache_kvs=[], + k_quant_scales=None, + v_quant_scales=None, + k_dequant_scales=None, + v_dequant_scales=None, + **model_kwargs, + ): + + model_kwargs["input_ids"] = input_ids + model_kwargs["penalty_score"] = penalty_score + model_kwargs["frequency_score"] = frequency_score + model_kwargs["presence_score"] = presence_score + model_kwargs["seq_lens_this_time"] = seq_lens_this_time + model_kwargs["seq_lens_encoder"] = seq_lens_encoder + model_kwargs["seq_lens_decoder"] = seq_lens_decoder + model_kwargs["step_idx"] = step_idx + model_kwargs["stop_flags"] = stop_flags + model_kwargs["min_dec_len"] = min_length + model_kwargs["max_dec_len"] = max_length + model_kwargs["stop_nums"] = stop_nums + model_kwargs["rope_emb"] = rope_emb + model_kwargs["bad_tokens"] = bad_tokens + model_kwargs["block_tables"] = block_tables + model_kwargs["pre_ids"] = pre_ids + model_kwargs["not_need_stop"] = not_need_stop + model_kwargs["caches"] = cache_kvs + model_kwargs["k_quant_scales"] = k_quant_scales + model_kwargs["v_quant_scales"] = v_quant_scales + model_kwargs["k_dequant_scales"] = k_dequant_scales + model_kwargs["v_dequant_scales"] = v_dequant_scales + model_kwargs["pre_caches"] = pre_caches + model_kwargs["next_tokens"] = next_tokens + model_kwargs["is_block_step"] = is_block_step + model_kwargs["src_mask"] = src_mask + + ret = self.sample( + eos_token_id, + top_k=0, + top_p=top_p, + temperature=temperature, + **model_kwargs, + ) + return ret + + def sample( + self, + eos_token_id, + top_k, + top_p, + penalty_score, + frequency_score, + presence_score, + temperature=None, + min_tokens_to_keep=1, + **model_kwargs + ): + def _forward_(**args): + model_inputs = self.prepare_inputs_for_generation(**args) + return self(**model_inputs) + + def _post_process_( + outputs, + top_k, + top_p, + penalty_score, + frequency_score, + presence_score, + temperature, + model_kwargs, + ): + step_idx = model_kwargs["step_idx"] + set_value_by_flags_and_idx_v2( + model_kwargs["pre_ids"], + model_kwargs["input_ids"], + model_kwargs["seq_lens_this_time"], + model_kwargs["seq_lens_encoder"], + model_kwargs["seq_lens_decoder"], + step_idx, + model_kwargs["stop_flags"], + ) + + logits = paddle.cast(outputs, paddle.float32) + + # pre-process distribution + logits = get_token_penalty_multi_scores_v2( + model_kwargs["pre_ids"], + logits, + penalty_score, + frequency_score, + presence_score, + temperature, + model_kwargs["bad_tokens"], + step_idx, + model_kwargs["min_dec_len"], + eos_token_id, + ) + + # sample + probs = F.softmax(logits) + # _, next_tokens = top_p_sampling(probs, top_p, -1) + _, next_tokens = paddle.topk(probs, 1, -1) + + if self.config.tensor_parallel_degree > 1: + paddle.distributed.broadcast(next_tokens, 0) + + step_idx = paddle.where(model_kwargs["stop_flags"], model_kwargs["step_idx"], model_kwargs["step_idx"] + 1) + paddle.assign(step_idx, model_kwargs["step_idx"]) + length_cond = paddle.greater_equal(model_kwargs["step_idx"], model_kwargs["max_dec_len"]) + stop_flags = paddle.logical_or(model_kwargs["stop_flags"], length_cond) + set_stop_value_multi_ends_v2( + next_tokens, stop_flags, model_kwargs["seq_lens_this_time"], eos_token_id, model_kwargs["next_tokens"] + ) # multi ends + paddle.assign(stop_flags, model_kwargs["stop_flags"]) + # update inputs + update_inputs( + model_kwargs["stop_flags"], + model_kwargs["not_need_stop"], + model_kwargs["seq_lens_this_time"], + model_kwargs["seq_lens_encoder"], + model_kwargs["seq_lens_decoder"], + model_kwargs["input_ids"], + model_kwargs["stop_nums"], + next_tokens, + model_kwargs["is_block_step"], + ) + save_output(next_tokens, model_kwargs["not_need_stop"], self.config.tensor_parallel_rank) + return next_tokens + + # encoder + outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed] + # first decoder + next_tokens = _post_process_( + outputs, + top_k, + top_p, + penalty_score, + frequency_score, + presence_score, + temperature, + model_kwargs, + ) + + return next_tokens diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index cbd2c78c0500..6923ba0db0ec 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -15,22 +15,35 @@ import json import os +from functools import partial import numpy as np import paddle from paddle import nn from paddle.distributed import fleet from paddle.nn.quant import weight_quantize -from paddlenlp_ops import fused_get_rotary_embedding, get_padding_offset +from paddlenlp_ops import ( + fused_get_rotary_embedding, + get_padding_offset, + get_padding_offset_v2, +) -from paddlenlp.experimental.model_utils import ActScalesLoader, WeightScalesLoader +from paddlenlp.experimental.model_utils import ( + ActScalesLoader, + CacheScaleLoader, + WeightScalesLoader, +) from paddlenlp.experimental.transformers.fused_transformer_layers import ( + FusedBlockMultiTransformer, + FusedBlockMultiTransformerA8W8, + FusedBlockMultiTransformerWeightOnly, FusedMultiTransformerA8W8, FusedMultiTransformerBase, FusedMultiTransformerConfig, FusedMultiTransformerWeightOnly, ) from paddlenlp.experimental.transformers.generation_utils import ( + GenerationBlockInferenceModel, GenerationInferenceModel, ) from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel @@ -43,8 +56,14 @@ dy2st_nocheck_guard_context, register_base_model, ) +from paddlenlp.utils.log import logger -__all__ = ["LlamaInferenceModel", "LlamaForCausalLMInferenceModel", "LlamaForMiniGPT4InferenceModel"] +__all__ = [ + "LlamaInferenceModel", + "LlamaForCausalLMInferenceModel", + "LlamaForCausalLMBlockInferenceModel", + "LlamaForMiniGPT4InferenceModel", +] class FusedLlamaRMSNorm(nn.Layer): @@ -242,6 +261,25 @@ def __init__(self, config: LlamaConfig): paddle.ParamAttr(name="fusellama.{}.ffn2_weight_scale".format(i)) for i in range(self.num_layers) ] + cache_k_scale_attrs = None + cache_v_scale_attrs = None + cache_k_out_scale_attrs = None + cache_v_out_scale_attrs = None + + if config.use_cachekv_int8 == "static": + cache_k_scale_attrs = [ + paddle.ParamAttr(name="fusellama.{}.cache_k_scale".format(i)) for i in range(self.num_layers) + ] + cache_v_scale_attrs = [ + paddle.ParamAttr(name="fusellama.{}.cache_v_scale".format(i)) for i in range(self.num_layers) + ] + cache_k_out_scale_attrs = [ + paddle.ParamAttr(name="fusellama.{}.cache_k_out_scale".format(i)) for i in range(self.num_layers) + ] + cache_v_out_scale_attrs = [ + paddle.ParamAttr(name="fusellama.{}.cache_v_out_scale".format(i)) for i in range(self.num_layers) + ] + transformer_config = FusedMultiTransformerConfig( self.hidden_size, self.num_attention_heads, @@ -275,18 +313,18 @@ def __init__(self, config: LlamaConfig): ffn_ln_bias_attrs=ffn_ln_bias_attrs, ffn1_bias_attrs=ffn1_bias_attrs, ffn2_bias_attrs=ffn2_bias_attrs, + cache_k_scale_attrs=cache_k_scale_attrs, + cache_v_scale_attrs=cache_v_scale_attrs, + cache_k_out_scale_attrs=cache_k_out_scale_attrs, + cache_v_out_scale_attrs=cache_v_out_scale_attrs, epsilon=self.epsilon, norm_type="rmsnorm", use_neox_rotary_style=True, + use_dynamic_cachekv_quant=config.use_cachekv_int8 == "dynamic", + rank_id=config.tensor_parallel_rank, ) - if self.use_weight_only: - self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) - elif self.quant_type == "a8w8": - self.transformer_block = FusedMultiTransformerA8W8(transformer_config) - else: - self.transformer_block = FusedMultiTransformerBase(transformer_config) - + self.set_transformer_block(transformer_config) self.norm = FusedLlamaRMSNorm(config) self.cache_kvs = None @@ -294,6 +332,14 @@ def __init__(self, config: LlamaConfig): self.gradient_checkpointing = False + def set_transformer_block(self, transformer_config): + if self.use_weight_only: + self.transformer_block = FusedMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8": + self.transformer_block = FusedMultiTransformerA8W8(transformer_config) + else: + self.transformer_block = FusedMultiTransformerBase(transformer_config) + def get_input_embeddings(self): return self.embed_tokens @@ -431,6 +477,10 @@ def set_state_dict(self, state_dict): self.norm.weight.set_value(paddle.to_tensor(state_dict["llama.norm.weight"], dtype=self.norm.weight.dtype)) for idx in range(self.config.num_hidden_layers): + logger.info(f"set state for layer {idx}") + + if self.use_weight_only: + logger.info("weight only is enabled") unfused_state_dict = {} unfused_state_dict["self_attn.q_proj.weight"] = state_dict[ "llama.layers.{}.self_attn.q_proj.weight".format(idx) @@ -471,8 +521,11 @@ def set_state_dict(self, state_dict): qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight) qkv_weight_tensor = paddle.transpose(qkv_weight_tensor, perm=[1, 0]) qkv_quanted_weight_tensor, qkv_weight_scale_tensor = weight_quantize( - qkv_weight_tensor, algo=self.quant_type + qkv_weight_tensor.cuda(), algo=self.quant_type ) + qkv_quanted_weight_tensor = qkv_quanted_weight_tensor.cpu() + qkv_weight_scale_tensor = qkv_weight_scale_tensor.cpu() + qkv_weight_scale_tensor = qkv_weight_scale_tensor.cast(qkv_weight_tensor.dtype) self.transformer_block.qkv_weights[idx].set_value(qkv_quanted_weight_tensor) self.transformer_block.qkv_weights_scale[idx].set_value(qkv_weight_scale_tensor) elif self.quant_type == "a8w8": @@ -485,8 +538,11 @@ def set_state_dict(self, state_dict): linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]) if self.use_weight_only: linear_quanted_weight_tensor, linear_weight_scale_tensor = weight_quantize( - linear_weight_tensor, algo=self.quant_type + linear_weight_tensor.cuda(), algo=self.quant_type ) + linear_quanted_weight_tensor = linear_quanted_weight_tensor.cpu() + linear_weight_scale_tensor = linear_weight_scale_tensor.cpu() + linear_weight_scale_tensor = linear_weight_scale_tensor.cast(linear_weight_tensor.dtype) self.transformer_block.linear_weights[idx].set_value(linear_quanted_weight_tensor) self.transformer_block.linear_weights_scale[idx].set_value(linear_weight_scale_tensor) elif self.quant_type == "a8w8": @@ -503,8 +559,11 @@ def set_state_dict(self, state_dict): if self.use_weight_only: ffn1_quanted_weight_tensor, ffn1_weight_scale_tensor = weight_quantize( - ffn1_weight_tensor, algo=self.quant_type + ffn1_weight_tensor.cuda(), algo=self.quant_type ) + ffn1_quanted_weight_tensor = ffn1_quanted_weight_tensor.cpu() + ffn1_weight_scale_tensor = ffn1_weight_scale_tensor.cpu() + ffn1_weight_scale_tensor = ffn1_weight_scale_tensor.cast(ffn1_weight_tensor.dtype) self.transformer_block.ffn1_weights[idx].set_value(ffn1_quanted_weight_tensor) self.transformer_block.ffn1_weights_scale[idx].set_value(ffn1_weight_scale_tensor) elif self.quant_type == "a8w8": @@ -517,8 +576,11 @@ def set_state_dict(self, state_dict): ffn2_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.mlp.down_proj.weight".format(idx)]) if self.use_weight_only: ffn2_quanted_weight_tensor, ffn2_weight_scale_tensor = weight_quantize( - ffn2_weight_tensor, algo=self.quant_type + ffn2_weight_tensor.cuda(), algo=self.quant_type ) + ffn2_quanted_weight_tensor = ffn2_quanted_weight_tensor.cpu() + ffn2_weight_scale_tensor = ffn2_weight_scale_tensor.cpu() + ffn2_weight_scale_tensor = ffn2_weight_scale_tensor.cast(ffn2_weight_tensor.dtype) self.transformer_block.ffn2_weights[idx].set_value(ffn2_quanted_weight_tensor) self.transformer_block.ffn2_weights_scale[idx].set_value(ffn2_weight_scale_tensor) elif self.quant_type == "a8w8": @@ -622,10 +684,18 @@ def set_state_dict(self, state_dict): scale_map_dict = json.load(json_file) act_scale_map_dict = scale_map_dict["act_scale"] weight_scale_map_dict = scale_map_dict["weight_scale"] + cache_scale_map_dict = scale_map_dict["cachekv_scale"] # TODO(RichardWooSJTU): support multi-cards act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json") weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + act_scale_json_path = os.path.join( + self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json" + ) + weight_scale_json_path = os.path.join( + self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json" + ) act_scale_loader = ActScalesLoader( act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers ) @@ -638,29 +708,72 @@ def set_state_dict(self, state_dict): concat_qkv=True, concat_ffn1=True, ) + + if self.config.use_cachekv_int8 == "static": + cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_act_scales.json") + if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq: + cache_scale_json_path = os.path.join( + self.quant_model_path, f"cachekv_act_scales_{self.config.tensor_parallel_rank}.json" + ) + cache_scales_loader = CacheScaleLoader( + cache_scale_json_path, + cache_scale_map_dict, + num_of_layers=self.config.num_hidden_layers, + num_heads=self.num_attention_heads // self.config.tensor_parallel_degree, + ) + for k, v in cache_scales_loader.scale.items(): + for i_layer, weight_scale in enumerate(v): + weight_scale = weight_scale.astype("float32") + if k == "cache_k_scale": + self.transformer_block.cache_k_scales[i_layer].set_value(weight_scale) + elif k == "cache_v_scale": + self.transformer_block.cache_v_scales[i_layer].set_value(weight_scale) + elif k == "cache_k_out_scale": + self.transformer_block.cache_k_out_scales[i_layer].set_value(weight_scale) + else: + self.transformer_block.cache_v_out_scales[i_layer].set_value(weight_scale) + for k, v in weight_scales_loader.scale.items(): if "qkv_" in k: for i_layer, weight_scale in enumerate(v): tmp = paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["qkv_in_scale"][i_layer]) + weight_scale + / ( + 127.0 * 127.0 * act_scale_loader.scale["qkv_in_scale"][i_layer] + ) # [3 * num_head * dim_head] ).reshape([-1]) + + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = ( + tmp.reshape([3, self.num_attention_heads, head_size]) + .split(self.config.tensor_parallel_degree, axis=1)[ + self.config.tensor_parallel_rank + ] + .reshape([-1]) + ) self.transformer_block.qkv_out_scales[i_layer].set_value(tmp) pass elif "out_linear_" in k: for i_layer, weight_scale in enumerate(v): - self.transformer_block.linear_out_scales[i_layer].set_value( - paddle.to_tensor( - weight_scale - / (127.0 * 127.0 * act_scale_loader.scale["out_linear_in_scale"][i_layer]) - ) + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["out_linear_in_scale"][i_layer]) ) + self.transformer_block.linear_out_scales[i_layer].set_value(tmp) elif "ffn1_weight_scale" in k: for i_layer, weight_scale in enumerate(v): - self.transformer_block.ffn1_out_scales[i_layer].set_value( - paddle.to_tensor( - weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn1_in_scale"][i_layer]) - ) + tmp = paddle.to_tensor( + weight_scale / (127.0 * 127.0 * act_scale_loader.scale["ffn1_in_scale"][i_layer]) ) + if self.config.tensor_parallel_degree > 1 and self.config.single_card_ptq: + tmp = paddle.split(tmp, self.config.tensor_parallel_degree * 2) + tmp = paddle.concat( + [ + tmp[self.config.tensor_parallel_rank], + tmp[self.config.tensor_parallel_rank + self.config.tensor_parallel_degree], + ], + axis=0, + ) + self.transformer_block.ffn1_out_scales[i_layer].set_value(tmp) elif "ffn2" in k: for i_layer, weight_scale in enumerate(v): self.transformer_block.ffn2_out_scales[i_layer].set_value( @@ -670,6 +783,75 @@ def set_state_dict(self, state_dict): ) +@register_base_model +class LlamaBlockInferenceModel(LlamaInferenceModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.max_seq_len = config.max_seq_len + self.block_size = config.block_size + + def set_transformer_block(self, transformer_config): + if self.use_weight_only: + self.transformer_block = FusedBlockMultiTransformerWeightOnly(transformer_config) + elif self.quant_type == "a8w8": + self.transformer_block = FusedBlockMultiTransformerA8W8(transformer_config) + else: + self.transformer_block = FusedBlockMultiTransformer(transformer_config) + + def remove_padding(self, input_ids, seq_lens_this_time): + cum_offsets_now = paddle.cumsum(self.max_seq_len - seq_lens_this_time) + token_num = paddle.sum(seq_lens_this_time) + ids_remove_padding, cum_offsets, padding_offset, cu_seqlens_q, cu_seqlens_k = get_padding_offset_v2( + input_ids, cum_offsets_now, token_num, seq_lens_this_time + ) + return ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + caches=None, + pre_caches=None, + output_attentions=False, + output_hidden_states=None, + return_dict=False, + **kwargs, + ): + + seq_lens_this_time = kwargs.get("seq_lens_this_time", None) + rope_emb = kwargs.get("rope_emb", None) + ids_remove_padding, padding_offset, cum_offsets, cu_seqlens_q, cu_seqlens_k = self.remove_padding( + input_ids, seq_lens_this_time + ) + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_k"] = cu_seqlens_k + kwargs["padding_offsets"] = padding_offset + kwargs["max_input_length"] = self.max_seq_len + + inputs_embeds = self.embed_tokens(ids_remove_padding) + + with dy2st_nocheck_guard_context(): + hidden_states, _ = self.transformer_block( + input_ids=input_ids, + src=inputs_embeds, + cum_offsets=cum_offsets, + attn_mask=attention_mask, + caches=caches, + pre_caches=pre_caches, + rotary_embs=rope_emb, + **kwargs, + ) + hidden_states = self.norm(hidden_states) + + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=None, + hidden_states=None, + attentions=None, + ) + + class LlamaForCausalLMInferenceModel(GenerationInferenceModel, LlamaPretrainedModel): """ Dynamic Batching for LLaMA Model with pretraining tasks on top. @@ -831,6 +1013,266 @@ def set_state_dict(self, state_dict): self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) +class LlamaForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, LlamaPretrainedModel): + """ + Dynamic Batching for LLaMA Model with pretraining tasks on top. + """ + + _keys_to_ignore_on_load_missing = [r"lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.llama = LlamaBlockInferenceModel(config) + self.lm_head = LlamaLMHead(config) + + @classmethod + def _get_tensor_parallel_mappings(cls, config: LlamaConfig, is_split=True): + + logger.info("llama inference model _get_tensor_parallel_mappings") + + from paddlenlp.transformers.conversion_utils import split_or_merge_func + + fn = split_or_merge_func( + is_split=is_split, + tensor_parallel_degree=config.tensor_parallel_degree, + tensor_parallel_rank=config.tensor_parallel_rank, + num_attention_heads=config.num_attention_heads, + ) + + def get_tensor_parallel_split_mappings(num_layers): + final_actions = {} + + base_actions = { + "lm_head.weight": partial(fn, is_column=True), + # Row Linear + "embed_tokens.weight": partial(fn, is_column=False), + "layers.0.self_attn.o_proj.weight": partial(fn, is_column=False), + "layers.0.mlp.down_proj.weight": partial(fn, is_column=False), + } + + if config.quant_type == "a8w8": + if config.quantization_config.shift_smooth_all_linears: + base_actions["layers.0.self_attn.o_proj.shift_bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.o_proj.smooth_weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.shift_bias"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.down_proj.smooth_weight"] = partial(fn, is_column=True) + + if config.quantization_config.shift: + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.bias"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.bias"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.bias"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.bias"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.bias"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.bias"] = partial(fn, is_column=True) + + # Column Linear + if config.fuse_attention_qkv: + base_actions["layers.0.self_attn.qkv_proj.weight"] = partial(fn, is_column=True) + else: + base_actions["layers.0.self_attn.q_proj.weight"] = partial(fn, is_column=True) + # if we have enough num_key_value_heads to split, then split it. + if config.num_key_value_heads % config.tensor_parallel_degree == 0: + base_actions["layers.0.self_attn.k_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.self_attn.v_proj.weight"] = partial(fn, is_column=True) + + if config.fuse_attention_ffn: + base_actions["layers.0.mlp.gate_up_fused_proj.weight"] = partial( + fn, is_column=True, is_naive_2fuse=True + ) + else: + base_actions["layers.0.mlp.gate_proj.weight"] = partial(fn, is_column=True) + base_actions["layers.0.mlp.up_proj.weight"] = partial(fn, is_column=True) + + for key, action in base_actions.items(): + if "layers.0." in key: + for i in range(num_layers): + final_actions[key.replace("layers.0.", f"layers.{i}.")] = action + final_actions[key] = action + + return final_actions + + mappings = get_tensor_parallel_split_mappings(config.num_hidden_layers) + + return mappings + + @classmethod + def from_pretrained( + cls, pretrained_model_name_or_path, from_hf_hub: bool = False, subfolder: str | None = None, *args, **kwargs + ): + # TODO: Support safetensors loading. + kwargs["use_safetensors"] = False + from paddlenlp.transformers.utils import ( + ContextManagers, + is_safetensors_available, + resolve_cache_dir, + ) + + config = kwargs.pop("config", None) + from_aistudio = kwargs.get("from_aistudio", False) + subfolder = kwargs.get("subfolder", None) + variant = kwargs.pop("variant", None) + use_safetensors = kwargs.pop("use_safetensors", None if is_safetensors_available() else False) + convert_from_torch = kwargs.pop("convert_from_torch", None) + cache_dir = kwargs.pop("cache_dir", None) + + cache_dir = resolve_cache_dir(pretrained_model_name_or_path, from_hf_hub, cache_dir) + + init_contexts = [] + with ContextManagers(init_contexts): + model = cls(config) + + if not config.single_card_ptq: + resolved_archive_file = pretrained_model_name_or_path + else: + resolved_archive_file = cls._resolve_model_file_path( + pretrained_model_name_or_path, + cache_dir=cache_dir, + subfolder=subfolder, + from_hf_hub=from_hf_hub, + from_aistudio=from_aistudio, + config=config, + convert_from_torch=convert_from_torch, + use_safetensors=use_safetensors, + variant=variant, + )[0] + logger.info(f"Load model form {resolved_archive_file}") + + if config.tensor_parallel_degree > 1 and config.single_card_ptq: + logger.info(f"convert_tensor_parallel {config.tensor_parallel_degree}") + model.state_dict = model.convert_tensor_parallel(resolved_archive_file, config) + elif config.tensor_parallel_degree > 1: + resolved_archive_file = os.path.join( + resolved_archive_file, f"mp_{config.tensor_parallel_rank:0>2d}_sharding_00_pp_00", "model.pdparams" + ) + model.state_dict = paddle.load(resolved_archive_file, return_numpy=True) + else: + model.state_dict = paddle.load(resolved_archive_file, return_numpy=True) + model.set_state_dict(model.state_dict) + + return model + + @classmethod + def get_cache_kvs_shape( + cls, config: LlamaConfig, max_batch_size: int = None, max_length: int = None + ) -> list[list[int]]: + """get cache_kvs tensor for llama model + + Args: + max_batch_size (int): the max batch size + max_length (int | None, optional): the max_length of cache_kvs. Defaults to None. + + Returns: + list[paddle.Tensor]: the list tensor shape for cache + """ + max_block_per_seq = (config.max_seq_len + config.block_size - 1) // config.block_size + if max_batch_size == -1: + max_block_nums = None + else: + max_block_nums = max_batch_size * max_block_per_seq + + cache_kvs = [] + for _ in range(config.num_hidden_layers): + cache_kv_shape = [ + max_block_nums, + config.num_attention_heads // max(config.tensor_parallel_degree, 1), + config.block_size, + config.hidden_size // config.num_attention_heads, + ] + cache_kvs.append(cache_kv_shape) + cache_kvs.append(cache_kv_shape) + return cache_kvs + + def prepare_inputs_for_generation(self, **kwargs): + # only last token for inputs_ids if cache is defined in kwargs + input_ids = kwargs["input_ids"] + src_mask = kwargs.get("src_mask", None) + block_tables = kwargs.get("block_tables", None) + + pre_caches = kwargs.get("pre_caches", None) + caches = kwargs.get("caches", None) + + rope_emb = kwargs["rope_emb"] + seq_lens_this_time = kwargs["seq_lens_this_time"] + seq_lens_encoder = kwargs["seq_lens_encoder"] + seq_lens_decoder = kwargs["seq_lens_decoder"] + k_quant_scales = kwargs.get("k_quant_scales", None) + v_quant_scales = kwargs.get("v_quant_scales", None) + k_dequant_scales = kwargs.get("k_dequant_scales", None) + v_dequant_scales = kwargs.get("v_dequant_scales", None) + model_inputs = { + "input_ids": input_ids, + "src_mask": src_mask, + "rope_emb": rope_emb, + "pre_caches": pre_caches, + "caches": caches, + "seq_lens_this_time": seq_lens_this_time, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "block_tables": block_tables, + "k_quant_scales": k_quant_scales, + "v_quant_scales": v_quant_scales, + "k_dequant_scales": k_dequant_scales, + "v_dequant_scales": v_dequant_scales, + } + return model_inputs + + def forward( + self, + input_ids, + src_mask=None, + pre_caches=None, + caches=None, + seq_lens_this_time=None, + seq_lens_encoder=None, + seq_lens_decoder=None, + rope_emb=None, + block_tables=None, + k_quant_scales=None, + v_quant_scales=None, + k_dequant_scales=None, + v_dequant_scales=None, + ): + outputs = self.llama( + input_ids, + src_mask=src_mask, + caches=caches, + rope_emb=rope_emb, + block_tables=block_tables, + pre_caches=pre_caches, + seq_lens_this_time=seq_lens_this_time, + seq_lens_encoder=seq_lens_encoder, + seq_lens_decoder=seq_lens_decoder, + k_quant_scales=k_quant_scales, + v_quant_scales=v_quant_scales, + k_dequant_scales=k_dequant_scales, + v_dequant_scales=v_dequant_scales, + ) + + hidden_states = outputs[0] + logits = self.lm_head( + hidden_states, + tensor_parallel_output=False, + ) + + return logits + + @paddle.no_grad() + def set_state_dict(self, state_dict): + if "lm_head.weight" in state_dict: + self.lm_head.weight.set_value(state_dict["lm_head.weight"]) + self.llama.set_state_dict({k: state_dict[k] for k in state_dict.keys()}) + + class LlamaForMiniGPT4InferenceModel(LlamaForCausalLMInferenceModel): """ This class is 99% like LlamaForCausalLMInferenceModel. diff --git a/paddlenlp/experimental/transformers/llama/ptq_scales_map.json b/paddlenlp/experimental/transformers/llama/ptq_scales_map.json index 4667101da4db..409db47f2f1f 100644 --- a/paddlenlp/experimental/transformers/llama/ptq_scales_map.json +++ b/paddlenlp/experimental/transformers/llama/ptq_scales_map.json @@ -13,5 +13,9 @@ "ffn1_1_weight_scale":"llama.layers.#.mlp.gate_proj.weight_quanter", "ffn1_2_weight_scale":"llama.layers.#.mlp.up_proj.weight_quanter", "ffn2_weight_scale":"llama.layers.#.mlp.down_proj.weight_quanter" + }, + "cachekv_scale":{ + "cache_k_scale": "llama.layers.#.self_attn.cachek_matmul.activation_quanter", + "cache_v_scale": "llama.layers.#.self_attn.cachev_matmul.activation_quanter" } } \ No newline at end of file diff --git a/paddlenlp/experimental/transformers/llama/ptq_scales_map_shift_smooth.json b/paddlenlp/experimental/transformers/llama/ptq_scales_map_shift_smooth.json index 3ec1d62b104a..4aa5122355a9 100644 --- a/paddlenlp/experimental/transformers/llama/ptq_scales_map_shift_smooth.json +++ b/paddlenlp/experimental/transformers/llama/ptq_scales_map_shift_smooth.json @@ -13,5 +13,9 @@ "ffn1_1_weight_scale":"llama.layers.#.mlp.gate_proj.weight_quanter", "ffn1_2_weight_scale":"llama.layers.#.mlp.up_proj.weight_quanter", "ffn2_weight_scale":"llama.layers.#.mlp.down_proj.layer.weight_quanter" + }, + "cachekv_scale":{ + "cache_k_scale": "llama.layers.#.self_attn.cachek_matmul.activation_quanter", + "cache_v_scale": "llama.layers.#.self_attn.cachev_matmul.activation_quanter" } } \ No newline at end of file diff --git a/tests/fixtures/llm/ptq.yaml b/tests/fixtures/llm/ptq.yaml index 6f00b1dbfa14..e48bd48e302e 100644 --- a/tests/fixtures/llm/ptq.yaml +++ b/tests/fixtures/llm/ptq.yaml @@ -14,7 +14,7 @@ ptq: ptq_step: 4 default: llama: - model_name_or_path: __internal_testing__/tiny-random-llama + model_name_or_path: __internal_testing__/tiny-fused-llama-inference5.2 chatglm: model_name_or_path: __internal_testing__/tiny-fused-chatglm chatglm2: diff --git a/tests/llm/test_predictor.py b/tests/llm/test_predictor.py index 209ff82fef6a..a19c3f5861c1 100644 --- a/tests/llm/test_predictor.py +++ b/tests/llm/test_predictor.py @@ -205,3 +205,81 @@ def test_create_predictor_with_unexpected_length(self): with argv_context_guard(config): predict() + + +@parameterized_class( + ["model_name_or_path", "model_class"], + [ + ["__internal_testing__/tiny-fused-llama-inference5.2", LlamaForCausalLM], + ], +) +class BlockAttnPredictorTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/predictor.yaml" + model_name_or_path: str = None + model_class = None + + def setUp(self) -> None: + super().setUp() + paddle.set_default_dtype("float32") + self.model_class.from_pretrained(self.model_name_or_path, dtype="float16").save_pretrained(self.output_dir) + AutoTokenizer.from_pretrained(self.model_name_or_path).save_pretrained(self.output_dir) + + def test_blha(self): + self.run_predictor({"inference_model": True, "block_attn": True}) + result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) + self.run_predictor({"inference_model": False}) + result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) + + # compare the generation result of inference & dygraph model + assert len(result_0) == len(result_1) + + count, full_match = 0, 0 + for inference_item, no_inference_item in zip(result_0, result_1): + min_length = min(len(inference_item), len(no_inference_item)) + count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2]) + full_match += int(inference_item[:min_length] == no_inference_item[:min_length]) + + self.assertGreaterEqual(full_match / len(result_0), 0.3) + + if self.model_name_or_path == "__internal_testing__/tiny-fused-chatglm": + self.assertGreaterEqual(count / len(result_0), 0.3) + else: + self.assertGreaterEqual(count / len(result_0), 0.4) + + def test_wint8(self): + self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8", "block_attn": True}) + result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) + self.run_predictor({"inference_model": True, "quant_type": "weight_only_int8"}) + result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) + + assert len(result_0) == len(result_1) + count, full_match = 0, 0 + + for inference_item, no_inference_item in zip(result_0, result_1): + min_length = min(len(inference_item), len(no_inference_item)) + count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2]) + full_match += int(inference_item[:min_length] == no_inference_item[:min_length]) + + self.assertGreaterEqual(full_match / len(result_0), 0.75) + + if self.model_name_or_path == "__internal_testing__/tiny-fused-chatglm": + self.assertGreaterEqual(count / len(result_0), 0.3) + else: + self.assertGreaterEqual(count / len(result_0), 0.4) + + def test_cachekv_int8(self): + self.run_predictor({"inference_model": True, "block_attn": True, "cachekv_int8": True}) + result_0 = self._read_result(os.path.join(self.output_dir, "predict.json")) + self.run_predictor({"inference_model": True, "block_attn": True}) + result_1 = self._read_result(os.path.join(self.output_dir, "predict.json")) + print(f"result_0 {result_0}, result_1 {result_1}") + + assert len(result_0) == len(result_1) + count, full_match = 0, 0 + + for inference_item, no_inference_item in zip(result_0, result_1): + min_length = min(len(inference_item), len(no_inference_item)) + count += int(inference_item[: min_length // 2] == no_inference_item[: min_length // 2]) + full_match += int(inference_item[:min_length] == no_inference_item[:min_length]) + + self.assertGreaterEqual(count / len(result_0), 0.2) diff --git a/tests/llm/test_ptq.py b/tests/llm/test_ptq.py index d230189b3332..2f41cead554d 100644 --- a/tests/llm/test_ptq.py +++ b/tests/llm/test_ptq.py @@ -52,6 +52,19 @@ def test_ptq(self): self.run_predictor({"inference_model": True}) + def test_blha(self): + finetune_config = load_test_config(self.config_path, "ptq", self.model_dir) + + finetune_config["dataset_name_or_path"] = self.data_dir + finetune_config["output_dir"] = self.output_dir + + with argv_context_guard(finetune_config): + from finetune_generation import main + + main() + + self.run_predictor({"inference_model": True, "block_attn": True}) + def test_ptq_smooth(self): finetune_config = load_test_config(self.config_path, "ptq", self.model_dir)