Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 42 additions & 6 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Copyright 2022 The Microsoft DeepSpeed Team

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <stdexcept>
#include <vector>
#include "inference_context.h"
#include "inference_cublas_wrappers.h"
Expand All @@ -16,6 +17,41 @@ std::array<int, 3> gemm_algos = std::array<int, 3>({99, 99, 99});
// will be incorrect.
enum class ActivationFuncType { UNKNOWN = 0, GELU = 1, ReLU = 2 };

enum class TransformerType : uint8_t { UNKNOWN = 0, GPTType = 1, BERTType = 2 };

// NOTE: this is a temporary and dodgy solution to distinguish GPT and BERT style models
// based on the dimensions of the corresponding attention mask.
inline auto infer_transformer_type(at::Tensor& attn_mask) -> TransformerType
{
auto attn_mask_num_dims = attn_mask.sizes().size();

if (attn_mask_num_dims > 2) {
return TransformerType::GPTType;
} else if (attn_mask_num_dims == 2) {
return TransformerType::BERTType;
} else {
return TransformerType::UNKNOWN;
}
}

// infer stride of attention mask memory layout based on the model type.
inline auto get_attn_mask_stride(at::Tensor& attn_mask) -> int
{
auto trnsfrmr_type = infer_transformer_type(attn_mask);

if (trnsfrmr_type == TransformerType::GPTType) {
return attn_mask.size(2);
} else if (trnsfrmr_type == TransformerType::BERTType) {
// Bert style models have always a mask stride of 1.
return 1;
} else if (trnsfrmr_type == TransformerType::UNKNOWN) {
throw std::runtime_error("Unknown transformer type.");
}

// this is just to make the compiler happy.
return 0;
}

template <typename T>
at::Tensor ds_softmax(at::Tensor& attn_scores,
at::Tensor& attn_mask,
Expand All @@ -42,8 +78,7 @@ at::Tensor ds_softmax(at::Tensor& attn_scores,
int heads = 1;
if (len > 1) heads = attn_scores_c.size(1);

int mask_stride = 1;
if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2);
auto mask_stride = get_attn_mask_stride(attn_mask);

launch_attn_softmax_v2((T*)attn_scores_c.data_ptr(),
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
Expand Down Expand Up @@ -142,8 +177,9 @@ void attention_unfused(at::Tensor& prev_key_cont,
float gemm_beta = 0.0;
auto attn_score = at::empty({bsz, heads, seq_len, soft_len}, options);
int k = prev_value_cont.size(2) / heads;
int mask_stride = heads;
if (attn_mask.sizes().size() > 2 && attn_mask.size(2) == 1) mask_stride *= seq_len;

auto mask_stride = get_attn_mask_stride(attn_mask);

cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream());
cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(),
soft_len,
Expand Down Expand Up @@ -271,8 +307,8 @@ void ds_softmax_internal(T* attn_scores,
int soft_len,
int heads)
{
int mask_stride = 1;
if (attn_mask.sizes().size() > 2) mask_stride = attn_mask.size(2);
auto mask_stride = get_attn_mask_stride(attn_mask);

launch_attn_softmax_v2((T*)attn_scores,
(attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr),
(alibi.sizes().size() > 1 ? (T*)alibi.data_ptr() : nullptr),
Expand Down
6 changes: 3 additions & 3 deletions csrc/transformer/inference/csrc/softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ __global__ void attn_softmax_v2(float* vals,
int batch_idx = iter_offset / (num_seq * heads);
int alibi_offset = batch_idx * heads * mp_size + head_offset;
int mask_offset = batch_idx * mask_stride + (iter_offset % mask_stride);

mask_offset = mask_offset * sequence_length;
int seq_id = iter_offset % num_seq;
int seq_id4 = seq_id >> 2;

Expand Down Expand Up @@ -305,7 +305,7 @@ __global__ void attn_softmax_v2(float* vals,
(data_id + 3) > window_stride)
? vals[data_id + 3]
: minus_infinity;
if (attn_mask && recompute) {
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
data[i].y += attn_mask[data_id + mask_offset + 1];
data[i].z += attn_mask[data_id + mask_offset + 2];
Expand All @@ -322,7 +322,7 @@ __global__ void attn_softmax_v2(float* vals,
? (vals[data_id + 2])
: minus_infinity;
data[i].w = minus_infinity;
if (attn_mask && recompute) {
if (attn_mask) {
data[i].x += attn_mask[data_id + mask_offset];
if ((data_id + 1) < sequence_length)
data[i].y += attn_mask[data_id + mask_offset + 1];
Expand Down