diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index affe3b0dd3f7..fee425215b52 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -4,6 +4,7 @@ Copyright 2022 The Microsoft DeepSpeed Team #include #include +#include #include #include "inference_context.h" #include "inference_cublas_wrappers.h" @@ -16,6 +17,41 @@ std::array gemm_algos = std::array({99, 99, 99}); // will be incorrect. enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2 }; +enum class TransformerType : uint8_t { UNKNOWN = 0, GPTType = 1, BERTType = 2 }; + +// NOTE: this is a temporary and dodgy solution to distinguish GPT and BERT style models +// based on the dimensions of the corresponding attention mask. +inline auto infer_transformer_type(at::Tensor& attn_mask) -> TransformerType +{ + auto attn_mask_num_dims = attn_mask.sizes().size(); + + if (attn_mask_num_dims > 2) { + return TransformerType::GPTType; + } else if (attn_mask_num_dims == 2) { + return TransformerType::BERTType; + } else { + return TransformerType::UNKNOWN; + } +} + +// infer stride of attention mask memory layout based on the model type. +inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int +{ + auto trnsfrmr_type = infer_transformer_type(attn_mask); + + if (trnsfrmr_type == TransformerType::GPTType) { + return attn_mask.size(2); + } else if (trnsfrmr_type == TransformerType::BERTType) { + // Bert style models have always a mask stride of 1. + return 1; + } else if (trnsfrmr_type == TransformerType::UNKNOWN) { + throw std::runtime_error("Unknown transformer type."); + } + + // this is just to make the compiler happy. + return 0; +} + template at::Tensor ds_softmax(at::Tensor& attn_scores, at::Tensor& attn_mask, @@ -42,8 +78,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores, int heads = 1; if (len > 1) heads = attn_scores_c.size(1); - int mask_stride = 1; - if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2); + auto mask_stride = get_attn_mask_stride(attn_mask); launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(), (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), @@ -142,8 +177,9 @@ void attention_unfused(at::Tensor& prev_key_cont, float gemm_beta = 0.0; auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options); int k = prev_value_cont.size(2) / heads; - int mask_stride = heads; - if (attn_mask.sizes().size() > 2 && attn_mask.size(2) == 1) mask_stride *= seq_len; + + auto mask_stride = get_attn_mask_stride(attn_mask); + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), soft_len, @@ -271,8 +307,8 @@ void ds_softmax_internal(T* attn_scores, int soft_len, int heads) { - int mask_stride = 1; - if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2); + auto mask_stride = get_attn_mask_stride(attn_mask); + launch_attn_softmax_v2((T*)attn_scores, (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), (alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr), diff --git a/csrc/transformer/inference/csrc/softmax.cu b/csrc/transformer/inference/csrc/softmax.cu index cc7c784913d7..ce7c2e77759d 100644 --- a/csrc/transformer/inference/csrc/softmax.cu +++ b/csrc/transformer/inference/csrc/softmax.cu @@ -274,7 +274,7 @@ __global__ void attn_softmax_v2(float* vals, int batch_idx = iter_offset / (num_seq * heads); int alibi_offset = batch_idx * heads * mp_size + head_offset; int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride); - + mask_offset = mask_offset * sequence_length; int seq_id = iter_offset % num_seq; int seq_id4 = seq_id >> 2; @@ -305,7 +305,7 @@ __global__ void attn_softmax_v2(float* vals, (data_id + 3) > window_stride) ? vals[data_id + 3] : minus_infinity; - if (attn_mask && recompute) { + if (attn_mask) { data[i].x += attn_mask[data_id + mask_offset]; data[i].y += attn_mask[data_id + mask_offset + 1]; data[i].z += attn_mask[data_id + mask_offset + 2]; @@ -322,7 +322,7 @@ __global__ void attn_softmax_v2(float* vals, ? (vals[data_id + 2]) : minus_infinity; data[i].w = minus_infinity; - if (attn_mask && recompute) { + if (attn_mask) { data[i].x += attn_mask[data_id + mask_offset]; if ((data_id + 1) < sequence_length) data[i].y += attn_mask[data_id + mask_offset + 1];