diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index c4eb7774d756..ffd31317e9f5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1297,7 +1297,213 @@ def _impl_v1(cls, inputs, attr, params): return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3) -class Attention(OnnxOpConverter): +class OrtAttentionBase: + """ + Base class for Attention and QAttention from Microsoft onnxruntime contrib opset. + """ + + @classmethod + def _check_input_embeddings(cls, input_emb, valid_types, **kwargs): + assert infer_type(input_emb).checked_type.dtype in valid_types + assert ( + len(infer_shape(input_emb)) == 3 + ), "Input should be 3D tensor with shape (batch_size, sequence_length, input_hidden_size)" + (batch_size, seq_len, input_hidden) = infer_shape(input_emb) + assert input_hidden > 0, ( + "The weight tensor has (input_hidden_size, 3 * output_hidden_size) shape, so it doesn't" + f" make sense to have ({input_hidden}, 3 * output_hidden_size) weight tensor." + ) + assert seq_len > 0, ( + "The output tensor has (batch_size, sequence_length, hidden_size) shape," + f" so it doesn't make sense to have (batch_size, {seq_len}, hidden_size) output." + ) + + return batch_size, seq_len, input_hidden + + @classmethod + def _check_weights(cls, weight, valid_types, **kwargs): + assert infer_type(weight).checked_type.dtype in valid_types + assert len(infer_shape(weight)) == 2, ( + "Weight should be 2D input tensor with shape (input_hidden_size, 3 * hidden_size), " + "hidden_size = num_heads * head_size" + ) + (input_hidden_weight, out_hidden_x3) = infer_shape(weight) + assert kwargs["input_hidden"] == input_hidden_weight + assert out_hidden_x3 % 3 == 0, "output hidden shape should be divisible by 3: W_Q, W_K, W_V" + out_hidden = out_hidden_x3 // 3 + assert ( + out_hidden % kwargs["num_heads"] == 0 + ), "output hidden size should be divisible by number of attention heads" + head_size = out_hidden // kwargs["num_heads"] + + return out_hidden_x3, out_hidden, head_size + + @classmethod + def _check_bias(cls, bias, valid_types, **kwargs): + assert infer_type(bias).checked_type.dtype in valid_types + assert ( + len(infer_shape(bias)) == 1 + ), "Bias should be 1D input tensor with shape (3 * hidden_size)" + (out_hidden_x3_bias,) = infer_shape(bias) + assert kwargs["out_hidden_x3"] == out_hidden_x3_bias + + @classmethod + def _check_mask_index(cls, mask_index, valid_types, **kwargs): + assert infer_type(mask_index).checked_type.dtype in valid_types + mask_index_shape = infer_shape(mask_index) + assert ( + len(mask_index_shape) == 2 + and mask_index_shape[0] == kwargs["batch_size"] + and mask_index_shape[1] >= kwargs["seq_len"] + ), "currently only support (batch_size, past_sequence_len + sequence_length) mask index" + + return mask_index_shape[1] + + @classmethod + def _check_past(cls, past, valid_types, **kwargs): + assert infer_type(past).checked_type.dtype in valid_types + past_shape = infer_shape(past) + assert len(past_shape) == 5, "past should be 5D tensor" + assert ( + past_shape[0] == 2 + and past_shape[1] == kwargs["batch_size"] + and past_shape[2] == kwargs["num_heads"] + and past_shape[3] + kwargs["seq_len"] == kwargs["total_seq_len"] + and past_shape[4] == kwargs["head_size"] + ) + past_seq_len = past_shape[3] + return past_seq_len + + @classmethod + def _split_into_heads(cls, tensor, batch_size, seq_len, num_heads, head_size): + """ + In the implementation of Multi-head attention we just split queries, keys, and values + we compute for a single-head attention into several parts: + (batch_size, num_heads, seq_len, head_size) + """ + tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) + + # (batch_size, num_heads, seq_len, head_size) + tensor = _op.transpose(tensor, axes=[0, 2, 1, 3]) + + return tensor + + @classmethod + def _merge_first_dimensions(cls, tensor): + """ + nn.batch_matmul is expecting 3D tensor: + (batch_size * num_heads, past_seq_len + seq_len, head_size) + """ + return _op.reverse_reshape(tensor, (-1, 0, 0)) + + @classmethod + def _create_unidirectional_mask(cls, left_value, right_value, past_seq_len, seq_len, dtype): + """ + [lhs rhs rhs ... rhs rhs] + [lhs lhs rhs ... rhs rhs] + [lhs lhs lhs ... rhs rhs] + ......................... + [lhs lhs lhs ... lhs rhs] + [lhs lhs lhs ... lhs lhs] + """ + numpy_unidirectional_mask = np.array( + [ + np.concatenate( + [ + np.full(past_seq_len + s_i + 1, left_value), + np.full(seq_len - s_i - 1, right_value), + ] + ) + for s_i in range(seq_len) + ] + ) + unidirectional_mask = _op.const(numpy_unidirectional_mask, dtype=dtype) + unidirectional_mask = _op.expand_dims(unidirectional_mask, 0, num_newaxis=2) + + return unidirectional_mask + + @classmethod + def _compute_attention(cls, Q, K, V, mask_index, **kwargs): + # Compute Attention scores + att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True) + score_dtype = infer_type(att_scores).checked_type.dtype + att_scores = _op.divide( + att_scores, + _op.const( + np.sqrt(kwargs["head_size"]), dtype=infer_type(att_scores).checked_type.dtype + ), + ) + att_scores = _op.reshape( + att_scores, + ( + kwargs["batch_size"], + kwargs["num_heads"], + kwargs["seq_len"], + kwargs["past_seq_len"] + kwargs["seq_len"], + ), + ) + + # Build the attention mask + att_mask = _op.cast(mask_index, score_dtype) + # Attention mask has value 0 or 1. Here we convert 0 to -10000, and 1 to 0. + att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask) + att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype)) + # Expand for att_scores broadcast + # (batch_size, past_seq_len + seq_len) -> (batch_size, 1, seq_len, past_seq_len + seq_len) + att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2) + att_mask = _op.concatenate([att_mask] * kwargs["seq_len"], axis=2) + + if kwargs["unidirectional"]: + att_mask = _op.add( + att_mask, + cls._create_unidirectional_mask( + 0, -10000, kwargs["past_seq_len"], kwargs["seq_len"], score_dtype + ), + ) + + # Apply the mask + att_scores = _op.add(att_scores, att_mask) + # TODO(agladyshev): + # Comment from ORT source code (onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h): + # "Fix unidirectional mask to be parity with huggingface implementation" + if kwargs["unidirectional"]: + att_scores = _op.multiply( + att_scores, + cls._create_unidirectional_mask( + 1, 0, kwargs["past_seq_len"], kwargs["seq_len"], score_dtype + ), + ) + att_scores = _op.add( + att_scores, + _op.multiply( + att_mask, + cls._create_unidirectional_mask( + 0, 1, kwargs["past_seq_len"], kwargs["seq_len"], score_dtype + ), + ), + ) + + # Compute Softmax + att_scores = _op.reshape( + att_scores, + ( + kwargs["batch_size"] * kwargs["num_heads"], + kwargs["seq_len"], + kwargs["past_seq_len"] + kwargs["seq_len"], + ), + ) + att_probs = _op.nn.softmax(att_scores, axis=-1) + + # Compute output + output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False) + output = _op.reverse_reshape(output, (-1, kwargs["num_heads"], 0, 0)) + output = _op.transpose(output, axes=[0, 2, 1, 3]) + output = _op.reshape(output, (0, 0, kwargs["out_hidden"])) + + return output + + +class Attention(OrtAttentionBase, OnnxOpConverter): """Operator converter for Attention from Microsoft onnxruntime contrib opset. This is the self-attention mechanism used in transformer models. @@ -1305,16 +1511,30 @@ class Attention(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + # ************************* Read attrs ************************* num_heads = attr["num_heads"] + unidirectional = attr["unidirectional"] + + assert ( + "past_present_share_buffer" not in attr + ), "share past and present buffers are not currently supported" assert ( "qkv_hidden_sizes" not in attr ), "different hidden sizes for Q, K, V are not currently supported" - assert "unidirectional" not in attr, "unidirectional attention not current supported" + # ************************* Read inputs ************************* # (batch, seq, in_hidden) input_emb = inputs[0] - # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size + # TODO(agladyshev): + # ORT documentation says: + # The weights for input projection of Q, K and V are merged. + # The data is stacked on the second dimension. + # Its shape is (input_hidden_size, hidden_size + hidden_size + v_hidden_size). + # Here hidden_size is the hidden dimension of Q and K, and v_hidden_size is that of V. + # However, in our case, we consider that hidden_size == v_hidden_size. + # Therefore, weight has the following shape: + # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size weight = inputs[1] # (3 * out_hidden,) @@ -1325,7 +1545,7 @@ def _impl_v1(cls, inputs, attr, params): # 3. ( batch, seq, past_seq + seq,) # 4. ( batch,) # 5. (2 * batch,) - # For now, we only support case 2. + # TODO: For now, we only support case 2. mask_index = inputs[3] # (2, batch, num_heads, past_seq, head_size) @@ -1333,28 +1553,47 @@ def _impl_v1(cls, inputs, attr, params): # (batch, num_heads, seq, seq) extra_add = inputs[5] + assert extra_add is None, "extra add to QxK not currently supported" - (batch_size, seq_len, _) = infer_shape(input_emb) - (out_hidden_x3,) = infer_shape(bias) - assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3" - out_hidden = out_hidden_x3 // 3 - assert ( - out_hidden % num_heads == 0 - ), "output hidden size should be divisible by number of attention heads" - head_size = out_hidden // num_heads + # When past_present_share_buffer is used, + # it is required to specify past_sequence_length (could be 0) + past_seq_len = inputs[6] + assert past_seq_len is None, "past sequence length not currently supported" + + # ************************* Parse inputs ************************* + t = ["float32", "float16"] + m = ["int32"] + + # input + batch_size, seq_len, input_hidden = cls._check_input_embeddings(input_emb, t) + + # weight + out_hidden_x3, out_hidden, head_size = cls._check_weights( + weight, t, num_heads=num_heads, input_hidden=input_hidden + ) + # bias + cls._check_bias(bias, t, out_hidden_x3=out_hidden_x3) + + # mask_index assert ( mask_index is not None ), "Attention import currently only supports required mask_index" - mask_index_shape = infer_shape(mask_index) - assert ( - len(mask_index_shape) == 2 - and mask_index_shape[0] == batch_size - and mask_index_shape[1] == seq_len - ), "currently only support (batch_size, sequence_length) mask index" + total_seq_len = cls._check_mask_index(mask_index, m, batch_size=batch_size, seq_len=seq_len) - assert past is None, "past K, V state is not currently supported" - assert extra_add is None, "extra add to QxK not currently supported" + # past + if past_seq_len is None: + past_seq_len = 0 + if past is not None: + past_seq_len = cls._check_past( + past, + t, + batch_size=batch_size, + num_heads=num_heads, + seq_len=seq_len, + total_seq_len=total_seq_len, + head_size=head_size, + ) # split weight and biases and do the matmuls w_Q, w_K, w_V = _op.split(weight, 3, axis=1) @@ -1365,53 +1604,44 @@ def _impl_v1(cls, inputs, attr, params): K = _op.add(_op.nn.matmul(input_emb, w_K), b_K) V = _op.add(_op.nn.matmul(input_emb, w_V), b_V) - # massage tensors in preparation for batched matmul - def massage(tensor): - tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) - - # (batch_size, num_heads, seq_len, head_size) - tensor = _op.transpose(tensor, axes=[0, 2, 1, 3]) + Q = cls._split_into_heads(Q, batch_size, seq_len, num_heads, head_size) + K = cls._split_into_heads(K, batch_size, seq_len, num_heads, head_size) + V = cls._split_into_heads(V, batch_size, seq_len, num_heads, head_size) - # (batch_size * num_heads, seq_len, head_size) - return _op.reverse_reshape(tensor, (-1, 0, 0)) - - Q = massage(Q) - K = massage(K) - V = massage(V) + # Concatenate (past_K, past_V) with (K, V) by sequence axis: + # (batch_size, num_heads, past_sequence_length + sequence_length, head_size) + if past is not None and past_seq_len > 0: + K_past, V_past = _op.split(past, 2, axis=0) + K = _op.concatenate([_op.squeeze(K_past, axis=[0]), K], axis=2) + V = _op.concatenate([_op.squeeze(V_past, axis=[0]), V], axis=2) - K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size)) - V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size)) - present = _op.stack([K_present, V_present], axis=0) + # Prepare present state for Key and Value with shape + # (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) + present = _op.stack([K, V], axis=0) - att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True) - score_dtype = infer_type(att_scores).checked_type.dtype - att_scores = _op.divide( - att_scores, - _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype), + Q = cls._merge_first_dimensions(Q) + K = cls._merge_first_dimensions(K) + V = cls._merge_first_dimensions(V) + + # Compute Attention output + output = cls._compute_attention( + Q, + K, + V, + mask_index, + unidirectional=unidirectional, + batch_size=batch_size, + out_hidden=out_hidden, + num_heads=num_heads, + head_size=head_size, + seq_len=seq_len, + past_seq_len=past_seq_len, ) - att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) - - # build the attention mask - att_mask = _op.cast(mask_index, score_dtype) - att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2) - att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask) - att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype)) - - # apply the mask - att_scores = _op.add(att_scores, att_mask) - att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len)) - - att_probs = _op.nn.softmax(att_scores, axis=-1) - - output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False) - output = _op.reverse_reshape(output, (-1, num_heads, 0, 0)) - output = _op.transpose(output, axes=[0, 2, 1, 3]) - output = _op.reshape(output, (0, 0, out_hidden)) return _expr.TupleWrapper(_expr.Tuple([output, present]), 2) -class QAttention(OnnxOpConverter): +class QAttention(OrtAttentionBase, OnnxOpConverter): """Operator converter for QAttention from Microsoft onnxruntime contrib opset. This is the self-attention mechanism used in transformer models. @@ -1473,42 +1703,15 @@ def _impl_v1(cls, inputs, attr, params): t4 = ["int32"] # input - assert infer_type(input_emb).checked_type.dtype in t1 - assert ( - len(infer_shape(input_emb)) == 3 - ), "Input should be 3D tensor with shape (batch_size, sequence_length, input_hidden_size)" - (batch_size, seq_len, input_hidden) = infer_shape(input_emb) - assert input_hidden > 0, ( - "The weight tensor has (input_hidden_size, 3 * output_hidden_size) shape, so it doesn't" - f" make sense to have ({input_hidden}, 3 * output_hidden_size) weight tensor." - ) - assert seq_len > 0, ( - "The output tensor has (batch_size, sequence_length, hidden_size) shape," - f" so it doesn't make sense to have (batch_size, {seq_len}, hidden_size) output." - ) + batch_size, seq_len, input_hidden = cls._check_input_embeddings(input_emb, t1) # weight - assert infer_type(weight).checked_type.dtype in t2 - assert len(infer_shape(weight)) == 2, ( - "Weight should be 2D input tensor with shape (input_hidden_size, 3 * hidden_size), " - "hidden_size = num_heads * head_size" + out_hidden_x3, out_hidden, head_size = cls._check_weights( + weight, t2, num_heads=num_heads, input_hidden=input_hidden ) - (input_hidden_weight, out_hidden_x3) = infer_shape(weight) - assert input_hidden == input_hidden_weight - assert out_hidden_x3 % 3 == 0, "output hidden shape should be divisible by 3: W_Q, W_K, W_V" - out_hidden = out_hidden_x3 // 3 - assert ( - out_hidden % num_heads == 0 - ), "output hidden size should be divisible by number of attention heads" - head_size = out_hidden // num_heads # bias - assert infer_type(bias).checked_type.dtype in t3 - assert ( - len(infer_shape(bias)) == 1 - ), "Bias should be 1D input tensor with shape (3 * hidden_size)" - (out_hidden_x3_bias,) = infer_shape(bias) - assert out_hidden_x3 == out_hidden_x3_bias + cls._check_bias(bias, t3, out_hidden_x3=out_hidden_x3) # input_scale assert infer_type(input_scale).checked_type.dtype in t3 @@ -1527,13 +1730,9 @@ def _impl_v1(cls, inputs, attr, params): assert ( mask_index is not None ), "Attention import currently only supports required mask_index" - assert infer_type(mask_index).checked_type.dtype in t4 - mask_index_shape = infer_shape(mask_index) - assert ( - len(mask_index_shape) == 2 - and mask_index_shape[0] == batch_size - and mask_index_shape[1] >= seq_len - ), "currently only support (batch_size, sequence_length) mask index" + total_seq_len = cls._check_mask_index( + mask_index, t4, batch_size=batch_size, seq_len=seq_len + ) # TODO(agladyshev): int32 required for qnn.batch_matmul (QnnBatchMatmulRel) zero_point_zero = _expr.const(0, "int32") @@ -1557,17 +1756,15 @@ def _impl_v1(cls, inputs, attr, params): # past (2, batch_size, num_heads, past_sequence_length, head_size) past_seq_len = 0 if past is not None: - assert infer_type(past).checked_type.dtype in t3 - past_shape = infer_shape(past) - assert len(past_shape) == 5, "past should be 5D tensor" - assert ( - past_shape[0] == 2 - and past_shape[1] == batch_size - and past_shape[2] == num_heads - and past_shape[3] + seq_len == mask_index_shape[1] - and past_shape[4] == head_size + past_seq_len = cls._check_past( + past, + t3, + batch_size=batch_size, + num_heads=num_heads, + seq_len=seq_len, + total_seq_len=total_seq_len, + head_size=head_size, ) - past_seq_len = past_shape[3] # ************************* Create Relay ************************* # Add batch dimension for QNN Batch Matmul @@ -1604,22 +1801,9 @@ def qmatmul_dequantize_bias( input_emb, w_V, input_scale, weight_scale, input_zero_point, weight_zero_point, b_V ) - def split_into_heads(tensor): - """ - In the implementation of Multi-head attention we just split queries, keys, and values - we compute for a single-head attention into several parts: - (batch_size, num_heads, seq_len, head_size) - """ - tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) - - # (batch_size, num_heads, seq_len, head_size) - tensor = _op.transpose(tensor, axes=[0, 2, 1, 3]) - - return tensor - - Q = split_into_heads(Q) - K = split_into_heads(K) - V = split_into_heads(V) + Q = cls._split_into_heads(Q, batch_size, seq_len, num_heads, head_size) + K = cls._split_into_heads(K, batch_size, seq_len, num_heads, head_size) + V = cls._split_into_heads(V, batch_size, seq_len, num_heads, head_size) # Concatenate (past_K, past_V) with (K, V) by sequence axis: # (batch_size, num_heads, past_sequence_length + sequence_length, head_size) @@ -1632,78 +1816,25 @@ def split_into_heads(tensor): # (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size) present = _op.stack([K, V], axis=0) - def merge_first_dimensions(tensor): - """ - nn.batch_matmul is expecting 3D tensor: - (batch_size * num_heads, past_seq_len + seq_len, head_size) - """ - return _op.reverse_reshape(tensor, (-1, 0, 0)) - - Q = merge_first_dimensions(Q) - K = merge_first_dimensions(K) - V = merge_first_dimensions(V) - - att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True) - score_dtype = infer_type(att_scores).checked_type.dtype - att_scores = _op.divide( - att_scores, - _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype), - ) - att_scores = _op.reshape( - att_scores, (batch_size, num_heads, seq_len, past_seq_len + seq_len) + Q = cls._merge_first_dimensions(Q) + K = cls._merge_first_dimensions(K) + V = cls._merge_first_dimensions(V) + + # Compute Attention output + output = cls._compute_attention( + Q, + K, + V, + mask_index, + unidirectional=unidirectional, + batch_size=batch_size, + out_hidden=out_hidden, + num_heads=num_heads, + head_size=head_size, + seq_len=seq_len, + past_seq_len=past_seq_len, ) - # Build the attention mask - att_mask = _op.cast(mask_index, score_dtype) - # Attention mask has value 0 or 1. Here we convert 0 to -10000, and 1 to 0. - att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask) - att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype)) - # Expand for att_scores broadcast - # (batch_size, past_seq_len + seq_len) -> (batch_size, 1, seq_len, past_seq_len + seq_len) - att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2) - att_mask = _op.concatenate([att_mask] * seq_len, axis=2) - - def create_unidirectional_mask(left_value, right_value): - numpy_unidirectional_mask = np.array( - [ - np.concatenate( - [ - np.full(past_seq_len + s_i + 1, left_value), - np.full(seq_len - s_i - 1, right_value), - ] - ) - for s_i in range(seq_len) - ] - ) - unidirectional_mask = _op.const(numpy_unidirectional_mask, dtype=score_dtype) - unidirectional_mask = _op.expand_dims(unidirectional_mask, 0, num_newaxis=2) - - return unidirectional_mask - - if unidirectional: - att_mask = _op.add(att_mask, create_unidirectional_mask(0, -10000)) - - # Apply the mask - att_scores = _op.add(att_scores, att_mask) - # TODO(agladyshev): - # Comment from ORT source code (onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h): - # "Fix unidirectional mask to be parity with huggingface implementation" - if unidirectional: - att_scores = _op.multiply(att_scores, create_unidirectional_mask(1, 0)) - att_scores = _op.add(att_scores, create_unidirectional_mask(0, -10000)) - - # Compute Softmax - att_scores = _op.reshape( - att_scores, (batch_size * num_heads, seq_len, past_seq_len + seq_len) - ) - att_probs = _op.nn.softmax(att_scores, axis=-1) - - # Compute output - output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False) - output = _op.reverse_reshape(output, (-1, num_heads, 0, 0)) - output = _op.transpose(output, axes=[0, 2, 1, 3]) - output = _op.reshape(output, (0, 0, out_hidden)) - return _expr.TupleWrapper(_expr.Tuple([output, present]), 2) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index a84de82f3bab..f5b5f7c65cb5 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -5878,30 +5878,47 @@ def verify_embedlayernormalization( def test_attention(target, dev): """test_attention""" - def verify_attention(input_, weight, bias, mask_index, num_heads): + def verify_attention(_unidirectional, _input, _weight, _bias, _mask_index=None, _past=None): + input_names = ["input", "weight", "bias"] + if _mask_index is not None: + input_names.append("mask_index") + if _past is not None: + input_names.append("past") + node = onnx.helper.make_node( "Attention", - inputs=["input", "weight", "bias", "mask_index"], + inputs=input_names, outputs=["output", "present"], domain="com.microsoft", num_heads=num_heads, + unidirectional=_unidirectional, ) + past_shape = (2, batch_size, num_heads, past_sequence_length, head_size) present_output_shape = (2, batch_size, num_heads, sequence_length, head_size) + inputs_info = [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(_input.shape)), + helper.make_tensor_value_info("weight", TensorProto.FLOAT, list(_weight.shape)), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(_bias.shape)), + ] + if _mask_index is not None: + inputs_info.append( + helper.make_tensor_value_info( + "mask_index", TensorProto.INT32, list(_mask_index.shape) + ), + ) + if _past is not None: + inputs_info.append( + helper.make_tensor_value_info("past", TensorProto.FLOAT, list(past_shape)) + ) + graph = helper.make_graph( [node], "attention_test", - inputs=[ - helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input_.shape)), - helper.make_tensor_value_info("weight", TensorProto.FLOAT, list(weight.shape)), - helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)), - helper.make_tensor_value_info( - "mask_index", TensorProto.INT32, list(mask_index.shape) - ), - ], + inputs=inputs_info, outputs=[ - helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input_.shape)), + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(_input.shape)), helper.make_tensor_value_info( "present", TensorProto.FLOAT, list(present_output_shape) ), @@ -5910,31 +5927,58 @@ def verify_attention(input_, weight, bias, mask_index, num_heads): model = helper.make_model(graph, producer_name="attention_test") + inputs = [_input, _weight, _bias] + if _mask_index is not None: + inputs.append(_mask_index) + if _past is not None: + inputs.append(_past) + # "present" output should be nullptr when the "past" input isn't included, # but ort requires an output shape to be specified? verify_with_ort_with_inputs( model, - [input_, weight, bias, mask_index], - [input_.shape, present_output_shape], + inputs, + [_input.shape, present_output_shape], target=target, dev=dev, rtol=1e-4, atol=1e-4, ) - hidden_size = 384 - batch_size = 4 - sequence_length = 4 - num_heads = 12 - head_size = 32 + batch_size = 11 + num_heads = 13 + head_size = 37 + sequence_length = 7 + input_hidden_size = 147 + weight_hidden_size = num_heads * head_size + past_sequence_length = 17 - dtype = "float32" - input_array = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype) - weight = np.random.normal(size=(hidden_size, 3 * hidden_size)).astype(dtype) * 0.1 - bias = np.random.randn(3 * hidden_size).astype(dtype) - mask_index = np.full((batch_size, sequence_length), 1).astype("int32") + total_sequence_length = past_sequence_length + sequence_length - verify_attention(input_array, weight, bias, mask_index, num_heads) + # Required inputs + input_array = np.random.normal(size=(batch_size, sequence_length, input_hidden_size)).astype( + "float32" + ) + weight = ( + np.random.normal(size=(input_hidden_size, 3 * weight_hidden_size)).astype("float32") * 0.1 + ) + bias = np.random.randn(3 * weight_hidden_size).astype("float32") + + # Optional inputs + past = np.random.random((2, batch_size, num_heads, past_sequence_length, head_size)).astype( + "float32" + ) + + for unidirectional in [0, 1]: + for have_past in [False, True]: + if not have_past: + mask_index = np.random.randint(0, 2, (batch_size, sequence_length)).astype("int32") + verify_attention(unidirectional, input_array, weight, bias, mask_index) + else: + mask_index = np.random.randint(0, 2, (batch_size, total_sequence_length)).astype( + "int32" + ) + verify_attention(unidirectional, input_array, weight, bias, mask_index, past) @tvm.testing.parametrize_targets