-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[ONNX] Add imports for BERT contrib operators #10949
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
c4b3e16
4190bb2
1d3064e
1927414
b718d6a
90bb12f
4998492
768e535
29e0c68
dbb7df3
265e753
43296f9
de7d940
93aceb2
dfba87e
989b412
16a4d09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -807,9 +807,10 @@ def _impl_v1(cls, inputs, attr, params): | |
| x = inputs[0] | ||
|
|
||
| # Declare consts | ||
| half = _expr.const(0.5) | ||
| one = _expr.const(1.0) | ||
| sqrt2 = _expr.const(math.sqrt(2)) | ||
| const_dtype = infer_type(x).checked_type.dtype | ||
| half = _expr.const(0.5, dtype=const_dtype) | ||
| one = _expr.const(1.0, dtype=const_dtype) | ||
| sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype) | ||
|
|
||
| # Compute gelu | ||
| term1 = _op.multiply(half, x) | ||
|
|
@@ -836,6 +837,208 @@ def _impl_v1(cls, inputs, attr, params): | |
| return Gelu._impl_v1([inp], attr, params) | ||
|
|
||
|
|
||
| class EmbedLayerNormalization(OnnxOpConverter): | ||
| """Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset. | ||
|
|
||
| This layer embeds the input tokens, sums them, and applies layer normalization. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _impl_v1(cls, inputs, attr, params): | ||
| input_ids = inputs[0] | ||
| segment_ids = inputs[1] | ||
| word_emb = inputs[2] | ||
| pos_emb = inputs[3] | ||
| segment_emb = inputs[4] | ||
| gamma = inputs[5] | ||
| beta = inputs[6] | ||
|
|
||
| mask = inputs[7] | ||
| pos_ids = inputs[8] | ||
|
|
||
| eps = attr.get('epsilon', 1e-12) | ||
|
|
||
| (batch_size, seq_len) = infer_shape(input_ids) | ||
|
|
||
| if segment_ids: | ||
| assert segment_emb | ||
|
|
||
| if pos_ids is None: | ||
| pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64") | ||
|
|
||
| word_vec = _op.take(word_emb, input_ids, axis=0) | ||
| segment_vec = _op.take(segment_emb, segment_ids, axis=0) | ||
| pos_vec = _op.take(pos_emb, pos_ids, axis=0) | ||
|
|
||
| vec_sum = _op.add(word_vec, pos_vec) | ||
| if segment_ids: | ||
| vec_sum = _op.add(vec_sum, segment_vec) | ||
|
|
||
| eps_dtype = infer_type(word_emb).checked_type.dtype | ||
|
|
||
| u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True) | ||
| ln = _op.divide( | ||
| _op.subtract(vec_sum, u), | ||
| _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), | ||
| ) | ||
| ln = _op.multiply(ln, gamma) + beta | ||
|
|
||
| mask_index = _op.const(np.zeros((batch_size,), dtype="int64")) | ||
| if mask: | ||
| # calculate number of words per sentence | ||
| mask_index = _op.sum(mask, axis=1) | ||
|
|
||
| return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3) | ||
|
|
||
|
|
||
| class SkipLayerNormalization(OnnxOpConverter): | ||
| """Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset. | ||
|
|
||
| This layer sums the two input tensors (along with optional bias), and applies layer | ||
| normalization. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _impl_v1(cls, inputs, attr, params): | ||
| data = inputs[0] | ||
| skip = inputs[1] | ||
| gamma = inputs[2] | ||
| beta = inputs[3] | ||
altanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| bias = inputs[4] | ||
|
|
||
| eps = attr["epsilon"] if "epsilon" in attr else 1e-12 | ||
|
|
||
| x = _op.add(data, skip) | ||
| if bias is not None: | ||
| x = _op.add(x, bias) | ||
|
|
||
| eps_dtype = infer_type(x).checked_type.dtype | ||
|
|
||
| u, s = _op.mean_variance(x, axis=-1, keepdims=True) | ||
| output = _op.divide( | ||
| _op.subtract(x, u), | ||
| _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), | ||
| ) | ||
| output = _op.multiply(output, gamma) | ||
|
||
| if beta: | ||
| output = _op.add(output, beta) | ||
|
|
||
| placeholder = _op.const(0, dtype="float32") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what is this placeholder for? optional returns are mean and inverse standard variance right?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs: https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc |
||
|
|
||
| return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3) | ||
|
|
||
|
|
||
| class Attention(OnnxOpConverter): | ||
| """Operator converter for Attention from Microsoft onnxruntime contrib opset. | ||
|
|
||
| This is the self-attention mechanism used in transformer models. | ||
| """ | ||
|
|
||
| @classmethod | ||
| def _impl_v1(cls, inputs, attr, params): | ||
| num_heads = attr["num_heads"] | ||
| assert ( | ||
| "qkv_hidden_sizes" not in attr | ||
| ), "different hidden sizes for Q, K, V are not currently supported" | ||
| assert "unidirectional" not in attr, "unidirectional attention not current supported" | ||
|
|
||
| # (batch, seq, in_hidden) | ||
| input_emb = inputs[0] | ||
|
|
||
| # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size | ||
| weight = inputs[1] | ||
|
|
||
| # (3 * out_hidden,) | ||
| bias = inputs[2] | ||
|
|
||
| # 1. ( batch, 1, max_seq, max_seq) | ||
| # 2. ( batch, past_seq + seq,) | ||
| # 3. ( batch, seq, past_seq + seq,) | ||
| # 4. ( batch,) | ||
| # 5. (2 * batch,) | ||
| # For now, we only support case 2. | ||
| mask_index = inputs[3] | ||
altanh marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # (2, batch, num_heads, past_seq, head_size) | ||
| past = inputs[4] | ||
|
|
||
| # (batch, num_heads, seq, seq) | ||
| extra_add = inputs[5] | ||
|
|
||
| (batch_size, seq_len, _) = infer_shape(input_emb) | ||
| (out_hidden_x3,) = infer_shape(bias) | ||
| assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3" | ||
| out_hidden = out_hidden_x3 // 3 | ||
| assert ( | ||
| out_hidden % num_heads == 0 | ||
| ), "output hidden size should be divisible by number of attention heads" | ||
| head_size = out_hidden // num_heads | ||
|
|
||
| mask_index_shape = infer_shape(mask_index) | ||
| assert ( | ||
| len(mask_index_shape) == 2 | ||
| and mask_index_shape[0] == batch_size | ||
| and mask_index_shape[1] == seq_len | ||
| ), "currently only support (batch_size, sequence_length) mask index" | ||
|
|
||
| assert past is None, "past K, V state is not currently supported" | ||
| assert extra_add is None, "extra add to QxK not currently supported" | ||
|
|
||
| # split weight and biases and do the matmuls | ||
| w_Q, w_K, w_V = _op.split(weight, 3, axis=1) | ||
| b_Q, b_K, b_V = _op.split(bias, 3, axis=0) | ||
| # need to merge batch dimensions since TVM matmul is 2D | ||
| input_emb = _op.reverse_reshape(input_emb, (-1, 0)) | ||
| Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q) | ||
| K = _op.add(_op.nn.matmul(input_emb, w_K), b_K) | ||
| V = _op.add(_op.nn.matmul(input_emb, w_V), b_V) | ||
|
|
||
| # massage tensors in preparation for batched matmul | ||
| def massage(tensor): | ||
| tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) | ||
|
|
||
| # (batch_size, num_heads, seq_len, head_size) | ||
| tensor = _op.transpose(tensor, axes=[0, 2, 1, 3]) | ||
|
|
||
| # (batch_size * num_heads, seq_len, head_size) | ||
| return _op.reverse_reshape(tensor, (-1, 0, 0)) | ||
|
|
||
| Q = massage(Q) | ||
| K = massage(K) | ||
| V = massage(V) | ||
|
|
||
| K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size)) | ||
| V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size)) | ||
| present = _op.stack([K_present, V_present], axis=0) | ||
|
|
||
| att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True) | ||
| score_dtype = infer_type(att_scores).checked_type.dtype | ||
| att_scores = _op.divide( | ||
| att_scores, | ||
| _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype), | ||
| ) | ||
| att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) | ||
|
|
||
| # build the attention mask | ||
| att_mask = _op.cast(mask_index, score_dtype) | ||
| att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2) | ||
| att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask) | ||
| att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype)) | ||
|
|
||
| # apply the mask | ||
| att_scores = _op.add(att_scores, att_mask) | ||
| att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len)) | ||
|
|
||
| att_probs = _op.nn.softmax(att_scores, axis=-1) | ||
|
|
||
| output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False) | ||
| output = _op.reverse_reshape(output, (-1, num_heads, 0, 0)) | ||
| output = _op.transpose(output, axes=[0, 2, 1, 3]) | ||
| output = _op.reshape(output, (0, 0, out_hidden)) | ||
|
|
||
| return _expr.TupleWrapper(_expr.Tuple([output, present]), 2) | ||
|
|
||
|
|
||
| class Gemm(OnnxOpConverter): | ||
| """Operator converter for Gemm.""" | ||
|
|
||
|
|
@@ -4737,6 +4940,12 @@ def _get_convert_map(opset): | |
| "Elu": Elu.get_converter(opset), | ||
| "Gelu": Gelu.get_converter(opset), | ||
| "BiasGelu": BiasGelu.get_converter(opset), | ||
| # TODO: We need a better way to handle different domains, in case | ||
| # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention | ||
| # are in the `com.microsoft` domain. | ||
| "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset), | ||
AndrewZhaoLuo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset), | ||
| "Attention": Attention.get_converter(opset), | ||
| "Exp": Renamer("exp"), | ||
| "Greater": Renamer("greater"), | ||
| "GreaterOrEqual": Renamer("greater_equal"), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.