diff --git a/src/operator/contrib/transformer.cc b/src/operator/contrib/transformer.cc index 2ca6f8c71093..58826a2d96a8 100644 --- a/src/operator/contrib/transformer.cc +++ b/src/operator/contrib/transformer.cc @@ -122,6 +122,531 @@ static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, return true; } +void strided_batch_sgemm(bool transA, bool transB, + index_t m, index_t n, index_t k, + float alpha, const float *a, index_t lda, + index_t strideA, const float *b, index_t ldb, + index_t strideB, float beta, float *c, index_t ldc, + index_t strideC, int32_t batchCount) { + std::vector pp_A(batchCount, nullptr); + std::vector pp_B(batchCount, nullptr); + std::vector pp_C(batchCount, nullptr); + + for (int i = 0; i < batchCount; i++) { + pp_A[i] = a + i * strideA; + pp_B[i] = b + i * strideB; + pp_C[i] = c + i * strideC; + } + +#if (MSHADOW_USE_MKL && INTEL_MKL_VERSION >= 20160000) + const int GROUP_SIZE = 1; + MKL_INT p_m[GROUP_SIZE] = {m}; + MKL_INT p_n[GROUP_SIZE] = {n}; + MKL_INT p_k[GROUP_SIZE] = {k}; + MKL_INT p_lda[GROUP_SIZE] = {lda}; + MKL_INT p_ldb[GROUP_SIZE] = {ldb}; + MKL_INT p_ldc[GROUP_SIZE] = {ldc}; + + float p_alpha[GROUP_SIZE] = {alpha}; + float p_beta[GROUP_SIZE] = {beta}; + + CBLAS_TRANSPOSE cblas_a_trans = transA ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE cblas_b_trans = transB ? CblasTrans : CblasNoTrans; + + MKL_INT p_group_sizeb[GROUP_SIZE] = {batchCount}; + CBLAS_TRANSPOSE p_transa[GROUP_SIZE] = {cblas_a_trans}; + CBLAS_TRANSPOSE p_transb[GROUP_SIZE] = {cblas_b_trans}; + + cblas_sgemm_batch(CblasColMajor, p_transa, p_transb, + p_m, p_n, p_k, p_alpha, pp_A.data(), p_lda, pp_B.data(), + p_ldb, p_beta, pp_C.data(), p_ldc, GROUP_SIZE, p_group_sizeb); +#else + for (int i = 0; i < batchCount; ++i) { + cblas_sgemm(CblasColMajor, + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + m, n, k, + alpha, pp_A[i], lda, + pp_B[i], ldb, beta, pp_C[i], ldc); + } +#endif +} + +void InterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + + const index_t qkv_seq_len = inputs[0].shape_[0]; + const index_t sequences = inputs[0].shape_[1]; + const index_t output_lin_dim = inputs[0].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + + strided_batch_sgemm(true, + false, + qkv_seq_len, + qkv_seq_len, + head_dim, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + queries_keys_values, + lead_dim, + batch_stride, + beta, + output, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); +} + +void BackwardInterleavedMatMulSelfAttQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + mshadow::Stream* s = ctx.get_stream(); + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + float* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + const index_t qkv_seq_len = inputs[1].shape_[0]; + const index_t sequences = inputs[1].shape_[1]; + const index_t output_lin_dim = inputs[1].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float scale = 1.0 / sqrt(static_cast(head_dim)); + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + if (req[0] == kWriteTo) { + memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof (float)); + } + + strided_batch_sgemm(false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values + head_dim, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads, + lead_dim, + batch_stride, + attn_batches); + + strided_batch_sgemm(false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + scale, + queries_keys_values, + lead_dim, + batch_stride, + output_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + head_dim, + lead_dim, + batch_stride, + attn_batches); +} + +void InterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* queries_keys_values = inputs[0].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[1].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + const index_t qkv_seq_len = inputs[0].shape_[0]; + const index_t sequences = inputs[0].shape_[1]; + const index_t output_lin_dim = inputs[0].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + strided_batch_sgemm(false, + false, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); +} + +void BackwardInterleavedMatMulSelfAttValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* queries_keys_values = inputs[1].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[2].FlatTo2D(s).dptr_; + float* queries_keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + float* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const index_t qkv_seq_len = inputs[1].shape_[0]; + const index_t sequences = inputs[1].shape_[1]; + const index_t output_lin_dim = inputs[1].shape_[2]; + const index_t embed_dim = output_lin_dim / 3; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim = attn_batches * 3 * head_dim; + const index_t batch_stride = 3 * head_dim; + const float alpha = 1.f; + if (req[0] != kNullOp) { + if (req[0] == kWriteTo) { + memset(queries_keys_values_grads, 0, outputs[0].shape_.Size() * sizeof (float)); + } + + const float beta = req[0] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + true, + head_dim, + qkv_seq_len, + qkv_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + beta, + queries_keys_values_grads + 2 * head_dim, + lead_dim, + batch_stride, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(true, + false, + qkv_seq_len, + qkv_seq_len, + head_dim, + alpha, + queries_keys_values + 2 * head_dim, + lead_dim, + batch_stride, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + qkv_seq_len, + qkv_seq_len * qkv_seq_len, + attn_batches); + } +} + +void InterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* queries = inputs[0].FlatTo2D(s).dptr_; + const float* keys_values = inputs[1].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + const index_t q_seq_len = inputs[0].shape_[0]; + const index_t sequences = inputs[0].shape_[1]; + const index_t output_lin_q_dim = inputs[0].shape_[2]; + const index_t kv_seq_len = inputs[1].shape_[0]; + const index_t embed_dim = output_lin_q_dim; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim_q = attn_batches * head_dim; + const index_t lead_dim_kv = attn_batches * 2 * head_dim; + const index_t batch_stride_q = head_dim; + const index_t batch_stride_kv = head_dim * 2; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + strided_batch_sgemm(true, + false, + kv_seq_len, + q_seq_len, + head_dim, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + queries, + lead_dim_q, + batch_stride_q, + beta, + output, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); +} + +void BackwardInterleavedMatMulEncDecQKCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* queries = inputs[1].FlatTo2D(s).dptr_; + const float* keys_values = inputs[2].FlatTo2D(s).dptr_; + float* queries_grads = outputs[0].FlatTo2D(s).dptr_; + float* keys_values_grads = outputs[1].FlatTo2D(s).dptr_; + const index_t q_seq_len = inputs[1].shape_[0]; + const index_t sequences = inputs[1].shape_[1]; + const index_t output_lin_q_dim = inputs[1].shape_[2]; + const index_t kv_seq_len = inputs[2].shape_[0]; + const index_t embed_dim = output_lin_q_dim; + const index_t head_dim = embed_dim / params.heads; + const index_t attn_batches = params.heads * sequences; + const index_t lead_dim_q = attn_batches * head_dim; + const index_t lead_dim_kv = attn_batches * 2 * head_dim; + const index_t batch_stride_q = head_dim; + const index_t batch_stride_kv = head_dim * 2; + const float scale = 1.f / sqrt(static_cast(head_dim)); + + if (req[0] != kNullOp) { + const float beta = req[0] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + false, + head_dim, + q_seq_len, + kv_seq_len, + scale, + keys_values, + lead_dim_kv, + batch_stride_kv, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + queries_grads, + lead_dim_q, + batch_stride_q, + attn_batches); + } + if (req[1] != kNullOp) { + if (req[1] == kWriteTo) { + memset(keys_values_grads, 0, outputs[1].shape_.Size() * sizeof (float)); + } + const float beta = req[1] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + true, + head_dim, + kv_seq_len, + q_seq_len, + scale, + queries, + lead_dim_q, + batch_stride_q, + output_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } +} + +void InterleavedMatMulEncDecValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) + return; + + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* keys_values = inputs[0].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[1].FlatTo2D(s).dptr_; + float* output = outputs[0].FlatTo2D(s).dptr_; + const index_t kv_seq_len = inputs[0].shape_[0]; + const index_t output_lin_kv_dim = inputs[0].shape_[2]; + const index_t attn_batches = inputs[1].shape_[0]; + const index_t q_seq_len = inputs[1].shape_[1]; + const index_t embed_dim = output_lin_kv_dim / 2; + const index_t head_dim = embed_dim / params.heads; + const index_t lead_dim_kv = attn_batches * head_dim * 2; + const index_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + const float beta = req[0] == kAddTo ? 1.f : 0.f; + + strided_batch_sgemm(false, + false, + head_dim, + q_seq_len, + kv_seq_len, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + output, + head_dim * attn_batches, + head_dim, + attn_batches); +} + +void BackwardInterleavedMatMulEncDecValAttCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(inputs[0].type_flag_, mshadow::kFloat32) + << "Only FP32 is supported on CPU at the moment"; + + mshadow::Stream* s = ctx.get_stream(); + const float* output_grads = inputs[0].FlatTo2D(s).dptr_; + const float* keys_values = inputs[1].FlatTo2D(s).dptr_; + const float* attention_maps = inputs[2].FlatTo2D(s).dptr_; + float* keys_values_grads = outputs[0].FlatTo2D(s).dptr_; + float* attention_maps_grads = outputs[1].FlatTo2D(s).dptr_; + const index_t kv_seq_len = inputs[1].shape_[0]; + const index_t output_lin_kv_dim = inputs[1].shape_[2]; + const index_t attn_batches = inputs[2].shape_[0]; + const index_t q_seq_len = inputs[2].shape_[1]; + const index_t embed_dim = output_lin_kv_dim / 2; + const index_t head_dim = embed_dim / params.heads; + const index_t lead_dim_kv = attn_batches * head_dim * 2; + const index_t batch_stride_kv = 2 * head_dim; + const float alpha = 1.f; + + if (req[0] != kNullOp) { + if (req[0] == kWriteTo) { + memset(keys_values_grads, 0, outputs[0].shape_.Size() * sizeof (float)); + } + const float beta = req[0] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(false, + true, + head_dim, + kv_seq_len, + q_seq_len, + alpha, + output_grads, + head_dim * attn_batches, + head_dim, + attention_maps, + kv_seq_len, + kv_seq_len * q_seq_len, + beta, + keys_values_grads + head_dim, + lead_dim_kv, + batch_stride_kv, + attn_batches); + } + if (req[1] != kNullOp) { + const float beta = req[1] == kAddTo ? 1.f : 0.f; + strided_batch_sgemm(true, + false, + kv_seq_len, + q_seq_len, + head_dim, + alpha, + keys_values + head_dim, + lead_dim_kv, + batch_stride_kv, + output_grads, + head_dim * attn_batches, + head_dim, + beta, + attention_maps_grads, + kv_seq_len, + kv_seq_len * q_seq_len, + attn_batches); + } +} + NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_qk) .describe(R"code(Compute the matrix multiplication between the projections of queries and keys in multihead attention use as self attention. @@ -138,8 +663,6 @@ q_proj = mx.nd.contrib.div_sqrt_dim(q_proj) k_proj = mx.nd.transpose(tmp[:,:,:,1,:], axes=(1, 2, 0, 3)) k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(1) .set_num_outputs(1) @@ -152,6 +675,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulSelfAttQKShape) .set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", InterleavedMatMulSelfAttQKCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"}) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values") @@ -161,7 +685,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk) .set_num_inputs(2) .set_num_outputs(1) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttQKCPU); NNVM_REGISTER_OP(_contrib_interleaved_matmul_selfatt_valatt) .describe(R"code(Compute the matrix multiplication between the projections of @@ -182,8 +707,6 @@ output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) output = mx.nd.reshape(output, shape=(0, 0, -1)) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -196,6 +719,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulSelfAttValAttShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", InterleavedMatMulSelfAttValAttCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"}) .add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved") @@ -206,7 +730,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulSelfAttValAttCPU); NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_qk) .describe(R"code(Compute the matrix multiplication between the projections of @@ -226,8 +751,6 @@ tmp = mx.nd.reshape(keys_values, shape=(0, 0, num_heads, 2, -1)) k_proj = mx.nd.transpose(tmp[:,:,:,0,:], axes=(1, 2, 0, 3)) k_proj = mx.nd.reshap(k_proj, shape=(-1, 0, 0), reverse=True) output = mx.nd.batch_dot(q_proj, k_proj, transpose_b=True) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -240,6 +763,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulEncDecQKShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", InterleavedMatMulEncDecQKCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"}) .add_argument("queries", "NDArray-or-Symbol", "Queries") @@ -250,7 +774,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecQKCPU); NNVM_REGISTER_OP(_contrib_interleaved_matmul_encdec_valatt) .describe(R"code(Compute the matrix multiplication between the projections of @@ -272,8 +797,6 @@ output = mx.nd.batch_dot(attention, v_proj, transpose_b=True) output = mx.nd.reshape(output, shape=(-1, num_heads, 0, 0), reverse=True) output = mx.nd.transpose(output, axes=(0, 2, 1, 3)) output = mx.nd.reshape(output, shape=(0, 0, -1)) - -This Op is GPU only )code" ADD_FILELINE) .set_num_inputs(2) .set_num_outputs(1) @@ -286,6 +809,7 @@ This Op is GPU only }) .set_attr("FInferShape", InterleavedMatMulEncDecValAttShape) .set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FCompute", InterleavedMatMulEncDecValAttCPU) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_valatt"}) .add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved") @@ -296,7 +820,8 @@ NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) -.set_attr_parser(ParamParser); +.set_attr_parser(ParamParser) +.set_attr("FCompute", BackwardInterleavedMatMulEncDecValAttCPU); // relu diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index e548217b9369..721eaaebab31 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2548,323 +2548,6 @@ def test_arange_like_dtype(): for v in out: assert v.dtype == t -@with_seed() -def check_multihead_attention_selfatt(dtype): - def convert_weight(F, q_weight, k_weight, v_weight, num_heads): - q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) - k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) - v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) - all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) - all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) - return all_weights - - def convert_bias(F, q_bias, k_bias, v_bias, num_heads): - q_bias = F.reshape(q_bias, shape=(num_heads, -1)) - k_bias = F.reshape(k_bias, shape=(num_heads, -1)) - v_bias = F.reshape(v_bias, shape=(num_heads, -1)) - all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) - all_bias = F.reshape(all_bias, shape=(-1,)) - return all_bias - - batch_size = 2 - qkv_length = 7 # length of a sequence - qkv_dim = 9 # dimension of encoding - num_heads = 3 # number of attention head - head_dim = 5 # head size - out_dim = 13 * num_heads - qkv_units = num_heads * head_dim - - arg_params = { - 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), - 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), - } - - qkv = mx.sym.Variable('qkv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) - qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) - qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) - qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, - num_hidden=qkv_units * 3, no_bias=False) - att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk( - qkv_proj, heads=num_heads) - att_score = att_score + sonde - weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt( - qkv_proj, att_score, heads=num_heads) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.transpose(output, axes=(1, 0, 2)) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - qkv=(batch_size, qkv_length, qkv_dim), - q_weight=(qkv_units, qkv_dim), - q_bias=(qkv_units,), - k_weight=(qkv_units, qkv_dim), - k_bias=(qkv_units,), - v_weight=(qkv_units, qkv_dim), - v_bias=(qkv_units,), - type_dict={'qkv': dtype, - 'q_weight': dtype, - 'k_weight': dtype, - 'v_weight': dtype, - 'q_bias': dtype, - 'k_bias': dtype, - 'v_bias': dtype, - 'sonde': dtype}, - grad_req='write', force_rebind=True) - output_shape = executor.outputs[0].shape - output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_opti = executor.outputs[0].asnumpy() - att_score_opti = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), - mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) - grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} - qkv = mx.sym.Variable('qkv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - - q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) - q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) - q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) - k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) - k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) - k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) - q = mx.sym.contrib.div_sqrt_dim(q) - att_score = mx.sym.batch_dot(q, k, transpose_b=True) - att_score = att_score + sonde - v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) - v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) - v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) - weighted_value = mx.sym.batch_dot(att_score, v) - weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), - reverse=True) - weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) - weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - qkv=(batch_size, qkv_length, qkv_dim), - type_dict={'qkv': dtype}, - grad_req='write', force_rebind=True) - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_orig = executor.outputs[0].asnumpy() - att_score_orig = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), - mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) - grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} - assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) - assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) - - for k in grads_opti.keys(): - assert(grads_orig[k].dtype == grads_opti[k].dtype) - assert(grads_orig[k].shape == grads_opti[k].shape) - assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) - -@assert_raises_cuda_not_satisfied(min_version='9.1') -def test_multihead_attention_selfatt(): - for dtype in ['float16', 'float32']: - check_multihead_attention_selfatt(dtype=dtype) - -def check_multihead_attention_encdec(dtype): - def convert_weight(F, k_weight, v_weight, num_heads): - k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) - v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) - all_weights = F.concat(k_weight, v_weight, dim=-2) - all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) - return all_weights - - def convert_bias(F, k_bias, v_bias, num_heads): - k_bias = F.reshape(k_bias, shape=(num_heads, -1)) - v_bias = F.reshape(v_bias, shape=(num_heads, -1)) - all_bias = F.stack(k_bias, v_bias, axis=1) - all_bias = F.reshape(all_bias, shape=(-1,)) - return all_bias - - batch_size = 2 - qkv_length = 7 # length of a sequence - qkv_dim = 9 # dimension of encoding - num_heads = 3 # number of attention head - head_dim = 5 # head size - out_dim = 13 * num_heads - qkv_units = num_heads * head_dim - - arg_params = { - 'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), - 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), - 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), - 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), - } - - q = mx.sym.Variable('q') - kv = mx.sym.Variable('kv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads) - kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads) - kv = mx.sym.transpose(kv, axes=(1, 0, 2)) - kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias, flatten=False, - num_hidden=qkv_units * 2, no_bias=False) - q = mx.sym.transpose(q, axes=(1, 0, 2)) - q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - att_score = mx.sym.contrib.interleaved_matmul_encdec_qk( - q_proj, kv_proj, heads=num_heads) - att_score = att_score + sonde - weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt( - kv_proj, att_score, heads=num_heads) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.transpose(output, axes=(1, 0, 2)) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - q=(batch_size, qkv_length, qkv_dim), - kv=(batch_size, qkv_length, qkv_dim), - q_weight=(qkv_units, qkv_dim), - q_bias=(qkv_units,), - k_weight=(qkv_units, qkv_dim), - k_bias=(qkv_units,), - v_weight=(qkv_units, qkv_dim), - v_bias=(qkv_units,), - out_weight=(out_dim, qkv_units), - out_bias=(out_dim,), - type_dict={'q': dtype, - 'kv': dtype, - 'q_weight': dtype, - 'q_bias': dtype, - 'k_weight': dtype, - 'k_bias': dtype, - 'v_weight': dtype, - 'v_bias': dtype, - 'out_weight': dtype, - 'out_bias': dtype, - }, - grad_req='write', force_rebind=True) - output_shape = executor.outputs[0].shape - output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_opti = executor.outputs[0].asnumpy() - att_score_opti = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) - - grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} - - q = mx.sym.Variable('q') - kv = mx.sym.Variable('kv') - sonde = mx.sym.Variable('sonde') - q_weight = mx.sym.Variable('q_weight') - k_weight = mx.sym.Variable('k_weight') - v_weight = mx.sym.Variable('v_weight') - q_bias = mx.sym.Variable('q_bias') - k_bias = mx.sym.Variable('k_bias') - v_bias = mx.sym.Variable('v_bias') - out_weight = mx.sym.Variable('out_weight') - out_bias = mx.sym.Variable('out_bias') - - q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False, - num_hidden=qkv_units, no_bias=False) - q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) - q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) - q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) - k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) - k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) - k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) - q = mx.sym.contrib.div_sqrt_dim(q) - att_score = mx.sym.batch_dot(q, k, transpose_b=True) - att_score = att_score + sonde - v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) - v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) - v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) - weighted_value = mx.sym.batch_dot(att_score, v) - weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), - reverse=True) - weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) - weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) - output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, - num_hidden=out_dim, no_bias=False) - output = mx.sym.Group([output, att_score]) - executor = output.simple_bind(ctx=mx.gpu(0), - q=(batch_size, qkv_length, qkv_dim), - kv=(batch_size, qkv_length, qkv_dim), - type_dict={'q': dtype, - 'kv': dtype}, - grad_req='write', force_rebind=True) - executor.copy_params_from(arg_params, {}) - executor.arg_dict['sonde'][:] = 0. - executor.arg_dict['sonde'].wait_to_read() - executor.forward(is_train=True) - output_orig = executor.outputs[0].asnumpy() - att_score_orig = executor.outputs[1].asnumpy() - executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) - grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} - assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) - assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) - - for k in grads_opti.keys(): - assert(grads_orig[k].dtype == grads_opti[k].dtype) - assert(grads_orig[k].shape == grads_opti[k].shape) - assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) - -@assert_raises_cuda_not_satisfied(min_version='9.1') -def test_multihead_attention_encdec(): - for dtype in ['float16', 'float32']: - check_multihead_attention_encdec(dtype=dtype) - if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index d59c3063f95a..72350de4d85f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9373,6 +9373,330 @@ def check_random_uniform(): hight = 1 assertRaises(MXNetError, mx.nd.random_uniform, alpha, beta, shape) +def check_multihead_attention_selfatt(dtype): + def convert_weight(F, q_weight, k_weight, v_weight, num_heads): + q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, q_bias, k_bias, v_bias, num_heads): + q_bias = F.reshape(q_bias, shape=(num_heads, -1)) + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) + qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) + qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) + qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, + num_hidden=qkv_units * 3, no_bias=False) + att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk( + qkv_proj, heads=num_heads) + att_score = att_score + sonde + weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt( + qkv_proj, att_score, heads=num_heads) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + qkv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + type_dict={'qkv': dtype, + 'q_weight': dtype, + 'k_weight': dtype, + 'v_weight': dtype, + 'q_bias': dtype, + 'k_bias': dtype, + 'v_bias': dtype, + 'sonde': dtype}, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + qkv=(batch_size, qkv_length, qkv_dim), + type_dict={'qkv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + + +@with_seed() +def test_multihead_attention_selfatt(): + dtypes = ['float32'] + if default_context().device_type == 'gpu': + dtypes += ['float16'] + + for dtype in dtypes: + check_multihead_attention_selfatt(dtype=dtype) + +def check_multihead_attention_encdec(dtype): + def convert_weight(F, k_weight, v_weight, num_heads): + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, k_bias, v_bias, num_heads): + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'q': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'kv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + kv_weight = convert_weight(mx.sym, k_weight, v_weight, num_heads) + kv_bias = convert_bias(mx.sym, k_bias, v_bias, num_heads) + kv = mx.sym.transpose(kv, axes=(1, 0, 2)) + kv_proj = mx.sym.FullyConnected(kv, weight=kv_weight, bias=kv_bias, flatten=False, + num_hidden=qkv_units * 2, no_bias=False) + q = mx.sym.transpose(q, axes=(1, 0, 2)) + q_proj = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + att_score = mx.sym.contrib.interleaved_matmul_encdec_qk( + q_proj, kv_proj, heads=num_heads) + att_score = att_score + sonde + weighted_value = mx.sym.contrib.interleaved_matmul_encdec_valatt( + kv_proj, att_score, heads=num_heads) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + out_weight=(out_dim, qkv_units), + out_bias=(out_dim,), + type_dict={'q': dtype, + 'kv': dtype, + 'q_weight': dtype, + 'q_bias': dtype, + 'k_weight': dtype, + 'k_bias': dtype, + 'v_weight': dtype, + 'v_bias': dtype, + 'out_weight': dtype, + 'out_bias': dtype, + }, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + + q = mx.sym.Variable('q') + kv = mx.sym.Variable('kv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(q, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(kv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(kv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=default_context(), + q=(batch_size, qkv_length, qkv_dim), + kv=(batch_size, qkv_length, qkv_dim), + type_dict={'q': dtype, + 'kv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + +@with_seed() +def test_multihead_attention_encdec(): + dtypes = ['float32'] + if default_context().device_type == 'gpu': + dtypes += ['float16'] + + for dtype in dtypes: + check_multihead_attention_encdec(dtype=dtype) @with_seed() def test_im2col_col2im():