Skip to content

Commit

Permalink
[LLM] Support block_attention/cachekv quant for llama (#7649)
Browse files Browse the repository at this point in the history
* support blha and cache kv quant

* lint

* fix unit test

* fix infer when blha is on

* code refine

* add docs and fix ops

* merge blha read res in predictor

* finish docs

* add docs and unittest

* add unittest

* migrate read res
  • Loading branch information
RichardWooSJTU authored Jan 10, 2024
1 parent 5a32534 commit c5d8d5b
Show file tree
Hide file tree
Showing 29 changed files with 3,380 additions and 79 deletions.
2 changes: 1 addition & 1 deletion .github/codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ coverage:
threshold: 1% # Allow the coverage to drop by 1%, and posting a success status.
patch:
default:
target: 80% # lines adjusted Coverage < 80% CI will fail
target: 80% # lines adjusted Coverage < 80% CI will fail
69 changes: 69 additions & 0 deletions csrc/generation/get_output.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"

#define MAX_BSZ 512

struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};

void GetOutput(const paddle::Tensor& x,
int64_t rank_id,
bool wait_flag) {
if (rank_id > 0) return;

static struct msgdata msg_rcv;

static key_t key = ftok("./", 1);

static int msgid = msgget(key, IPC_CREAT | 0666);

int64_t *out_data = const_cast<int64_t*>(x.data<int64_t>());
int ret = -1;
if (!wait_flag) {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, IPC_NOWAIT);
} else {
ret = msgrcv(msgid, &msg_rcv, (MAX_BSZ + 2) * 4, 0, 0);
}
if(ret == -1)
{
// read none
out_data[0] = -2;
out_data[1] = 0;
return;
}

int bsz = msg_rcv.mtext[1];

for (int64_t i = 0; i < bsz + 2; i++) {
out_data[i] = (int64_t)msg_rcv.mtext[i];
}
return;
}

PD_BUILD_OP(get_output)
.Inputs({"x"})
.Attrs({"rank_id: int64_t",
"wait_flag: bool"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(GetOutput));
96 changes: 96 additions & 0 deletions csrc/generation/get_padding_offset_v2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#include "paddle/extension.h"

__global__ void RemovePaddingV2(int64_t *output_data,
const int64_t *input_data,
const int *seq_lens,
const int *cum_offsets,
const int sequence_length) {
const int bi = blockIdx.x;
const int tid = threadIdx.x;

for (int i = tid; i < seq_lens[bi]; i += blockDim.x) {
const int tgt_seq_id = bi * sequence_length - cum_offsets[bi] + i;
const int src_seq_id = bi * sequence_length + i;
output_data[tgt_seq_id] = input_data[src_seq_id];
}
}

__global__ void GetPaddingOffsetKernelV2(int *padding_offset,
int *cum_offsets_out,
int *cu_seqlens_q,
int *cu_seqlens_k,
const int *cum_offsets,
const int *seq_lens,
const int max_seq_len) {
// get padding offset of each batch
const int bi = blockIdx.x;
const int ti = threadIdx.x;
int cum_offset = bi == 0 ? 0 : cum_offsets[bi - 1];
for (int i = ti; i < seq_lens[bi]; i += blockDim.x) {
padding_offset[bi * max_seq_len - cum_offset + i] = cum_offset;
}
if (ti == 0) {
cum_offsets_out[bi] = cum_offset;
int cum_seq_len = (bi + 1) * max_seq_len - cum_offsets[bi];
cu_seqlens_q[bi + 1] = cum_seq_len;
cu_seqlens_k[bi + 1] = cum_seq_len;
}
}


std::vector<paddle::Tensor> GetPaddingOffsetV2(const paddle::Tensor& input_ids,
const paddle::Tensor& cum_offsets,
const paddle::Tensor& token_num,
const paddle::Tensor& seq_len) {
auto cu_stream = input_ids.stream();
std::vector<int64_t> input_ids_shape = input_ids.shape();
const int bsz = seq_len.shape()[0];
const int seq_length = input_ids_shape[1];
auto cum_offsets_out = cum_offsets.copy_to(cum_offsets.place(), false);
auto cpu_token_num = token_num.copy_to(paddle::CPUPlace(), false);

const int token_num_data = cpu_token_num.data<int64_t>()[0];
auto x_remove_padding = paddle::full({token_num_data}, 0, paddle::DataType::INT64, input_ids.place());
auto padding_offset = paddle::full({token_num_data}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_q = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
auto cu_seqlens_k = paddle::full({bsz + 1}, 0, paddle::DataType::INT32, input_ids.place());
int blockSize = min((token_num_data + 32 - 1) / 32 * 32, 128);
GetPaddingOffsetKernelV2<<<bsz, 128, 0, cu_stream>>>(
padding_offset.data<int>(),
cum_offsets_out.data<int>(),
cu_seqlens_q.data<int>(),
cu_seqlens_k.data<int>(),
cum_offsets.data<int>(),
seq_len.data<int>(),
seq_length);
RemovePaddingV2<<<bsz, blockSize, 0, cu_stream>>>(
x_remove_padding.data<int64_t>(),
input_ids.data<int64_t>(),
seq_len.data<int>(),
cum_offsets_out.data<int>(),
seq_length);
return {x_remove_padding, cum_offsets_out, padding_offset, cu_seqlens_q, cu_seqlens_k}; // , enc_token_num, dec_token_num};
}

std::vector<std::vector<int64_t>> GetPaddingOffsetV2InferShape(const std::vector<int64_t>& input_ids_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& token_num_shape,
const std::vector<int64_t>& seq_len_shape) {
int64_t bsz = seq_len_shape[0];
int64_t seq_len = input_ids_shape[1];
return {{-1}, {bsz}, {-1}, {bsz + 1}, {bsz + 1}};
}

std::vector<paddle::DataType> GetPaddingOffsetV2InferDtype(const paddle::DataType& input_ids_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& token_num_dtype,
const paddle::DataType& seq_len_dtype) {
return {input_ids_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype, seq_len_dtype};
}

PD_BUILD_OP(get_padding_offset_v2)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Outputs({"x_remove_padding", "cum_offsets_out", "padding_offset", "cu_seqlens_q", "cu_seqlens_k"})
.SetKernelFn(PD_KERNEL(GetPaddingOffsetV2))
.SetInferShapeFn(PD_INFER_SHAPE(GetPaddingOffsetV2InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GetPaddingOffsetV2InferDtype));
131 changes: 131 additions & 0 deletions csrc/generation/rebuild_padding_v2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include "helper.h"

template <typename T, int VecSize>
__global__ void RebuildPaddingV2Kernel(T *output_data,
const T *input_data,
const int *cum_offsets,
const int *seq_len_decoder,
const int *seq_len_encoder,
const int seq_len,
const int dim_embed,
const int elem_nums) {
using LoadT = AlignedVector<T, VecSize>;
LoadT src_vec;
const int global_idx = blockDim.x * blockIdx.x + threadIdx.x;
for (int i = global_idx * VecSize; i < elem_nums; i += gridDim.x * blockDim.x * VecSize) {
const int bi = i / dim_embed;
const int bias_idx = i % dim_embed;
int seq_id = 0;
// just encoder or stop, get last token; just decoder, get first token.
if (seq_len_decoder[bi] == 0) {
if (seq_len_encoder[bi] != 0) {
seq_id = seq_len_encoder[bi] - 1;
} else {
return;
}
}
const int ori_token_idx = bi * seq_len - cum_offsets[bi] + seq_id;
const int src_offset = ori_token_idx * dim_embed + bias_idx;
Load<T, VecSize>(&input_data[src_offset], &src_vec);
Store<T, VecSize>(src_vec, &output_data[i]);
}
}

template <paddle::DataType D>
std::vector<paddle::Tensor> rebuild_padding_v2(const paddle::Tensor& tmp_out, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
int max_input_length) {
typedef PDTraits<D> traits_;
typedef typename traits_::DataType DataType_;
typedef typename traits_::data_t data_t;

auto cu_stream = tmp_out.stream();
std::vector<int64_t> tmp_out_shape = tmp_out.shape();
const int token_num = tmp_out_shape[0];
const int dim_embed = tmp_out_shape[1];
const int bsz = cum_offsets.shape()[0];
auto out = paddle::full({bsz, dim_embed}, 0, tmp_out.dtype(), tmp_out.place());
constexpr int PackSize = VEC_16B / sizeof(DataType_);
int elem_nums = out.numel();
int pack_num = elem_nums / PackSize;
const int blocksize = 128;
const int grid_size = (pack_num + blocksize - 1) / blocksize;
RebuildPaddingV2Kernel<DataType_, PackSize><<<grid_size, blocksize, 0, tmp_out.stream()>>>(
reinterpret_cast<DataType_*>(out.data<data_t>()),
reinterpret_cast<DataType_*>(const_cast<data_t*>(tmp_out.data<data_t>())),
cum_offsets.data<int>(),
seq_lens_decoder.data<int>(),
seq_lens_encoder.data<int>(),
max_input_length,
dim_embed,
elem_nums);
return {out};
}

std::vector<paddle::Tensor> RebuildPaddingV2(const paddle::Tensor& tmp_out, // [token_num, dim_embed]
const paddle::Tensor& cum_offsets, // [bsz, 1]
const paddle::Tensor& seq_lens_decoder,
const paddle::Tensor& seq_lens_encoder,
int max_input_length) {
switch (tmp_out.type()) {
case paddle::DataType::BFLOAT16: {
return rebuild_padding_v2<paddle::DataType::BFLOAT16>(
tmp_out,
cum_offsets,
seq_lens_decoder,
seq_lens_encoder,
max_input_length
);
}
case paddle::DataType::FLOAT16: {
return rebuild_padding_v2<paddle::DataType::FLOAT16>(
tmp_out,
cum_offsets,
seq_lens_decoder,
seq_lens_encoder,
max_input_length
);
}
case paddle::DataType::FLOAT32: {
return rebuild_padding_v2<paddle::DataType::FLOAT32>(
tmp_out,
cum_offsets,
seq_lens_decoder,
seq_lens_encoder,
max_input_length
);
}
default: {
PD_THROW(
"NOT supported data type. "
"Only float16, bfloat16 and float32 are supported. ");
break;
}
}
}

std::vector<std::vector<int64_t>> RebuildPaddingV2InferShape(const std::vector<int64_t>& tmp_out_shape,
const std::vector<int64_t>& cum_offsets_shape,
const std::vector<int64_t>& seq_lens_decoder_shape,
const std::vector<int64_t>& seq_lens_encoder_shape) {
int64_t bsz = cum_offsets_shape[0];
int64_t dim_embed = tmp_out_shape[1];
return {{bsz, dim_embed}};
}

std::vector<paddle::DataType> RebuildPaddingV2InferDtype(const paddle::DataType& tmp_out_dtype,
const paddle::DataType& cum_offsets_dtype,
const paddle::DataType& seq_lens_decoder_dtype,
const paddle::DataType& seq_lens_encoder_dtype) {
return {tmp_out_dtype};
}

PD_BUILD_OP(rebuild_padding_v2)
.Inputs({"tmp_out", "cum_offsets", "seq_lens_decoder", "seq_lens_encoder"})
.Outputs({"out"})
.Attrs({"max_input_length: int"})
.SetKernelFn(PD_KERNEL(RebuildPaddingV2))
.SetInferShapeFn(PD_INFER_SHAPE(RebuildPaddingV2InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(RebuildPaddingV2InferDtype));
12 changes: 12 additions & 0 deletions csrc/generation/reset_need_stop_value.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#include "paddle/extension.h"

void SetStopValue(const paddle::Tensor& not_need_stop) {
bool *stop_data = const_cast<bool*>(not_need_stop.data<bool>());
stop_data[0] = true;
}

PD_BUILD_OP(reset_stop_value)
.Inputs({"not_need_stop"})
.Outputs({"not_need_stop_out"})
.SetInplaceMap({{"not_need_stop", "not_need_stop_out"}})
.SetKernelFn(PD_KERNEL(SetStopValue));
58 changes: 58 additions & 0 deletions csrc/generation/save_with_output_msg.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <stdio.h>
#include <string.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
#include "paddle/extension.h"

#define MAX_BSZ 512

struct msgdata {
long mtype;
int mtext[MAX_BSZ + 2]; // stop_flag, bsz, tokens
};

void SaveOutMmsg(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id) {
if (rank_id > 0) return;
auto x_cpu = x.copy_to(paddle::CPUPlace(), false);
int64_t *x_data = x_cpu.data<int64_t>();
static struct msgdata msg_sed;
static key_t key = ftok("./", 1);
static int msgid = msgget(key, IPC_CREAT | 0666);

msg_sed.mtype = 1;
bool not_need_stop_data = not_need_stop.data<bool>()[0];
msg_sed.mtext[0] = not_need_stop_data ? 1 : -1;
int bsz = x.shape()[0];
msg_sed.mtext[1] = bsz;
for (int i = 2; i < bsz + 2; i++) {
msg_sed.mtext[i] = (int)x_data[i - 2];
}
if ((msgsnd(msgid, &msg_sed, (MAX_BSZ + 2) * 4, 0)) == -1) {
// printf("full msg buffer\n");
}
return;
}

PD_BUILD_OP(save_output)
.Inputs({"x", "not_need_stop"})
.Attrs({"rank_id: int64_t"})
.Outputs({"x_out"})
.SetInplaceMap({{"x", "x_out"}})
.SetKernelFn(PD_KERNEL(SaveOutMmsg));
Loading

0 comments on commit c5d8d5b

Please sign in to comment.