-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[ONNX] Add imports for BERT contrib operators #10949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c4b3e16
4190bb2
1d3064e
1927414
b718d6a
90bb12f
4998492
768e535
29e0c68
dbb7df3
265e753
43296f9
de7d940
93aceb2
dfba87e
989b412
16a4d09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -329,6 +329,22 @@ def flatten_to_nd(x, x_shape, nd=3): | |
| return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype) | ||
|
|
||
|
|
||
| def layer_norm(x, eps, gamma, beta): | ||
| """Common function to handle layer norm""" | ||
| eps_dtype = infer_type(x).checked_type.dtype | ||
|
|
||
| u, s = _op.mean_variance(x, axis=-1, keepdims=True) | ||
| output = _op.divide( | ||
| _op.subtract(x, u), | ||
| _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), | ||
| ) | ||
| output = _op.multiply(output, gamma) | ||
| if beta is not None: | ||
| output = _op.add(output, beta) | ||
|
|
||
| return output | ||
|
|
||
|
|
||
| class OnnxOpConverter(object): | ||
| """A helper class for holding onnx op converters.""" | ||
|
|
||
|
|
@@ -807,9 +823,10 @@ def _impl_v1(cls, inputs, attr, params): | |
| x = inputs[0] | ||
|
|
||
| # Declare consts | ||
| half = _expr.const(0.5) | ||
| one = _expr.const(1.0) | ||
| sqrt2 = _expr.const(math.sqrt(2)) | ||
| const_dtype = infer_type(x).checked_type.dtype | ||
| half = _expr.const(0.5, dtype=const_dtype) | ||
| one = _expr.const(1.0, dtype=const_dtype) | ||
| sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype) | ||
|
|
||
| # Compute gelu | ||
| term1 = _op.multiply(half, x) | ||
|
|
@@ -836,6 +853,201 @@ def _impl_v1(cls, inputs, attr, params): | |
| return Gelu._impl_v1([inp], attr, params) | ||
|
|
||
|
|
||
| class EmbedLayerNormalization(OnnxOpConverter): | ||
| """Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset. | ||
|
|
||
| This layer embeds the input tokens, sums them, and applies layer normalization. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _impl_v1(cls, inputs, attr, params): | ||
| input_ids = inputs[0] | ||
| segment_ids = inputs[1] | ||
| word_emb = inputs[2] | ||
| pos_emb = inputs[3] | ||
| segment_emb = inputs[4] | ||
| gamma = inputs[5] | ||
| beta = inputs[6] | ||
|
|
||
| mask = inputs[7] | ||
| pos_ids = inputs[8] | ||
|
|
||
| eps = attr.get("epsilon", 1e-12) | ||
|
|
||
| (batch_size, seq_len) = infer_shape(input_ids) | ||
|
|
||
| if segment_ids: | ||
| assert segment_emb | ||
|
|
||
| if pos_ids is None: | ||
| pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32") | ||
|
|
||
| word_vec = _op.take(word_emb, input_ids, axis=0) | ||
| segment_vec = _op.take(segment_emb, segment_ids, axis=0) | ||
| pos_vec = _op.take(pos_emb, pos_ids, axis=0) | ||
|
|
||
| vec_sum = _op.add(word_vec, pos_vec) | ||
| if segment_ids: | ||
| vec_sum = _op.add(vec_sum, segment_vec) | ||
|
|
||
| ln = layer_norm(vec_sum, eps, gamma, beta) | ||
|
|
||
| mask_index = _op.const(np.zeros((batch_size,), dtype="int32")) | ||
| if mask: | ||
| # calculate number of words per sentence | ||
| mask_index = _op.sum(mask, axis=1) | ||
|
|
||
| # TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum | ||
| return _expr.TupleWrapper(_expr.Tuple([ln, mask_index]), 2) | ||
|
|
||
|
|
||
| class SkipLayerNormalization(OnnxOpConverter): | ||
| """Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset. | ||
|
|
||
| This layer sums the two input tensors (along with optional bias), and applies layer | ||
| normalization. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _impl_v1(cls, inputs, attr, params): | ||
| data = inputs[0] | ||
| skip = inputs[1] | ||
| gamma = inputs[2] | ||
| beta = inputs[3] | ||
altanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| bias = inputs[4] | ||
|
|
||
| assert ( | ||
| beta is not None and bias is not None | ||
| ), "SkipLayerNormalization import currently only supports required beta and bias" | ||
|
|
||
| eps = attr.get("epsilon", 1e-12) | ||
|
|
||
| x = _op.add(data, skip) | ||
| if bias is not None: | ||
| x = _op.add(x, bias) | ||
|
|
||
| output = layer_norm(x, eps, gamma, beta) | ||
|
|
||
| # onnxruntime doesn't compute the other outputs, despite the documentation | ||
| placeholder = _op.const(0, dtype="float32") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this placeholder for? optional returns are mean and inverse standard variance right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs: https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc |
||
|
|
||
| return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3) | ||
|
|
||
|
|
||
| class Attention(OnnxOpConverter): | ||
| """Operator converter for Attention from Microsoft onnxruntime contrib opset. | ||
|
|
||
| This is the self-attention mechanism used in transformer models. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _impl_v1(cls, inputs, attr, params): | ||
| num_heads = attr["num_heads"] | ||
| assert ( | ||
| "qkv_hidden_sizes" not in attr | ||
| ), "different hidden sizes for Q, K, V are not currently supported" | ||
| assert "unidirectional" not in attr, "unidirectional attention not current supported" | ||
|
|
||
| # (batch, seq, in_hidden) | ||
| input_emb = inputs[0] | ||
|
|
||
| # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size | ||
| weight = inputs[1] | ||
|
|
||
| # (3 * out_hidden,) | ||
| bias = inputs[2] | ||
|
|
||
| # 1. ( batch, 1, max_seq, max_seq) | ||
| # 2. ( batch, past_seq + seq,) | ||
| # 3. ( batch, seq, past_seq + seq,) | ||
| # 4. ( batch,) | ||
| # 5. (2 * batch,) | ||
| # For now, we only support case 2. | ||
| mask_index = inputs[3] | ||
altanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # (2, batch, num_heads, past_seq, head_size) | ||
| past = inputs[4] | ||
|
|
||
| # (batch, num_heads, seq, seq) | ||
| extra_add = inputs[5] | ||
|
|
||
| (batch_size, seq_len, _) = infer_shape(input_emb) | ||
| (out_hidden_x3,) = infer_shape(bias) | ||
| assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3" | ||
| out_hidden = out_hidden_x3 // 3 | ||
| assert ( | ||
| out_hidden % num_heads == 0 | ||
| ), "output hidden size should be divisible by number of attention heads" | ||
| head_size = out_hidden // num_heads | ||
|
|
||
| assert ( | ||
| mask_index is not None | ||
| ), "Attention import currently only supports required mask_index" | ||
| mask_index_shape = infer_shape(mask_index) | ||
| assert ( | ||
| len(mask_index_shape) == 2 | ||
| and mask_index_shape[0] == batch_size | ||
| and mask_index_shape[1] == seq_len | ||
| ), "currently only support (batch_size, sequence_length) mask index" | ||
|
|
||
| assert past is None, "past K, V state is not currently supported" | ||
| assert extra_add is None, "extra add to QxK not currently supported" | ||
|
|
||
| # split weight and biases and do the matmuls | ||
| w_Q, w_K, w_V = _op.split(weight, 3, axis=1) | ||
| b_Q, b_K, b_V = _op.split(bias, 3, axis=0) | ||
| # need to merge batch dimensions since TVM matmul is 2D | ||
| input_emb = _op.reverse_reshape(input_emb, (-1, 0)) | ||
| Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q) | ||
| K = _op.add(_op.nn.matmul(input_emb, w_K), b_K) | ||
| V = _op.add(_op.nn.matmul(input_emb, w_V), b_V) | ||
|
|
||
| # massage tensors in preparation for batched matmul | ||
| def massage(tensor): | ||
| tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) | ||
|
|
||
| # (batch_size, num_heads, seq_len, head_size) | ||
| tensor = _op.transpose(tensor, axes=[0, 2, 1, 3]) | ||
|
|
||
| # (batch_size * num_heads, seq_len, head_size) | ||
| return _op.reverse_reshape(tensor, (-1, 0, 0)) | ||
|
|
||
| Q = massage(Q) | ||
| K = massage(K) | ||
| V = massage(V) | ||
|
|
||
| K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size)) | ||
| V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size)) | ||
| present = _op.stack([K_present, V_present], axis=0) | ||
|
|
||
| att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True) | ||
| score_dtype = infer_type(att_scores).checked_type.dtype | ||
| att_scores = _op.divide( | ||
| att_scores, | ||
| _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype), | ||
| ) | ||
| att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) | ||
|
|
||
| # build the attention mask | ||
| att_mask = _op.cast(mask_index, score_dtype) | ||
| att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2) | ||
| att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask) | ||
| att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype)) | ||
|
|
||
| # apply the mask | ||
| att_scores = _op.add(att_scores, att_mask) | ||
| att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len)) | ||
|
|
||
| att_probs = _op.nn.softmax(att_scores, axis=-1) | ||
|
|
||
| output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False) | ||
| output = _op.reverse_reshape(output, (-1, num_heads, 0, 0)) | ||
| output = _op.transpose(output, axes=[0, 2, 1, 3]) | ||
| output = _op.reshape(output, (0, 0, out_hidden)) | ||
|
|
||
| return _expr.TupleWrapper(_expr.Tuple([output, present]), 2) | ||
|
|
||
|
|
||
| class Gemm(OnnxOpConverter): | ||
| """Operator converter for Gemm.""" | ||
|
|
||
|
|
@@ -4737,6 +4949,12 @@ def _get_convert_map(opset): | |
| "Elu": Elu.get_converter(opset), | ||
| "Gelu": Gelu.get_converter(opset), | ||
| "BiasGelu": BiasGelu.get_converter(opset), | ||
| # TODO: We need a better way to handle different domains, in case | ||
| # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention | ||
| # are in the `com.microsoft` domain. | ||
| "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset), | ||
AndrewZhaoLuo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset), | ||
| "Attention": Attention.get_converter(opset), | ||
| "Exp": Renamer("exp"), | ||
| "Greater": Renamer("greater"), | ||
| "GreaterOrEqual": Renamer("greater_equal"), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.