diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 7dcb9952c7fb..3450d489af70 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -329,6 +329,22 @@ def flatten_to_nd(x, x_shape, nd=3): return _op.nn.dense(inputs[0], input_1_t, out_dtype=out_dtype) +def layer_norm(x, eps, gamma, beta): + """Common function to handle layer norm""" + eps_dtype = infer_type(x).checked_type.dtype + + u, s = _op.mean_variance(x, axis=-1, keepdims=True) + output = _op.divide( + _op.subtract(x, u), + _op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))), + ) + output = _op.multiply(output, gamma) + if beta is not None: + output = _op.add(output, beta) + + return output + + class OnnxOpConverter(object): """A helper class for holding onnx op converters.""" @@ -807,9 +823,10 @@ def _impl_v1(cls, inputs, attr, params): x = inputs[0] # Declare consts - half = _expr.const(0.5) - one = _expr.const(1.0) - sqrt2 = _expr.const(math.sqrt(2)) + const_dtype = infer_type(x).checked_type.dtype + half = _expr.const(0.5, dtype=const_dtype) + one = _expr.const(1.0, dtype=const_dtype) + sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype) # Compute gelu term1 = _op.multiply(half, x) @@ -836,6 +853,201 @@ def _impl_v1(cls, inputs, attr, params): return Gelu._impl_v1([inp], attr, params) +class EmbedLayerNormalization(OnnxOpConverter): + """Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset. + + This layer embeds the input tokens, sums them, and applies layer normalization. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + input_ids = inputs[0] + segment_ids = inputs[1] + word_emb = inputs[2] + pos_emb = inputs[3] + segment_emb = inputs[4] + gamma = inputs[5] + beta = inputs[6] + + mask = inputs[7] + pos_ids = inputs[8] + + eps = attr.get("epsilon", 1e-12) + + (batch_size, seq_len) = infer_shape(input_ids) + + if segment_ids: + assert segment_emb + + if pos_ids is None: + pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int32") + + word_vec = _op.take(word_emb, input_ids, axis=0) + segment_vec = _op.take(segment_emb, segment_ids, axis=0) + pos_vec = _op.take(pos_emb, pos_ids, axis=0) + + vec_sum = _op.add(word_vec, pos_vec) + if segment_ids: + vec_sum = _op.add(vec_sum, segment_vec) + + ln = layer_norm(vec_sum, eps, gamma, beta) + + mask_index = _op.const(np.zeros((batch_size,), dtype="int32")) + if mask: + # calculate number of words per sentence + mask_index = _op.sum(mask, axis=1) + + # TODO(@anwang2009): onnxruntime v1.10.0 requires a third output of vec_sum + return _expr.TupleWrapper(_expr.Tuple([ln, mask_index]), 2) + + +class SkipLayerNormalization(OnnxOpConverter): + """Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset. + + This layer sums the two input tensors (along with optional bias), and applies layer + normalization. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + data = inputs[0] + skip = inputs[1] + gamma = inputs[2] + beta = inputs[3] + bias = inputs[4] + + assert ( + beta is not None and bias is not None + ), "SkipLayerNormalization import currently only supports required beta and bias" + + eps = attr.get("epsilon", 1e-12) + + x = _op.add(data, skip) + if bias is not None: + x = _op.add(x, bias) + + output = layer_norm(x, eps, gamma, beta) + + # onnxruntime doesn't compute the other outputs, despite the documentation + placeholder = _op.const(0, dtype="float32") + + return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3) + + +class Attention(OnnxOpConverter): + """Operator converter for Attention from Microsoft onnxruntime contrib opset. + + This is the self-attention mechanism used in transformer models. + """ + + @classmethod + def _impl_v1(cls, inputs, attr, params): + num_heads = attr["num_heads"] + assert ( + "qkv_hidden_sizes" not in attr + ), "different hidden sizes for Q, K, V are not currently supported" + assert "unidirectional" not in attr, "unidirectional attention not current supported" + + # (batch, seq, in_hidden) + input_emb = inputs[0] + + # (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size + weight = inputs[1] + + # (3 * out_hidden,) + bias = inputs[2] + + # 1. ( batch, 1, max_seq, max_seq) + # 2. ( batch, past_seq + seq,) + # 3. ( batch, seq, past_seq + seq,) + # 4. ( batch,) + # 5. (2 * batch,) + # For now, we only support case 2. + mask_index = inputs[3] + + # (2, batch, num_heads, past_seq, head_size) + past = inputs[4] + + # (batch, num_heads, seq, seq) + extra_add = inputs[5] + + (batch_size, seq_len, _) = infer_shape(input_emb) + (out_hidden_x3,) = infer_shape(bias) + assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3" + out_hidden = out_hidden_x3 // 3 + assert ( + out_hidden % num_heads == 0 + ), "output hidden size should be divisible by number of attention heads" + head_size = out_hidden // num_heads + + assert ( + mask_index is not None + ), "Attention import currently only supports required mask_index" + mask_index_shape = infer_shape(mask_index) + assert ( + len(mask_index_shape) == 2 + and mask_index_shape[0] == batch_size + and mask_index_shape[1] == seq_len + ), "currently only support (batch_size, sequence_length) mask index" + + assert past is None, "past K, V state is not currently supported" + assert extra_add is None, "extra add to QxK not currently supported" + + # split weight and biases and do the matmuls + w_Q, w_K, w_V = _op.split(weight, 3, axis=1) + b_Q, b_K, b_V = _op.split(bias, 3, axis=0) + # need to merge batch dimensions since TVM matmul is 2D + input_emb = _op.reverse_reshape(input_emb, (-1, 0)) + Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q) + K = _op.add(_op.nn.matmul(input_emb, w_K), b_K) + V = _op.add(_op.nn.matmul(input_emb, w_V), b_V) + + # massage tensors in preparation for batched matmul + def massage(tensor): + tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) + + # (batch_size, num_heads, seq_len, head_size) + tensor = _op.transpose(tensor, axes=[0, 2, 1, 3]) + + # (batch_size * num_heads, seq_len, head_size) + return _op.reverse_reshape(tensor, (-1, 0, 0)) + + Q = massage(Q) + K = massage(K) + V = massage(V) + + K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size)) + V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size)) + present = _op.stack([K_present, V_present], axis=0) + + att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True) + score_dtype = infer_type(att_scores).checked_type.dtype + att_scores = _op.divide( + att_scores, + _op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype), + ) + att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) + + # build the attention mask + att_mask = _op.cast(mask_index, score_dtype) + att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2) + att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask) + att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype)) + + # apply the mask + att_scores = _op.add(att_scores, att_mask) + att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len)) + + att_probs = _op.nn.softmax(att_scores, axis=-1) + + output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False) + output = _op.reverse_reshape(output, (-1, num_heads, 0, 0)) + output = _op.transpose(output, axes=[0, 2, 1, 3]) + output = _op.reshape(output, (0, 0, out_hidden)) + + return _expr.TupleWrapper(_expr.Tuple([output, present]), 2) + + class Gemm(OnnxOpConverter): """Operator converter for Gemm.""" @@ -4737,6 +4949,12 @@ def _get_convert_map(opset): "Elu": Elu.get_converter(opset), "Gelu": Gelu.get_converter(opset), "BiasGelu": BiasGelu.get_converter(opset), + # TODO: We need a better way to handle different domains, in case + # of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention + # are in the `com.microsoft` domain. + "EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset), + "SkipLayerNormalization": SkipLayerNormalization.get_converter(opset), + "Attention": Attention.get_converter(opset), "Exp": Renamer("exp"), "Greater": Renamer("greater"), "GreaterOrEqual": Renamer("greater_equal"), diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 638b4b8f57eb..9a8647af48c1 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -39,6 +39,10 @@ def get_input_data_shape_dict(graph_def, input_data): shape_dict = {} for i, _ in enumerate(input_data): input_names[i] = graph_def.graph.input[i].name + if input_data[i] is None or input_data[i].shape == (): + # Skip adding input shape data when the input data is None; + # This is to enable optional arguments for onnx operators. + continue shape_dict[input_names[i]] = input_data[i].shape else: input_names = graph_def.graph.input[0].name @@ -5433,6 +5437,221 @@ def verify_biasgelu(x, bias): verify_biasgelu(x, bias) +@tvm.testing.parametrize_targets +def test_embedlayernormalization(target, dev): + def verify_embedlayernormalization( + input_ids, + segment_ids, + word_embedding, + position_embedding, + segment_embedding, + gamma, + beta, + ): + node = onnx.helper.make_node( + "EmbedLayerNormalization", + inputs=[ + "input_ids", + "" if segment_ids is None else "segment_ids", + "word_embedding", + "position_embedding", + "" if segment_embedding is None else "segment_embedding", + "gamma", + "beta", + ], + outputs=["output", "mask_index"], + domain="com.microsoft", + ) + + node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4)) + + segment_ids_shape = [] if segment_ids is None else segment_ids.shape + segment_embedding_shape = [] if segment_embedding is None else segment_embedding.shape + + graph = helper.make_graph( + [node], + "embedlayernormalization_test", + inputs=[ + helper.make_tensor_value_info( + "input_ids", TensorProto.INT32, list(input_ids.shape) + ), + helper.make_tensor_value_info("segment_ids", TensorProto.INT32, segment_ids_shape), + helper.make_tensor_value_info( + "word_embedding", TensorProto.FLOAT, list(word_embedding.shape) + ), + helper.make_tensor_value_info( + "position_embedding", TensorProto.FLOAT, list(position_embedding.shape) + ), + helper.make_tensor_value_info( + "segment_embedding", TensorProto.FLOAT, segment_embedding_shape + ), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)), + ], + outputs=[ + helper.make_tensor_value_info( + "output", TensorProto.FLOAT, list((batch_size, sequence_length, hidden_size)) + ), + helper.make_tensor_value_info("mask_index", TensorProto.INT32, [batch_size]), + ], + ) + + model = helper.make_model(graph, producer_name="embedlayernormalization_test") + + # TODO(@anwang2009): onnxruntime v1.9.0 requires empty list for optional argument, + # but v1.10.0+ requires None instead. + verify_with_ort_with_inputs( + model, + [ + input_ids, + np.empty(0, dtype="int32") if segment_ids is None else segment_ids, + word_embedding, + position_embedding, + np.empty(0, dtype="float32") if segment_embedding is None else segment_embedding, + gamma, + beta, + ], + [ + (batch_size, sequence_length, hidden_size), + batch_size, + ], + target=target, + dev=dev, + rtol=1e-4, + atol=1e-4, + ) + + hidden_size = 384 + batch_size = 4 + sequence_length = 4 + vocab_size = 5 + + input_ids = np.full((batch_size, sequence_length), 3).astype("int32") + segment_ids = np.zeros((batch_size, sequence_length)).astype("int32") + word_embedding = np.full((vocab_size, hidden_size), 1).astype("float32") + position_embedding = np.full((sequence_length, hidden_size), 2).astype("float32") + segment_embedding = np.full((vocab_size, hidden_size), 3).astype("float32") + + gamma = np.random.uniform(0.5, 0.7, hidden_size).astype("float32") + beta = np.random.randn(hidden_size).astype("float32") * 0.1 + + verify_embedlayernormalization( + input_ids, segment_ids, word_embedding, position_embedding, segment_embedding, gamma, beta + ) + + # Test with undefined segment embedding + verify_embedlayernormalization( + input_ids, None, word_embedding, position_embedding, None, gamma, beta + ) + + +@tvm.testing.parametrize_targets +def test_attention(target, dev): + def verify_attention(input, weight, bias, mask_index, num_heads): + node = onnx.helper.make_node( + "Attention", + inputs=["input", "weight", "bias", "mask_index"], + outputs=["output", "present"], + domain="com.microsoft", + num_heads=num_heads, + ) + + present_output_shape = (2, batch_size, num_heads, sequence_length, head_size) + + graph = helper.make_graph( + [node], + "attention_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)), + helper.make_tensor_value_info("weight", TensorProto.FLOAT, list(weight.shape)), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)), + helper.make_tensor_value_info( + "mask_index", TensorProto.INT32, list(mask_index.shape) + ), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)), + helper.make_tensor_value_info( + "present", TensorProto.FLOAT, list(present_output_shape) + ), + ], + ) + + model = helper.make_model(graph, producer_name="attention_test") + + # "present" output should be nullptr when the "past" input isn't included, + # but ort requires an output shape to be specified? + verify_with_ort_with_inputs( + model, + [input, weight, bias, mask_index], + [input.shape, present_output_shape], + target=target, + dev=dev, + rtol=1e-4, + atol=1e-4, + ) + + hidden_size = 384 + batch_size = 4 + sequence_length = 4 + num_heads = 12 + head_size = 32 + + dtype = "float32" + input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype) + weight = np.random.normal(size=(hidden_size, 3 * hidden_size)).astype(dtype) * 0.1 + bias = np.random.randn(3 * hidden_size).astype(dtype) + mask_index = np.full((batch_size, sequence_length), 1).astype("int32") + + verify_attention(input, weight, bias, mask_index, num_heads) + + +@tvm.testing.parametrize_targets +def test_skiplayernormalization(target, dev): + def verify_skiplayernormalization(input, skip, gamma, beta, bias): + node = onnx.helper.make_node( + "SkipLayerNormalization", + inputs=["input", "skip", "gamma", "beta", "bias"], + outputs=["output"], + domain="com.microsoft", + ) + + node.attribute.append(onnx.helper.make_attribute("epsilon", 1e-4)) + + graph = helper.make_graph( + [node], + "skiplayernormalization_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, list(input.shape)), + helper.make_tensor_value_info("skip", TensorProto.FLOAT, list(skip.shape)), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, list(gamma.shape)), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, list(beta.shape)), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, list(bias.shape)), + ], + outputs=[ + helper.make_tensor_value_info("output", TensorProto.FLOAT, list(input.shape)), + ], + ) + + model = helper.make_model(graph, producer_name="skiplayernormalization_test") + verify_with_ort_with_inputs( + model, [input, skip, gamma, beta, bias], [input.shape], target=target, dev=dev + ) + + hidden_size = 384 + batch_size = 4 + sequence_length = 4 + + dtype = "float32" + input = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype) + skip = np.random.random((batch_size, sequence_length, hidden_size)).astype(dtype) + gamma = np.random.uniform(0.5, 0.7, hidden_size).astype(dtype) + beta = np.random.randn(hidden_size).astype(dtype) * 0.1 + bias = np.random.randn(hidden_size).astype(dtype) + + verify_skiplayernormalization(input, skip, gamma, beta, bias) + + @tvm.testing.known_failing_targets("cuda") @tvm.testing.parametrize_targets def test_qlinearconv(target, dev):