diff --git a/src/operator/contrib/transformer-inl.h b/src/operator/contrib/transformer-inl.h index da3d14e33cf4..f83aa0bbda2e 100644 --- a/src/operator/contrib/transformer-inl.h +++ b/src/operator/contrib/transformer-inl.h @@ -34,6 +34,18 @@ namespace mxnet { namespace op { +struct InterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool bwd_ignore_zero_init; + DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) { + DMLC_DECLARE_FIELD(heads) + .describe("Set number of heads"); + DMLC_DECLARE_FIELD(bwd_ignore_zero_init) + .describe("Make backward pass ignore AddTo and not init to 0.") + .set_default(false); + } +}; + template static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs, const OpContext& ctx, diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index 00085c0dc7aa..411c1c9737d7 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -29,6 +29,163 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); + +static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1); + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3); + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; +} + +static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + CHECK_EQ(in_shape->size(), 2); + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3); + CHECK_EQ(att_shape.ndim(), 3); + CHECK_EQ(qkv_shape[0], att_shape[1]); + CHECK_EQ(qkv_shape[0], att_shape[2]); + CHECK_EQ(qkv_shape[2] % 3, 0); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + return true; +} + +static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2); + auto q_shape = in_shape->at(0); + auto kv_shape = in_shape->at(1); + CHECK_EQ(q_shape.ndim(), 3); + CHECK_EQ(kv_shape.ndim(), 3); + CHECK_EQ(q_shape[2] * 2, kv_shape[2]); + CHECK_EQ(q_shape[1], kv_shape[1]); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]})); + return true; +} + +static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2); + auto kv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(kv_shape[0], att_shape[2]); + CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2})); + return true; +} + +NNVM_REGISTER_OP(interleaved_matmul_selfatt_qk) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"queries_keys_values"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulSelfAttQKShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"}) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + +NNVM_REGISTER_OP(interleaved_matmul_selfatt_valatt) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"queries_keys_values", "attention"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulSelfAttValAttShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"}) +.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + +NNVM_REGISTER_OP(interleaved_matmul_encdec_qk) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"queries", "keys_values"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulEncDecQKShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"}) +.add_argument("queries", "NDArray-or-Symbol", "Queries") +.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + +NNVM_REGISTER_OP(interleaved_matmul_encdec_valatt) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", [](const NodeAttrs& attrs) { + return std::vector{"keys_values", "attention"}; +}) +.set_attr("FListOutputNames", [](const NodeAttrs& attrs) { + return std::vector{"output"}; +}) +.set_attr("FInferShape", InterleavedMatMulEncDecValAttShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FGradient", + ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_valatt"}) +.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") +.add_argument("attention", "NDArray-or-Symbol", "Attention maps") +.add_arguments(InterleavedMatMulParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr_parser(ParamParser); + + // relu MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim) .describe(R"code(Rescale the input by the square root of the channel dimension. diff --git a/src/operator/contrib/transformer.cu b/src/operator/contrib/transformer.cu index 6ed073db6011..812cc78781b0 100644 --- a/src/operator/contrib/transformer.cu +++ b/src/operator/contrib/transformer.cu @@ -22,12 +22,898 @@ * \file transformer.cu * \brief GPU implementation of the operators used in Transformer */ + +#include +#include +#include +#include + #include #include "./transformer-inl.h" +#include "../../common/cuda_utils.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/wmma_matrix.h" +#ifdef CUTLASS_USE_WMMA_API +#include "cutlass/gemm/wmma_gemm_traits.h" namespace mxnet { namespace op { +// gemm_switch_fp32accum and the functions called are almost fully copied from: +// MLPerf v0.6 submission repository from NVIDIA by https://github.com/kevinstephano +template +void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, + int32_t m, int32_t n, int32_t k, + float alpha, const DType* a, int32_t lda, int32_t strideA, + const DType *b, int32_t ldb, int32_t strideB, float beta, + DType *c, int32_t ldc, int32_t strideC, int32_t batchCount, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { + using namespace mxnet::common::cuda; + CHECK_EQ(s->blas_handle_ownership_, mshadow::Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + + cublasHandle_t blas_handle = mshadow::Stream::GetBlasHandle(s); + auto err = CUBLAS_STATUS_SUCCESS; + // TODO(cfujitsang): handle computation_precision + err = cublasGemmStridedBatchedEx( + blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), + static_cast(m), static_cast(n), static_cast(k), + reinterpret_cast(&alpha), + a, CublasType::kCudaFlag, static_cast(lda), strideA, + b, CublasType::kCudaFlag, static_cast(ldb), strideB, + reinterpret_cast(&beta), + c, CublasType::kCudaFlag, static_cast(ldc), strideC, + static_cast(batchCount), CUDA_R_32F, algo); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail."; +} + +template<::cutlass::MatrixLayout::Kind A_LAYOUT, + ::cutlass::MatrixLayout::Kind B_LAYOUT, + int SRC_A, int SRC_B, int DST_C, typename DType> +void CutlassGemm_FP32Accum(cudaStream_t, int32_t m, int32_t n, int32_t k, + float alpha, const DType *a, int32_t lda, + int32_t strideA, const DType *b, int32_t ldb, + int32_t strideB, float beta, DType *c, int32_t ldc, + int32_t strideC, int32_t batchCount) { + LOG(FATAL) << "Not implemented with this DType and shape (Cutlass)"; +} + + +template<::cutlass::MatrixLayout::Kind A_LAYOUT, + ::cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C> +void CutlassGemm_FP32Accum(cudaStream_t stream, int32_t m, int32_t n, int32_t k, + float alpha, const mshadow::half::half_t *a, int32_t lda, + int32_t strideA, const mshadow::half::half_t *b, int32_t ldb, + int32_t strideB, float beta, mshadow::half::half_t *c, int32_t ldc, + int32_t strideC, int32_t batchCount) { + typedef cutlass::gemm::WmmaGemmTraits< + A_LAYOUT, + B_LAYOUT, + cutlass::Shape<32, 16, 16>, + half, + half, + half, + cutlass::gemm::LinearScaling, + float, + typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp< + typename cutlass::Shape<32, 16, 16> >::Shape, + typename cutlass::Shape<16, 16, 16>, + SRC_A, // kScalarsPerLdgA_ + SRC_B, // kScalarsPerLdgB_ + SRC_A, // KScalarsPerLdsA_ + SRC_B, // KScalarsPerLdsB_ + DST_C, // kScalarsPerLdgCAndStgD_ + DST_C/2, // kScalarsPerStsD_ + DST_C/2 // kScalarsPerLdsD_ + > + WmmaGemmTraits; + + typedef cutlass::gemm::Gemm Gemm; + typename Gemm::Params params; + + + int result = params.initialize( + m, // M dimension for each batch + n, // N dimension for each batch + k, // K dimension for each batch + alpha, // scalar alpha + reinterpret_cast(a), + lda, + strideA, // distance in memory between the first element of neighboring batch + reinterpret_cast(b), + ldb, + strideB, // distance in memory between the first element of neighboring batch + beta, // scalar beta + reinterpret_cast<__half*>(c), // source matrix C + ldc, + strideC, // distance in memory between the first element of neighboring batch + reinterpret_cast<__half*>(c), // destination matrix C (may be different memory than C) + ldc, + strideC, // distance in memory between the first element of neighboring batch + batchCount); + + CHECK_EQ(result, 0) << "Failed to initialize CUTLASS Gemm::Params object."; + + // Launch the CUTLASS GEMM kernel. + Gemm::launch(params); +} + +template +void gemm_switch_fp32accum(mshadow::Stream* s, bool transA, bool transB, + int32_t m, int32_t n, int32_t k, + float alpha, const DType *a, int32_t lda, + int32_t strideA, const DType *b, int32_t ldb, + int32_t strideB, float beta, DType *c, int32_t ldc, + int32_t strideC, int32_t batchCount) { + using cutlass::MatrixLayout::kRowMajor; + using cutlass::MatrixLayout::kColumnMajor; + cudaStream_t stream = mshadow::Stream::GetStream(s); + if (transA && (!transB)) { + if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount); + } + } else if ((!transA) && (!transB)) { + if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount); + } + } else if ((!transA) && transB) { + if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x7)) { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount, CUBLAS_GEMM_ALGO0_TENSOR_OP); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x7) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x3) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x7) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x3) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x7)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x3)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else if (!(lda & 0x1) && !(ldb & 0x1) && !(ldc & 0x1)) { + CutlassGemm_FP32Accum(stream, m, n, k, alpha, a, lda, + strideA, b, ldb, strideB, beta, c, ldc, strideC, batchCount); + } else { + CublasStridedBatchedGemm(s, transA, transB, m, n, k, alpha, a, lda, strideA, b, ldb, + strideB, beta, c, ldc, strideC, batchCount); + } + } else { + LOG(FATAL) << "transA and transB are invalid"; + } + CHECK_CUDA_ERROR("Error at InterleavedMatMul"); +} + +// TODO(cfujitsang): use scale as optional ? +void InterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_dim = inputs[0].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + true, + false, + qkv_seq_len, + qkv_seq_len, + head_dim, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + queries_keys_values, + lead_dim, + batch_stride, + beta, + output, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); + }) +} + +void BackwardInterleavedMatMulSelfAttQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + DType* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_dim = inputs[1].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + const float beta = ((req[0] == kAddTo) && !params.bwd_ignore_zero_init) ? 1.f : 0.f; + + if (req[0] == kNullOp) + return; + + if (req[0] == kWriteTo && !params.bwd_ignore_zero_init) { + cudaMemsetAsync(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + + gemm_switch_fp32accum(s, + false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads, + lead_dim, + batch_stride, + attn_batches); + gemm_switch_fp32accum(s, + false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + head_dim, + lead_dim, + batch_stride, + attn_batches); + }) +} + +void InterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[1].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_dim = inputs[0].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); + }) +} + +void BackwardInterleavedMatMulSelfAttValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[2].FlatTo2D(s).dptr_; + DType* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + DType* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const int32_t qkv_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_dim = inputs[1].shape_[2]; + const int32_t embed_dim = output_lin_dim / 3; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim = attn_batches * 3 * head_dim; + const int32_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + if (req[0] != kNullOp) { + if (req[0] == kWriteTo && !params.bwd_ignore_zero_init) { + cudaMemsetAsync(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + const float beta = ((req[0] == kAddTo) && !params.bwd_ignore_zero_init) ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + 2 * head_dim, + lead_dim, + batch_stride, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + true, + false, + qkv_seq_len, + qkv_seq_len, + head_dim, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); + } + }) +} + + +void InterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* queries = inputs[0].FlatTo2D(s).dptr_; + const DType* keys_values = inputs[1].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t q_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_q_dim = inputs[0].shape_[2]; + const int32_t kv_seq_len = inputs[1].shape_[0]; + const int32_t output_lin_kv_dim = inputs[1].shape_[2]; + const int32_t embed_dim = output_lin_q_dim; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim_q = attn_batches * head_dim; + const int32_t lead_dim_kv = attn_batches * 2 * head_dim; + const int32_t batch_stride_q = head_dim; + const int32_t batch_stride_kv = head_dim * 2; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + true, + false, + kv_seq_len, + q_seq_len, + head_dim, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + queries, + lead_dim_q, + batch_stride_q, + beta, + output, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); + }) +} + +void BackwardInterleavedMatMulEncDecQKGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* queries = inputs[1].FlatTo2D(s).dptr_; + const DType* keys_values = inputs[2].FlatTo2D(s).dptr_; + DType* queries_grads = outputs[0].FlatTo2D(s).dptr_; + DType* keys_values_grads = outputs[1].FlatTo2D(s).dptr_; + const int32_t q_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_q_dim = inputs[1].shape_[2]; + const int32_t kv_seq_len = inputs[2].shape_[0]; + const int32_t output_lin_kv_dim = inputs[2].shape_[2]; + const int32_t embed_dim = output_lin_q_dim; + const int32_t head_dim = embed_dim / params.heads; + const int32_t attn_batches = params.heads * sequences; + const int32_t lead_dim_q = attn_batches * head_dim; + const int32_t lead_dim_kv = attn_batches * 2 * head_dim; + const int32_t batch_stride_q = head_dim; + const int32_t batch_stride_kv = head_dim * 2; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + if (req[0] != kNullOp) { + const float beta = req[0] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + false, + head_dim, + q_seq_len, + kv_seq_len, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + queries_grads, + lead_dim_q, + batch_stride_q, + attn_batches); + } + if (req[1] != kNullOp) { + if (req[1] == kWriteTo && !params.bwd_ignore_zero_init) { + cudaMemsetAsync(keys_values_grads, 0, outputs[1].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + const float beta = ((req[1] == kAddTo) && !params.bwd_ignore_zero_init) ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + true, + head_dim, + kv_seq_len, + q_seq_len, + scale, + queries, + lead_dim_q, + batch_stride_q, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } + }) +} + +void InterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* keys_values = inputs[0].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[1].FlatTo2D(s).dptr_; + DType* output = outputs[0].FlatTo2D(s).dptr_; + const int32_t kv_seq_len = inputs[0].shape_[0]; + const int32_t sequences = inputs[0].shape_[1]; + const int32_t output_lin_kv_dim = inputs[0].shape_[2]; + const int32_t attn_batches = inputs[1].shape_[0]; + const int32_t q_seq_len = inputs[1].shape_[1]; + const int32_t embed_dim = output_lin_kv_dim / 2; + int32_t head_dim = embed_dim / params.heads; + const int32_t lead_dim_kv = attn_batches * head_dim * 2; + const int32_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + if (req[0] == kNullOp) + return; + + gemm_switch_fp32accum(s, + false, + false, + head_dim, + q_seq_len, + kv_seq_len, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); + }) +} + +void BackwardInterleavedMatMulEncDecValAttGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + mshadow::Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + const DType* output_grads = inputs[0].FlatTo2D(s).dptr_; + const DType* keys_values = inputs[1].FlatTo2D(s).dptr_; + const DType* attention_maps = inputs[2].FlatTo2D(s).dptr_; + DType* keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + DType* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const int32_t kv_seq_len = inputs[1].shape_[0]; + const int32_t sequences = inputs[1].shape_[1]; + const int32_t output_lin_kv_dim = inputs[1].shape_[2]; + const int32_t attn_batches = inputs[2].shape_[0]; + const int32_t q_seq_len = inputs[2].shape_[1]; + const int32_t embed_dim = output_lin_kv_dim / 2; + int32_t head_dim = embed_dim / params.heads; + const int32_t lead_dim_kv = attn_batches * head_dim * 2; + const int32_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + + if (req[0] != kNullOp) { + if (req[0] == kWriteTo && !params.bwd_ignore_zero_init) { + cudaMemsetAsync(keys_values_grads, 0, outputs[0].shape_.Size() * sizeof(DType), + mshadow::Stream::GetStream(s)); + } + const float beta = ((req[0] == kAddTo) && !params.bwd_ignore_zero_init) ? 1.f : 0.f; + gemm_switch_fp32accum(s, + false, + true, + head_dim, + kv_seq_len, + q_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads + head_dim, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + gemm_switch_fp32accum(s, + true, + false, + kv_seq_len, + q_seq_len, + head_dim, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); + } + }) +} + +NNVM_REGISTER_OP(interleaved_matmul_selfatt_qk) +.set_attr("FCompute", InterleavedMatMulSelfAttQKGPU); + +NNVM_REGISTER_OP(interleaved_matmul_selfatt_valatt) +.set_attr("FCompute", InterleavedMatMulSelfAttValAttGPU); + +NNVM_REGISTER_OP(interleaved_matmul_encdec_qk) +.set_attr("FCompute", InterleavedMatMulEncDecQKGPU); + +NNVM_REGISTER_OP(interleaved_matmul_encdec_valatt) +.set_attr("FCompute", InterleavedMatMulEncDecValAttGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttQKGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttValAttGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecQKGPU); + +NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecValAttGPU); + // relu NNVM_REGISTER_OP(_contrib_div_sqrt_dim) .set_attr("FCompute", DivSqrtDimForward_); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 0f1cd93755c3..ee20cec25dbd 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2493,13 +2493,334 @@ def test_arange_like_dtype(): x = mx.sym.Variable('x', dtype=t) y = mx.sym.reshape(x, shape=(0, 0, -1)) z = mx.sym.contrib.arange_like(y, axis=-1) - + mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null') mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t) out = mod.forward(is_train=False) for v in out: assert v.dtype == t +@with_seed() +def check_multihead_attention_selfatt(bwd_ignore_zero_init): + def convert_weight(F, q_weight, k_weight, v_weight, num_heads): + q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, q_bias, k_bias, v_bias, num_heads): + q_bias = F.reshape(q_bias, shape=(num_heads, -1)) + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + dtype='float16' + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) + qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) + qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) + qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, + num_hidden=qkv_units * 3, no_bias=False) + att_score = mx.sym.interleaved_matmul_selfatt_qk(qkv_proj, heads=num_heads, + bwd_ignore_zero_init=bwd_ignore_zero_init) + att_score = att_score + sonde + weighted_value = mx.sym.interleaved_matmul_selfatt_valatt(qkv_proj, att_score, heads=num_heads, + bwd_ignore_zero_init=bwd_ignore_zero_init) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + qkv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + type_dict={'qkv': dtype, + 'q_weight': dtype, + 'k_weight': dtype, + 'v_weight': dtype, + 'q_bias': dtype, + 'k_bias': dtype, + 'v_bias': dtype, + 'sonde': dtype}, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + qkv=(batch_size, qkv_length, qkv_dim), + type_dict={'qkv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + +def test_multihead_attention_selfatt(): + os.environ['MXNET_EXEC_ENABLE_ADDTO'] = '0' + check_multihead_attention_selfatt(bwd_ignore_zero_init=False) + os.environ['MXNET_EXEC_ENABLE_ADDTO'] = '1' + check_multihead_attention_selfatt(bwd_ignore_zero_init=False) + check_multihead_attention_selfatt(bwd_ignore_zero_init=True) + +def check_multihead_attention_encdec(bwd_ignore_zero_init): + def convert_weight(F, k_weight, v_weight, num_heads): + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, k_bias, v_bias, num_heads): + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype('float16') * 0.1, dtype='float16'), + 'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype('float16') * 0.1, dtype='float16'), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype('float16') * 0.1, dtype='float16'), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype('float16') * 0.1, dtype='float16'), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype('float16') * 0.1, dtype='float16'), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype('float16') * 0.1, dtype='float16'), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype('float16') * 0.1, dtype='float16'), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype('float16') * 0.1, dtype='float16'), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype('float16') * 0.1, dtype='float16'), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype('float16') * 0.1, dtype='float16'), + } + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads) + kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads) + kv = mx.sym.transpose(kv, axes=(1, 0, 2)) + kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias, flatten=False, + num_hidden=qkv_units * 2, no_bias=False) + q = mx.sym.transpose(q, axes=(1, 0, 2)) + q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + att_score = mx.sym.interleaved_matmul_encdec_qk(q_proj, kv_proj, heads=num_heads, + bwd_ignore_zero_init=bwd_ignore_zero_init) + att_score = att_score + sonde + weighted_value = mx.sym.interleaved_matmul_encdec_valatt(kv_proj, att_score, heads=num_heads, + bwd_ignore_zero_init=bwd_ignore_zero_init) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + out_weight=(out_dim, qkv_units), + out_bias=(out_dim,), + type_dict={'q': 'float16', + 'kv': 'float16', + 'q_weight': 'float16', + 'q_bias': 'float16', + 'k_weight': 'float16', + 'k_bias': 'float16', + 'v_weight': 'float16', + 'v_bias': 'float16', + 'out_weight': 'float16', + 'out_bias': 'float16', + }, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype('float16') * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype='float16'), mx.nd.zeros(att_score_opti.shape, dtype='float16')]) + + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + type_dict={'q': 'float16', + 'kv': 'float16'}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype='float16'), mx.nd.zeros(att_score_orig.shape, dtype='float16')]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + +def test_multihead_attention_encdec(): + os.environ['MXNET_EXEC_ENABLE_ADDTO'] = '0' + check_multihead_attention_encdec(bwd_ignore_zero_init=False) + os.environ['MXNET_EXEC_ENABLE_ADDTO'] = '1' + check_multihead_attention_encdec(bwd_ignore_zero_init=False) + check_multihead_attention_encdec(bwd_ignore_zero_init=True) if __name__ == '__main__': import nose