Skip to content
215 changes: 212 additions & 3 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,9 +807,10 @@ def _impl_v1(cls, inputs, attr, params):
x = inputs[0]

# Declare consts
half = _expr.const(0.5)
one = _expr.const(1.0)
sqrt2 = _expr.const(math.sqrt(2))
const_dtype = infer_type(x).checked_type.dtype
half = _expr.const(0.5, dtype=const_dtype)
one = _expr.const(1.0, dtype=const_dtype)
sqrt2 = _expr.const(math.sqrt(2), dtype=const_dtype)

# Compute gelu
term1 = _op.multiply(half, x)
Expand All @@ -836,6 +837,208 @@ def _impl_v1(cls, inputs, attr, params):
return Gelu._impl_v1([inp], attr, params)


class EmbedLayerNormalization(OnnxOpConverter):
"""Operator converter for EmbedLayerNormalization from Microsoft onnxruntime contrib opset.

This layer embeds the input tokens, sums them, and applies layer normalization.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
input_ids = inputs[0]
segment_ids = inputs[1]
word_emb = inputs[2]
pos_emb = inputs[3]
segment_emb = inputs[4]
gamma = inputs[5]
beta = inputs[6]

mask = inputs[7]
pos_ids = inputs[8]

eps = attr.get('epsilon', 1e-12)

(batch_size, seq_len) = infer_shape(input_ids)

if segment_ids:
assert segment_emb

if pos_ids is None:
pos_ids = _op.const([list(range(seq_len))] * seq_len, dtype="int64")

word_vec = _op.take(word_emb, input_ids, axis=0)
segment_vec = _op.take(segment_emb, segment_ids, axis=0)
pos_vec = _op.take(pos_emb, pos_ids, axis=0)

vec_sum = _op.add(word_vec, pos_vec)
if segment_ids:
vec_sum = _op.add(vec_sum, segment_vec)

eps_dtype = infer_type(word_emb).checked_type.dtype

u, s = _op.mean_variance(vec_sum, axis=-1, keepdims=True)
ln = _op.divide(
_op.subtract(vec_sum, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
ln = _op.multiply(ln, gamma) + beta

mask_index = _op.const(np.zeros((batch_size,), dtype="int64"))
if mask:
# calculate number of words per sentence
mask_index = _op.sum(mask, axis=1)

return _expr.TupleWrapper(_expr.Tuple([ln, mask_index, vec_sum]), 3)


class SkipLayerNormalization(OnnxOpConverter):
"""Operator converter for SkipLayerNormalization from Microsoft onnxruntime contrib opset.

This layer sums the two input tensors (along with optional bias), and applies layer
normalization.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
data = inputs[0]
skip = inputs[1]
gamma = inputs[2]
beta = inputs[3]
bias = inputs[4]

eps = attr["epsilon"] if "epsilon" in attr else 1e-12

x = _op.add(data, skip)
if bias is not None:
x = _op.add(x, bias)

eps_dtype = infer_type(x).checked_type.dtype

u, s = _op.mean_variance(x, axis=-1, keepdims=True)
output = _op.divide(
_op.subtract(x, u),
_op.sqrt(_op.add(s, _op.const(eps, dtype=eps_dtype))),
)
output = _op.multiply(output, gamma)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is basically same normalization calculation as 877-886 above right? if it's easy, can we pull it out into a common helper function?

if beta:
output = _op.add(output, beta)

placeholder = _op.const(0, dtype="float32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this placeholder for? optional returns are mean and inverse standard variance right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's true according to the documentation, however both CUDA and C++ onnxruntime implementations of the kernels do not actually ever return or calculate values for these outputs:

https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc


return _expr.TupleWrapper(_expr.Tuple([output, placeholder, placeholder]), 3)


class Attention(OnnxOpConverter):
"""Operator converter for Attention from Microsoft onnxruntime contrib opset.

This is the self-attention mechanism used in transformer models.
"""

@classmethod
def _impl_v1(cls, inputs, attr, params):
num_heads = attr["num_heads"]
assert (
"qkv_hidden_sizes" not in attr
), "different hidden sizes for Q, K, V are not currently supported"
assert "unidirectional" not in attr, "unidirectional attention not current supported"

# (batch, seq, in_hidden)
input_emb = inputs[0]

# (in_hidden, 3 * out_hidden), where out_hidden = num_heads * head_size
weight = inputs[1]

# (3 * out_hidden,)
bias = inputs[2]

# 1. ( batch, 1, max_seq, max_seq)
# 2. ( batch, past_seq + seq,)
# 3. ( batch, seq, past_seq + seq,)
# 4. ( batch,)
# 5. (2 * batch,)
# For now, we only support case 2.
mask_index = inputs[3]

# (2, batch, num_heads, past_seq, head_size)
past = inputs[4]

# (batch, num_heads, seq, seq)
extra_add = inputs[5]

(batch_size, seq_len, _) = infer_shape(input_emb)
(out_hidden_x3,) = infer_shape(bias)
assert out_hidden_x3 % 3 == 0, "bias shape should be divisible by 3"
out_hidden = out_hidden_x3 // 3
assert (
out_hidden % num_heads == 0
), "output hidden size should be divisible by number of attention heads"
head_size = out_hidden // num_heads

mask_index_shape = infer_shape(mask_index)
assert (
len(mask_index_shape) == 2
and mask_index_shape[0] == batch_size
and mask_index_shape[1] == seq_len
), "currently only support (batch_size, sequence_length) mask index"

assert past is None, "past K, V state is not currently supported"
assert extra_add is None, "extra add to QxK not currently supported"

# split weight and biases and do the matmuls
w_Q, w_K, w_V = _op.split(weight, 3, axis=1)
b_Q, b_K, b_V = _op.split(bias, 3, axis=0)
# need to merge batch dimensions since TVM matmul is 2D
input_emb = _op.reverse_reshape(input_emb, (-1, 0))
Q = _op.add(_op.nn.matmul(input_emb, w_Q), b_Q)
K = _op.add(_op.nn.matmul(input_emb, w_K), b_K)
V = _op.add(_op.nn.matmul(input_emb, w_V), b_V)

# massage tensors in preparation for batched matmul
def massage(tensor):
tensor = _op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))

# (batch_size, num_heads, seq_len, head_size)
tensor = _op.transpose(tensor, axes=[0, 2, 1, 3])

# (batch_size * num_heads, seq_len, head_size)
return _op.reverse_reshape(tensor, (-1, 0, 0))

Q = massage(Q)
K = massage(K)
V = massage(V)

K_present = _op.reshape(K, (batch_size, num_heads, seq_len, head_size))
V_present = _op.reshape(V, (batch_size, num_heads, seq_len, head_size))
present = _op.stack([K_present, V_present], axis=0)

att_scores = _op.nn.batch_matmul(Q, K, transpose_a=False, transpose_b=True)
score_dtype = infer_type(att_scores).checked_type.dtype
att_scores = _op.divide(
att_scores,
_op.const(np.sqrt(head_size), dtype=infer_type(att_scores).checked_type.dtype),
)
att_scores = _op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))

# build the attention mask
att_mask = _op.cast(mask_index, score_dtype)
att_mask = _op.expand_dims(att_mask, 1, num_newaxis=2)
att_mask = _op.subtract(_op.const(1, dtype=score_dtype), att_mask)
att_mask = _op.multiply(att_mask, _op.const(-10000, dtype=score_dtype))

# apply the mask
att_scores = _op.add(att_scores, att_mask)
att_scores = _op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))

att_probs = _op.nn.softmax(att_scores, axis=-1)

output = _op.nn.batch_matmul(att_probs, V, transpose_a=False, transpose_b=False)
output = _op.reverse_reshape(output, (-1, num_heads, 0, 0))
output = _op.transpose(output, axes=[0, 2, 1, 3])
output = _op.reshape(output, (0, 0, out_hidden))

return _expr.TupleWrapper(_expr.Tuple([output, present]), 2)


class Gemm(OnnxOpConverter):
"""Operator converter for Gemm."""

Expand Down Expand Up @@ -4737,6 +4940,12 @@ def _get_convert_map(opset):
"Elu": Elu.get_converter(opset),
"Gelu": Gelu.get_converter(opset),
"BiasGelu": BiasGelu.get_converter(opset),
# TODO: We need a better way to handle different domains, in case
# of name collisions. EmbedLayerNormalization, SkipLayerNormalization, and Attention
# are in the `com.microsoft` domain.
"EmbedLayerNormalization": EmbedLayerNormalization.get_converter(opset),
"SkipLayerNormalization": SkipLayerNormalization.get_converter(opset),
"Attention": Attention.get_converter(opset),
"Exp": Renamer("exp"),
"Greater": Renamer("greater"),
"GreaterOrEqual": Renamer("greater_equal"),
Expand Down
Loading