Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
df895f7
add initial sd policy
jeffra Sep 29, 2022
d0576d4
Merge branch 'master' into jeffra/sd-policy
jeffra Oct 5, 2022
c8163e2
formatting
jeffra Oct 5, 2022
f34eb5b
add attention kernel and enable cuda-graph for SD models
Oct 7, 2022
5aec40c
add new files
Oct 7, 2022
23674fd
formatting and add dtype to unet
jeffra Oct 7, 2022
6bf0c73
adding more optitmization by enabling ds-encoder with CUDA-Graph
Oct 10, 2022
8b0762b
Merge branch 'master' into jeffra/sd-policy
RezaYazdaniAminabadi Oct 10, 2022
ac29fdc
add missing file
Oct 10, 2022
c177070
Merge branch 'jeffra/sd-policy' of github.com:microsoft/DeepSpeed int…
Oct 10, 2022
27e752b
adapt the triton kernel to be used in more places
Oct 11, 2022
a40780b
add more fusion
Oct 11, 2022
923ddd2
allocate workspace using the padded hidden_size
Oct 11, 2022
2dca3ac
skip the clip-encoder injection for now
Oct 11, 2022
52dd412
Merge branch 'master' into jeffra/sd-policy
jeffra Oct 11, 2022
6079065
add triton to new extra
jeffra Oct 11, 2022
fb1605f
lazy import triton, add sd extra, formatting
jeffra Oct 11, 2022
16bcef1
delay import
jeffra Oct 12, 2022
691fd34
fix previous issue i added
jeffra Oct 13, 2022
c758b03
fix bug with adding bias
Oct 13, 2022
da01337
Merge branch 'master' into jeffra/sd-policy
RezaYazdaniAminabadi Oct 13, 2022
406832c
fixes for triton import and add acks to triton-ops file
jeffra Oct 13, 2022
79b05ca
Merge branch 'jeffra/sd-policy' of github.com:microsoft/DeepSpeed int…
jeffra Oct 13, 2022
a36275c
Merge branch 'master' into jeffra/sd-policy
jeffra Oct 13, 2022
75fbcfe
merge fix & formatting
Oct 13, 2022
e617595
Merge branch 'jeffra/sd-policy' of github.com:microsoft/DeepSpeed int…
Oct 13, 2022
eff95e7
fix small issue
Oct 13, 2022
770e88b
skip cuda-graph for clip-encoder for now (it has issue on larger batc…
Oct 13, 2022
3148881
Merge branch 'master' into jeffra/sd-policy
jeffra Oct 13, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -475,3 +475,121 @@ template void launch_moe_res_matmul(__half* residual,
int seq_len,
int hidden_dim,
cudaStream_t stream);

__global__ void pad_data_kernel(__half* padded_output,
__half* output,
int head_size,
int padded_head_size)
{
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bid = blockIdx.x * (blockDim.y) + threadIdx.y;
int idx = threadIdx.x;
padded_output_cast += (bid * padded_head_size);
output_cast += (bid * head_size);
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;
if (idx < head_size)
padded_output_cast[idx] = output_cast[idx];
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_data_kernel(float* padded_output,
float* output,
int head_size,
int padded_head_size)
{
}
template <typename T>
void pad_data(T* padded_output,
T* output,
int bsz,
int head_size,
int padded_head_size,
cudaStream_t stream)
{
dim3 grid_dim((bsz - 1) / 16 + 1);
dim3 block_dim(padded_head_size / 8, 16);
pad_data_kernel<<<grid_dim, block_dim, 0, stream>>>(
padded_output, output, head_size / 8, padded_head_size / 8);
}
template void pad_data(__half* padded_output,
__half* output,
int bsz,
int head_size,
int padded_head_size,
cudaStream_t stream);
template void pad_data(float* padded_output,
float* output,
int bsz,
int head_size,
int padded_head_size,
cudaStream_t stream);

__global__ void pad_head_seq_kernel(__half* padded_output,
__half* output,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size)
{
float4* padded_output_cast = reinterpret_cast<float4*>(padded_output);
float4* output_cast = reinterpret_cast<float4*>(output);
int bsz = blockIdx.x;
int bid = blockIdx.y * (blockDim.y) + threadIdx.y;
int idx = threadIdx.x;
padded_output_cast += (bsz * padded_seq_len + bid) * padded_head_size;
output_cast += (bsz * seq_len + bid) * head_size;
float4 ZERO;
const __half2 zero_h = __float2half2_rn(0.f);
__half2* ZERO_h = reinterpret_cast<__half2*>(&ZERO);
#pragma unroll
for (int i = 0; i < 4; i++) ZERO_h[i] = zero_h;

if (idx < head_size && bid < seq_len)
padded_output_cast[idx] = output_cast[idx];
else
padded_output_cast[idx] = ZERO;
}
__global__ void pad_head_seq_kernel(float* padded_output,
float* output,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size)
{
}
template <typename T>
void pad_head_seq(T* padded_output,
T* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
cudaStream_t stream)
{
dim3 grid_dim(bsz, padded_seq_len / 16);
dim3 block_dim(padded_head_size / 8, 16);
pad_head_seq_kernel<<<grid_dim, block_dim, 0, stream>>>(
padded_output, output, seq_len, padded_seq_len, head_size / 8, padded_head_size / 8);
}
template void pad_head_seq(__half* padded_output,
__half* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
cudaStream_t stream);
template void pad_head_seq(float* padded_output,
float* output,
int bsz,
int seq_len,
int padded_seq_len,
int head_size,
int padded_head_size,
cudaStream_t stream);
195 changes: 189 additions & 6 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,10 @@ template <typename T>
at::Tensor ds_linear_layer(at::Tensor& input,
at::Tensor& weight,
at::Tensor& bias,
bool add_bias,
bool external_cache,
bool do_flash_attn,
int num_heads,
unsigned num_layers)
{
auto input_cont = input.contiguous();
Expand All @@ -840,8 +844,23 @@ at::Tensor ds_linear_layer(at::Tensor& input,
.device(at::kCUDA)
.requires_grad(false);

int head_size = input_cont.size(2) / num_heads;
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
// Reallocate memory if we received a new prompt
if (!workspace) {
cublasSetStream(Context::Instance().GetCublasHandle(),
Context::Instance().GetCurrentStream());
allocate_workspace<T>(input.size(2),
input.size(0),
input.size(1),
num_layers,
num_heads,
1,
external_cache,
0);
workspace = (T*)Context::Instance().GetWorkSpace();
}
auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options);

float alpha = (T)1.0;
Expand All @@ -864,16 +883,172 @@ at::Tensor ds_linear_layer(at::Tensor& input,
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
if (add_bias)
launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
bool add_padding = (head_size % 32 != 0 && head_size < 64) || (head_size % 64 != 0);
if (do_flash_attn) {
if (add_padding) {
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
auto padded_output = workspace + output.numel();
auto final_output =
padded_output + (input.size(0) * input.size(1) * 3 * num_heads * padded_head_size);
pad_data(padded_output,
workspace,
3 * bsz * num_heads,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());

launch_bias_add((T*)output.data_ptr(),
(T*)bias.data_ptr(),
weight.size(1),
bsz,
Context::Instance().GetCurrentStream());
launch_bias_add_transform_0213<T>(
final_output,
final_output + (input.size(0) * input.size(1) * num_heads * padded_head_size),
final_output + (input.size(0) * input.size(1) * 2 * num_heads * padded_head_size),
padded_output,
nullptr,
input.size(0),
input.size(1),
0,
input.size(1),
(num_heads * padded_head_size),
num_heads,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(final_output,
{3, input.size(0), num_heads, input.size(1), padded_head_size},
options);
// return at::from_blob(padded_output, {input.size(0) * input.size(1), 3, num_heads,
// padded_head_size}, options);
} else {
auto final_output = workspace + output.numel();
launch_bias_add_transform_0213<T>(
final_output,
final_output + (input.size(0) * input.size(1) * input_cont.size(2)),
final_output + (input.size(0) * input.size(1) * 2 * input_cont.size(2)),
workspace,
nullptr,
input.size(0),
input.size(1),
0,
input.size(1),
input_cont.size(2),
num_heads,
-1,
false,
false,
Context::Instance().GetCurrentStream(),
3,
input.size(1));
return at::from_blob(
final_output, {3, input.size(0), num_heads, input.size(1), head_size}, options);
// return at::from_blob(workspace, {input.size(0) * input.size(1), 3, num_heads,
// head_size}, options);
}

} else
return output;
}

return output;
template <typename T>
std::vector<at::Tensor> add_padding(at::Tensor& query, at::Tensor& key, at::Tensor& value)
{
int head_size = query.size(3);
int padded_head_size = head_size < 32 ? 32 : (head_size < 64 ? 64 : 128);
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * query.size(1) * query.size(2);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * query.size(1) * 128;
pad_head_seq(workspace,
(T*)query.data_ptr(),
query.size(0) * query.size(1),
query.size(2),
query.size(2),
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
pad_head_seq(key_pad_ptr,
(T*)key.data_ptr(),
query.size(0) * query.size(1),
key.size(2),
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
pad_head_seq(value_pad_ptr,
(T*)value.data_ptr(),
query.size(0) * query.size(1),
key.size(2),
128,
head_size,
padded_head_size,
Context::Instance().GetCurrentStream());
return {
at::from_blob(workspace,
{query.size(0), query.size(1), query.size(2), padded_head_size},
query.options()),
at::from_blob(
key_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options()),
at::from_blob(
value_pad_ptr, {query.size(0), query.size(1), 128, padded_head_size}, query.options())};
}

template <typename T>
std::vector<at::Tensor> padd_add_transform(at::Tensor& query,
at::Tensor& key,
at::Tensor& value,
int heads,
bool add_padding)
{
int head_size = query.size(2) / heads;
int key_value_length = add_padding ? 128 : key.size(1);
int padded_head_size = add_padding ? (head_size < 32 ? 32 : (head_size < 64 ? 64 : 128))
: head_size;
T* workspace = (T*)Context::Instance().GetWorkSpace();
T* key_pad_ptr = workspace + padded_head_size * query.size(0) * heads * query.size(1);
T* value_pad_ptr = key_pad_ptr + padded_head_size * query.size(0) * heads * key_value_length;
launch_pad_add_transform_0213(workspace,
(T*)query.data_ptr(),
query.size(0),
query.size(2),
query.size(1),
query.size(1),
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
launch_pad_add_transform_0213(key_pad_ptr,
(T*)key.data_ptr(),
key.size(0),
key.size(2),
key.size(1),
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
launch_pad_add_transform_0213(value_pad_ptr,
(T*)value.data_ptr(),
value.size(0),
value.size(2),
value.size(1),
key_value_length,
heads,
padded_head_size,
Context::Instance().GetCurrentStream());
return {
at::from_blob(
workspace, {query.size(0), heads, query.size(1), padded_head_size}, query.options()),
at::from_blob(key_pad_ptr,
{query.size(0), heads, key_value_length, padded_head_size},
query.options()),
at::from_blob(value_pad_ptr,
{query.size(0), heads, key_value_length, padded_head_size},
query.options())};
}
template <typename T>
at::Tensor ds_linear_layer_int8(at::Tensor& input,
at::Tensor& weight,
Expand Down Expand Up @@ -1414,6 +1589,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
&einsum_sec_sm_ecm<__half>,
"DeepSpeed vector-MM with fp16 (CUDA)");
m.def("moe_res_matmul", &moe_res_matmul, "DeepSpeed moe residual matmul (CUDA)");
m.def("add_padding_fp32", &add_padding<float>, "DeepSpeed residual add with fp32 (CUDA)");
m.def("add_padding_fp16", &add_padding<__half>, "DeepSpeed residual add with fp16 (CUDA)");
m.def("pad_transform_fp32",
&padd_add_transform<float>,
"DeepSpeed residual add with fp32 (CUDA)");
m.def("pad_transform_fp16",
&padd_add_transform<__half>,
"DeepSpeed residual add with fp16 (CUDA)");
m.def("allocate_workspace_fp32",
&allocate_workspace<float>,
"DeepSpeed memory allocation for GPT inference with fp32 (CUDA)");
Expand Down
5 changes: 3 additions & 2 deletions csrc/transformer/inference/csrc/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Copyright 2022 The Microsoft DeepSpeed Team
#include <cstdlib>
#include <ctime>

#define ATTN_THREADS 1024
#define ATTN_THREADS 256
#define MAX_REG_SIZE 8

#define minus_infinity -10000.0
Expand Down Expand Up @@ -427,7 +427,8 @@ void launch_attn_softmax_v2(T* vals,
cudaStream_t stream)
{
int total_count = batch_size * heads * num_seq;
dim3 grid_dim((total_count - 1) / (WARP_SIZE / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
int warp_num = ATTN_THREADS / WARP_SIZE;
dim3 grid_dim((total_count - 1) / (warp_num / ((sequence_length - 1) / ATTN_THREADS + 1)) + 1);
dim3 block_dim(ATTN_THREADS);

const int reduce_width = ((sequence_length - 1) / ATTN_THREADS + 1) * WARP_SIZE;
Expand Down
Loading