Skip to content

Commit

Permalink
add MXNet Ops for fast multihead attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Caenorst committed Oct 9, 2019
1 parent d5666ed commit e987614
Show file tree
Hide file tree
Showing 4 changed files with 1,377 additions and 1 deletion.
12 changes: 12 additions & 0 deletions src/operator/contrib/transformer-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@
namespace mxnet {
namespace op {

struct InterleavedMatMulParam : public dmlc::Parameter<InterleavedMatMulParam> {
int heads;
bool bwd_ignore_zero_init;
DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) {
DMLC_DECLARE_FIELD(heads)
.describe("Set number of heads");
DMLC_DECLARE_FIELD(bwd_ignore_zero_init)
.describe("Make backward pass ignore AddTo and not init to 0.")
.set_default(false);
}
};

template<typename xpu>
static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
157 changes: 157 additions & 0 deletions src/operator/contrib/transformer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,163 @@
namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(InterleavedMatMulParam);

static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1);
auto qkv_shape = in_shape->at(0);
CHECK_EQ(qkv_shape.ndim(), 3);
out_shape->resize(1);
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]}));
return true;
}

static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
CHECK_EQ(in_shape->size(), 2);
auto qkv_shape = in_shape->at(0);
auto att_shape = in_shape->at(1);
CHECK_EQ(qkv_shape.ndim(), 3);
CHECK_EQ(att_shape.ndim(), 3);
CHECK_EQ(qkv_shape[0], att_shape[1]);
CHECK_EQ(qkv_shape[0], att_shape[2]);
CHECK_EQ(qkv_shape[2] % 3, 0);
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3}));
return true;
}

static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 2);
auto q_shape = in_shape->at(0);
auto kv_shape = in_shape->at(1);
CHECK_EQ(q_shape.ndim(), 3);
CHECK_EQ(kv_shape.ndim(), 3);
CHECK_EQ(q_shape[2] * 2, kv_shape[2]);
CHECK_EQ(q_shape[1], kv_shape[1]);
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]}));
return true;
}

static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs,
mxnet::ShapeVector* in_shape,
mxnet::ShapeVector* out_shape) {
const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 2);
auto kv_shape = in_shape->at(0);
auto att_shape = in_shape->at(1);
CHECK_EQ(kv_shape[0], att_shape[2]);
CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]);
SHAPE_ASSIGN_CHECK(*out_shape, 0,
mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2}));
return true;
}

NNVM_REGISTER_OP(interleaved_matmul_selfatt_qk)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"queries_keys_values"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulSelfAttQKShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_qk"})
.add_argument("queries_keys_values", "NDArray-or-Symbol", "Interleaved queries, keys and values")
.add_arguments(InterleavedMatMulParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_qk)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);

NNVM_REGISTER_OP(interleaved_matmul_selfatt_valatt)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"queries_keys_values", "attention"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulSelfAttValAttShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_selfatt_valatt"})
.add_argument("queries_keys_values", "NDArray-or-Symbol", "Queries, keys and values interleaved")
.add_argument("attention", "NDArray-or-Symbol", "Attention maps")
.add_arguments(InterleavedMatMulParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_interleaved_matmul_selfatt_valatt)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);

NNVM_REGISTER_OP(interleaved_matmul_encdec_qk)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"queries", "keys_values"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulEncDecQKShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_qk"})
.add_argument("queries", "NDArray-or-Symbol", "Queries")
.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved")
.add_arguments(InterleavedMatMulParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_qk)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);

NNVM_REGISTER_OP(interleaved_matmul_encdec_valatt)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<InterleavedMatMulParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"keys_values", "attention"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<mxnet::FInferShape>("FInferShape", InterleavedMatMulEncDecValAttShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<2, 1>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseIn{"_backward_interleaved_matmul_encdec_valatt"})
.add_argument("keys_values", "NDArray-or-Symbol", "Keys and values interleaved")
.add_argument("attention", "NDArray-or-Symbol", "Attention maps")
.add_arguments(InterleavedMatMulParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_interleaved_matmul_encdec_valatt)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<InterleavedMatMulParam>);


// relu
MXNET_OPERATOR_REGISTER_UNARY(_contrib_div_sqrt_dim)
.describe(R"code(Rescale the input by the square root of the channel dimension.
Expand Down
Loading

0 comments on commit e987614

Please sign in to comment.