Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7cf7c89
Add Inference support for running the BigScience-BLOOM Architecture
Jul 9, 2022
6d33b82
Merge branch 'master' into ds-inference/bloom-support
jeffra Jul 11, 2022
6cc2340
formatting
jeffra Jul 11, 2022
2a69756
add the checkpoint loading at the same of kernel-injection
Jul 12, 2022
e449ac7
Merge branch 'ds-inference/bloom-support' of github.com:microsoft/Dee…
Jul 12, 2022
6e304aa
releasing checkpoint CPU-memory after loading it
Jul 12, 2022
16dd1dc
some fixes plus formatting
Jul 12, 2022
e0dd488
fix layer_past; this caused issues when running inference on several …
Jul 13, 2022
02f9945
Add support for multi-batch inference
Jul 13, 2022
1cef202
fix the padding issue for large bach inference
Jul 14, 2022
644fea4
fixing some bug in softmax kernel for batch_size>1
Jul 14, 2022
adb0b97
align alibi&mask addition with HF new changes
Jul 15, 2022
bd3c0a0
revert back some changes and support for very large batch size
Jul 15, 2022
1f92e55
reduce the max_token_length for now
Jul 15, 2022
5d1f351
fix mask-adding
Jul 15, 2022
3c12b89
fix the large-batch inference for MP > 1
Jul 16, 2022
a5bdd58
Merge branch 'master' into ds-inference/bloom-support
RezaYazdaniAminabadi Jul 18, 2022
aa5e01f
Ds inference/bloom support meta (#2104)
jeffra Jul 18, 2022
b6503ed
fix the Bert and GPT-J unit tests
Jul 18, 2022
a9459d6
fix for OneDevice
jeffra Jul 18, 2022
2cd301e
Merge branch 'master' into ds-inference/bloom-support
jeffra Jul 18, 2022
72aba56
added bloom inference tests
mrwyattii Jul 18, 2022
332f69d
fixing the masking stride for the GPT models
Jul 18, 2022
0ddf41c
Merge branch 'ds-inference/bloom-support' of github.com:microsoft/Dee…
Jul 18, 2022
d38464e
revert back some changes on replace_module
Jul 18, 2022
ac2d092
fix fp32 softmax
Jul 18, 2022
1e3ea74
allocate tensors initially on cpu at inference-api
Jul 18, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions csrc/transformer/inference/csrc/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ __global__ void fused_bias_residual(float* input,
float* attnbias,
int total_count,
int intermediate_size,
int mp_size,
float mp_scale,
bool preln)
{
float4* input_cast = reinterpret_cast<float4*>(input);
Expand All @@ -191,10 +191,10 @@ __global__ void fused_bias_residual(float* input,
float4 bias_data = bias_cast[offset % intermediate_size];
float4 attn_bias = attnbias_cast[offset % intermediate_size];
if (preln) {
data.x = (data.x + res_vec.x) * mp_size + (out.x + bias_data.x + attn_bias.x);
data.y = (data.y + res_vec.y) * mp_size + (out.y + bias_data.y + attn_bias.y);
data.z = (data.z + res_vec.z) * mp_size + (out.z + bias_data.z + attn_bias.z);
data.w = (data.w + res_vec.w) * mp_size + (out.w + bias_data.w + attn_bias.w);
data.x = (data.x + res_vec.x + bias_data.x + attn_bias.x) * mp_scale + (out.x);
data.y = (data.y + res_vec.y + bias_data.y + attn_bias.y) * mp_scale + (out.y);
data.z = (data.z + res_vec.z + bias_data.z + attn_bias.z) * mp_scale + (out.z);
data.w = (data.w + res_vec.w + bias_data.w + attn_bias.w) * mp_scale + (out.w);
} else {
data.x = data.x + out.x + bias_data.x;
data.y = data.y + out.y + bias_data.y;
Expand All @@ -212,7 +212,7 @@ __global__ void fused_bias_residual(__half* input,
__half* attn_bias,
int total_count,
int intermediate_size,
int mp_size,
float mp_scale,
bool preln)
{
#ifdef HALF_PRECISION_AVAILABLE
Expand Down Expand Up @@ -257,13 +257,13 @@ __global__ void fused_bias_residual(__half* input,

if (preln) {
low_data.x =
(low_data.x + low_res.x) * mp_size + (low_out.x + (low_bias.x + attn_low_bias.x));
(low_data.x + low_res.x + (low_bias.x + attn_low_bias.x)) * mp_scale + low_out.x;
low_data.y =
(low_data.y + low_res.y) * mp_size + (low_out.y + (low_bias.y + attn_low_bias.y));
high_data.x = (high_data.x + high_res.x) * mp_size +
(high_out.x + (high_bias.x + attn_high_bias.x));
high_data.y = (high_data.y + high_res.y) * mp_size +
(high_out.y + (high_bias.y + attn_high_bias.y));
(low_data.y + low_res.y + (low_bias.y + attn_low_bias.y)) * mp_scale + low_out.y;
high_data.x = (high_data.x + high_res.x + (high_bias.x + attn_high_bias.x)) * mp_scale +
high_out.x;
high_data.y = (high_data.y + high_res.y + (high_bias.y + attn_high_bias.y)) * mp_scale +
high_out.y;
} else {
low_data.x = (low_data.x + low_out.x + low_bias.x);
low_data.y = (low_data.y + low_out.y + low_bias.y);
Expand Down Expand Up @@ -310,7 +310,7 @@ __global__ void gptj_residual_add(float* input,
float* attnbias,
int total_count,
int intermediate_size,
float mp_size)
float mp_scale)
{
float4* input_cast = reinterpret_cast<float4*>(input);
float4* output_cast = reinterpret_cast<float4*>(output);
Expand All @@ -332,10 +332,10 @@ __global__ void gptj_residual_add(float* input,
data.z += attn_bias.z;
data.w += attn_bias.w;
}
data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x);
data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y);
data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z);
data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w);
data.x = data.x * mp_scale + (out.x + res_vec.x + bias_data.x);
data.y = data.y * mp_scale + (out.y + res_vec.y + bias_data.y);
data.z = data.z * mp_scale + (out.z + res_vec.z + bias_data.z);
data.w = data.w * mp_scale + (out.w + res_vec.w + bias_data.w);

output_cast[offset] = data;
}
Expand All @@ -348,7 +348,7 @@ __global__ void gptj_residual_add(__half* input,
__half* attn_bias,
int total_count,
int intermediate_size,
float mp_size)
float mp_scale)
{
#if __CUDA_ARCH__ >= 700 || defined(__HIP_PLATFORM_HCC__)

Expand Down Expand Up @@ -395,10 +395,10 @@ __global__ void gptj_residual_add(__half* input,
high_data.y += attn_high_bias.y;
}

low_data.x = low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x));
low_data.y = low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y));
high_data.x = high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x));
high_data.y = high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y));
low_data.x = low_data.x * mp_scale + (low_out.x + low_res.x + (low_bias.x));
low_data.y = low_data.y * mp_scale + (low_out.y + low_res.y + (low_bias.y));
high_data.x = high_data.x * mp_scale + (high_out.x + high_res.x + (high_bias.x));
high_data.y = high_data.y * mp_scale + (high_out.y + high_res.y + (high_bias.y));

vals_half[0] = __float22half2_rn(low_data);
vals_half[1] = __float22half2_rn(high_data);
Expand Down
1 change: 1 addition & 0 deletions csrc/transformer/inference/csrc/normalize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ __global__ void fused_bias_residual_layer_norm(__half* output,
int row_stride)
{
#ifdef HALF_PRECISION_AVAILABLE

int iteration_stride = blockDim.x;
int iterations = row_stride / iteration_stride;

Expand Down
67 changes: 55 additions & 12 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,36 @@ std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
at::Tensor& alibi,
bool triangular,
bool recompute,
bool local_attention,
int window_size,
bool async_op)
bool async_op,
float layer_scale,
int head_offset,
int mp_size)
{
auto attn_scores_c = attn_scores.contiguous();
int bsz = attn_scores_c.size(0);

int seq_len = attn_scores_c.size(1);
int len = attn_scores_c.sizes().size();
if (len > 3) seq_len = attn_scores_c.size(2);
if (len > 2) seq_len = attn_scores_c.size(2);

int soft_len = attn_scores_c.size(2);
if (len > 3) soft_len = attn_scores_c.size(3);

int heads = 1;
if (len > 3) heads = attn_scores_c.size(1);
if (len > 1) heads = attn_scores_c.size(1);

int mask_stride = 1;
if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2);

launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
(alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr),
layer_scale,
triangular,
recompute,
local_attention,
Expand All @@ -40,7 +49,9 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
heads,
seq_len,
soft_len,
1.0,
head_offset,
mask_stride,
mp_size,
Context::Instance().GetCurrentStream(async_op));

return attn_scores_c;
Expand Down Expand Up @@ -123,6 +134,8 @@ void attention_unfused(at::Tensor& prev_key_cont,
float gemm_beta = 0.0;
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
int k = prev_value_cont.size(2) / heads;
int mask_stride = heads;
if (attn_mask.sizes().size() > 2 && attn_mask.size(2) == 1) mask_stride *= seq_len;
cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
Expand All @@ -144,8 +157,22 @@ void attention_unfused(at::Tensor& prev_key_cont,
#else
CUBLAS_GEMM_DEFAULT_TENSOR_OP);
#endif
attn_score = ds_softmax<T>(
attn_score, attn_mask, triangular, recompute, local_attention, window_size, false);
launch_attn_softmax_v2((T*)attn_score.data_ptr(),
(T*)(attn_mask.sizes().size() > 1 ? attn_mask.data_ptr() : nullptr),
(T*)nullptr,
1.0,
triangular,
recompute,
local_attention,
window_size,
bsz,
heads,
seq_len,
soft_len,
0,
mask_stride,
1,
Context::Instance().GetCurrentStream(false));
alpha = 1.0;
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
k,
Expand Down Expand Up @@ -225,6 +252,8 @@ std::vector<at::Tensor> ds_softmax_context1(at::Tensor& query,
template <typename T>
void ds_softmax_internal(T* attn_scores,
at::Tensor& attn_mask,
at::Tensor& alibi,
float& layer_scale,
bool triangular,
bool recompute,
bool local_attention,
Expand All @@ -234,8 +263,12 @@ void ds_softmax_internal(T* attn_scores,
int soft_len,
int heads)
{
int mask_stride = 1;
if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2);
launch_attn_softmax_v2((T*)attn_scores,
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
(alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr),
layer_scale,
triangular,
recompute,
local_attention,
Expand All @@ -244,7 +277,9 @@ void ds_softmax_internal(T* attn_scores,
heads,
seq_len,
soft_len,
1.0,
0,
mask_stride,
1,
at::cuda::getCurrentCUDAStream());
}

Expand All @@ -263,9 +298,12 @@ void attention_unfused(T* prev_key_cont,
bool triangular,
bool recompute,
bool local_attention,
int window_size)
int window_size,
at::Tensor& alibi,
int layer_id)
{
float alpha = norm_factor * norm_factor;
float layer_scale = alibi.sizes().size() > 1 ? std::max(1, layer_id) : 1.0;
float alpha = norm_factor * norm_factor / layer_scale;
float gemm_beta = 0.0;
T* workspace = (T*)output + bsz * seq_len * heads * k;

Expand All @@ -292,6 +330,8 @@ void attention_unfused(T* prev_key_cont,
#endif
ds_softmax_internal<T>(workspace,
attn_mask,
alibi,
layer_scale,
triangular,
recompute,
local_attention,
Expand Down Expand Up @@ -336,7 +376,8 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
int window_size,
bool no_masking,
unsigned layer_id,
unsigned num_layers)
unsigned num_layers,
at::Tensor& alibi)
{
unsigned bsz = query_key_value.size(0);
unsigned seq_len = query_key_value.size(1);
Expand Down Expand Up @@ -410,7 +451,9 @@ std::vector<at::Tensor> ds_softmax_context(at::Tensor& query_key_value,
(triangular && is_prompt),
is_prompt,
local_attention,
window_size);
window_size,
alibi,
layer_id);
launch_transform4d_0213<T>((T*)output.data_ptr(),
temp_buf,
bsz,
Expand Down Expand Up @@ -506,7 +549,7 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output,
{
int bsz = input.size(0) * input.size(1);
T* workspace = (T*)Context::Instance().GetWorkSpace();
workspace += (3 * input.size(0) * MAX_OUT_TOKES * input.size(2));
workspace += (3 * bsz * input.size(2));
ds_layernorm_internal<T>(workspace, input, gamma, beta, epsilon);
// cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream());

Expand Down
Loading