Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
No tensor cores for fp32 interleaved attention, remove div by 8 restr…
Browse files Browse the repository at this point in the history
…iction (#17994) (#18085)

(cherry picked from commit afae030)
  • Loading branch information
blchu committed Apr 16, 2020
1 parent b56571d commit 8cfc64a
Showing 1 changed file with 37 additions and 16 deletions.
53 changes: 37 additions & 16 deletions src/operator/contrib/transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void CublasStridedBatchedGemm(mshadow::Stream<gpu>* s, bool transA, bool transB,
float alpha, const DType* a, int32_t lda, int32_t strideA,
const DType *b, int32_t ldb, int32_t strideB, float beta,
DType *c, int32_t ldc, int32_t strideC, int32_t batchCount,
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) {
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT) {
#if CUDA_VERSION >= 9010
using namespace mxnet::common::cuda;
CHECK_EQ(s->blas_handle_ownership_, mshadow::Stream<gpu>::OwnHandle)
Expand Down Expand Up @@ -142,9 +142,9 @@ void gemm_switch_fp32accum(mshadow::Stream<gpu>* s, bool transA, bool transB,
float alpha, const DType *a, int32_t lda,
int32_t strideA, const DType *b, int32_t ldb,
int32_t strideB, float beta, DType *c, int32_t ldc,
int32_t strideC, int32_t batchCount) {
int32_t strideC, int32_t batchCount, bool using_fp16) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) {
if (using_fp16) {
CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb,
strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_DEFAULT_TENSOR_OP);
} else {
Expand Down Expand Up @@ -175,6 +175,7 @@ void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride = 3 * head_dim;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -196,7 +197,8 @@ void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
output,
qkv_seq_len,
qkv_seq_len * qkv_seq_len,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -220,7 +222,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t lead_dim = attn_batches * 3 * head_dim;
const int32_t batch_stride = 3 * head_dim;
const float scale = 1.0 / sqrt(static_cast<float>(head_dim));
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -247,7 +250,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
queries_keys_values_grads,
lead_dim,
batch_stride,
attn_batches);
attn_batches,
using_fp16);
gemm_switch_fp32accum(s,
false,
true,
Expand All @@ -265,7 +269,8 @@ void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs,
queries_keys_values_grads + head_dim,
lead_dim,
batch_stride,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -290,6 +295,7 @@ void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride = 3 * head_dim;
const float alpha = 1.f;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -311,7 +317,8 @@ void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
output,
head_dim * attn_batches,
head_dim,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -337,6 +344,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t lead_dim = attn_batches * 3 * head_dim;
const int32_t batch_stride = 3 * head_dim;
const float alpha = 1.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] != kNullOp) {
if (req[0] == kWriteTo) {
cudaMemsetAsync(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType),
Expand All @@ -360,7 +369,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
queries_keys_values_grads + 2 * head_dim,
lead_dim,
batch_stride,
attn_batches);
attn_batches,
using_fp16);
}
if (req[1] != kNullOp) {
const float beta = req[1] == kAddTo ? 1.f : 0.f;
Expand All @@ -381,7 +391,8 @@ void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs,
attention_maps_grads,
qkv_seq_len,
qkv_seq_len * qkv_seq_len,
attn_batches);
attn_batches,
using_fp16);
}
})
}
Expand Down Expand Up @@ -412,6 +423,7 @@ void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride_kv = head_dim * 2;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const float scale = 1.f / sqrt(static_cast<float>(head_dim));
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -433,7 +445,8 @@ void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
output,
kv_seq_len,
kv_seq_len * q_seq_len,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand Down Expand Up @@ -463,6 +476,7 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride_q = head_dim;
const int32_t batch_stride_kv = head_dim * 2;
const float scale = 1.f / sqrt(static_cast<float>(head_dim));
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] != kNullOp) {
const float beta = req[0] == kAddTo ? 1.f : 0.f;
Expand All @@ -483,7 +497,8 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
queries_grads,
lead_dim_q,
batch_stride_q,
attn_batches);
attn_batches,
using_fp16);
}
if (req[1] != kNullOp) {
if (req[1] == kWriteTo) {
Expand All @@ -508,7 +523,8 @@ void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs,
keys_values_grads,
lead_dim_kv,
batch_stride_kv,
attn_batches);
attn_batches,
using_fp16);
}
})
}
Expand All @@ -535,6 +551,7 @@ void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t batch_stride_kv = 2 * head_dim;
const float alpha = 1.f;
const float beta = req[0] == kAddTo ? 1.f : 0.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] == kNullOp)
return;
Expand All @@ -556,7 +573,8 @@ void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
output,
head_dim * attn_batches,
head_dim,
attn_batches);
attn_batches,
using_fp16);
})
}

Expand All @@ -583,6 +601,7 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
const int32_t lead_dim_kv = attn_batches * head_dim * 2;
const int32_t batch_stride_kv = 2 * head_dim;
const float alpha = 1.f;
const bool using_fp16 = inputs[0].type_flag_ == mshadow::kFloat16;

if (req[0] != kNullOp) {
if (req[0] == kWriteTo) {
Expand All @@ -607,7 +626,8 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
keys_values_grads + head_dim,
lead_dim_kv,
batch_stride_kv,
attn_batches);
attn_batches,
using_fp16);
}
if (req[1] != kNullOp) {
const float beta = req[1] == kAddTo ? 1.f : 0.f;
Expand All @@ -628,7 +648,8 @@ void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs,
attention_maps_grads,
kv_seq_len,
kv_seq_len * q_seq_len,
attn_batches);
attn_batches,
using_fp16);
}
})
}
Expand Down

0 comments on commit 8cfc64a

Please sign in to comment.