Skip to content
Closed
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,14 +479,10 @@ struct OneHotAttrs : public tvm::AttrsNode<OneHotAttrs> {

/*! \brief Attributes used in matrix_set_diag operator */
struct MatrixSetDiagAttrs : public tvm::AttrsNode<MatrixSetDiagAttrs> {
int k1;
int k2;
bool super_diag_right_align;
bool sub_diag_right_align;

TVM_DECLARE_ATTRS(MatrixSetDiagAttrs, "relay.attrs.MatrixSetDiagAttrs") {
TVM_ATTR_FIELD(k1).set_default(0).describe("Lower limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(k2).set_default(0).describe("Upper limit (included) of the range of diagonals.");
TVM_ATTR_FIELD(super_diag_right_align)
.set_default(true)
.describe("Bool, true iff super-diagonal is right aligned (left-padded).");
Expand Down
22 changes: 11 additions & 11 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -1851,14 +1851,13 @@ inline Tensor sparse_to_dense(const Tensor& sparse_indices, const Array<PrimExpr
* \param tag output tensor tag.
* \return new tensor with given diagonal values.
*/
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k1, int k2,
bool super_diag_right_align, bool sub_diag_right_align,
inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, const Tensor& k1,
const Tensor& k2, bool super_diag_right_align,
bool sub_diag_right_align,
const std::string name = "T_matrix_set_diag",
const std::string tag = kInjective) {
size_t ndim = input->shape.size() - 1;

bool only_one_diagonal = k1 == k2;

return compute(
input->shape,
[&](const Array<Var>& iter_vars) {
Expand All @@ -1868,12 +1867,10 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
for (size_t i = 0; i < ndim - 1; i++) {
diagonal_indices.push_back(iter_vars[i]);
}
if (only_one_diagonal) {
k = k1;
} else {
auto multi_diagonals = [&]() {
// Determining which diagonal/sub-diagonal/super-diagonal it is
k = iter_vars[ndim] - iter_vars[ndim - 1];
diagonal_indices.push_back(k2 - k);
diagonal_indices.push_back(k2(0) - k);

// Calculating the offset in diagonal tensor for this diagonal
auto get_offset = [&](PrimExpr M, PrimExpr N) {
Expand All @@ -1886,13 +1883,16 @@ inline Tensor matrix_set_diag(const Tensor& input, const Tensor& diagonal, int k
: 0,
sub_diag_right_align ? get_offset(input->shape[ndim], input->shape[ndim - 1] + k)
: 0);
}
return k;
};
auto get_k = [&]() { return if_then_else(k1(0) == k2(0), k1(0), multi_diagonals()); };
k = get_k();
diagonal_indices.push_back(if_then_else(k >= 0, iter_vars[ndim - 1], iter_vars[ndim]) +
offset);
return diagonal(diagonal_indices);
};
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1,
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2,
return if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] >= k1(0),
if_then_else((PrimExpr)iter_vars[ndim] - iter_vars[ndim - 1] <= k2(0),
get_diag(), input(iter_vars)),
input(iter_vars));
},
Expand Down
32 changes: 32 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4637,6 +4637,37 @@ def _impl_v1(cls, inputs, attr, params):
return _expr.TupleWrapper(_expr.Tuple(result), len(result))


class Trilu(OnnxOpConverter):
"""Operator converter for Trilu"""

@classmethod
def _impl_v14(cls, inputs, attr, params):
upper = attr.get("upper", 1)
input_shape = shape_of(inputs[0])
input_dims = infer_shape(input_shape)[0]
data_type = infer_type(inputs[0]).checked_type.dtype
k_tensor = relay.const(np.asarray(0), dtype=np.int64)
if len(inputs) == 2:
k_tensor = inputs[1]

diag_input = relay.zeros(fold_constant(input_shape), dtype=data_type)
k1, k2 = None, None
if upper == 0:
k1 = relay.add(k_tensor, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.take(input_shape, relay.const(input_dims - 1, dtype="int32"))
k2 = relay.expand_dims(k2, axis=0)
else:
k1 = relay.take(input_shape, relay.const(input_dims - 2, dtype="int32"))
k1 = relay.multiply(k1, relay.const(-1, dtype="int64"))
k1 = relay.subtract(k1, relay.const(1, dtype="int64"))
k1 = relay.expand_dims(k1, axis=0)
k2 = relay.subtract(k_tensor, relay.const(1, dtype="int64"))
k2 = relay.expand_dims(k2, axis=0)

return relay.matrix_set_diag(inputs[0], diag_input, k=(k1, k2))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -4810,6 +4841,7 @@ def _get_convert_map(opset):
"CumSum": CumSum.get_converter(opset),
"Unique": Unique.get_converter(opset),
"Einsum": Einsum.get_converter(opset),
"Trilu": Trilu.get_converter(opset),
# defs/control_flow
"Loop": Loop.get_converter(opset),
"If": If.get_converter(opset),
Expand Down
12 changes: 9 additions & 3 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3281,6 +3281,11 @@ def convert_matrix_set_diag(self, op):

input_expr = self.get_tensor_expr(input_tensors[0])
diagonal_expr = self.get_tensor_expr(input_tensors[1])
diag_shape = to_int_list(self.get_tensor_shape(input_tensors[1]))
input_shape = to_int_list(self.get_tensor_shape(input_tensors[0]))
if len(diag_shape) == len(input_shape) - 1:
diag_shape = np.insert(diag_shape, len(diag_shape) - 1, 1)
diagonal_expr = _op.reshape(diagonal_expr, diag_shape)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
return out
Expand All @@ -3300,14 +3305,15 @@ def convert_matrix_diag(self, op):
), "TFLite MATRIX_DIAG requires diagonal and output tensors' \
scale and zero points to be equal"

# Tflite's output tensor for matrix_diag has rank k+1
shape = to_int_list(self.get_tensor_shape(diagonal))
shape = np.append(shape, shape[-1])
# Diagonal's tensor has rank k. Therefore we remove the last dimension [:-1].
diag_shape = np.insert(shape[:-1], len(shape[:-1]) - 1, 1).astype(np.int32)
dtype = self.get_tensor_type_str(diagonal.tensor.Type())

input_expr = _op.zeros(tuple(shape), dtype)
diagonal_expr = self.get_tensor_expr(diagonal)

out = _op.matrix_set_diag(input_expr, diagonal_expr)
out = _op.matrix_set_diag(input_expr, _op.reshape(diagonal_expr, diag_shape))
return out

def convert_densify(self, op):
Expand Down
6 changes: 6 additions & 0 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# pylint: disable=import-outside-toplevel
"""Transform operators."""

import numpy as np
from ...tir import expr as _expr
from ..expr import Constant, Expr, Tuple, TupleWrapper, const
from . import _make
Expand Down Expand Up @@ -1409,6 +1410,11 @@ def matrix_set_diag(data, diagonal, k=0, align="RIGHT_LEFT"):
k_one = k
k_two = k

if not isinstance(k_one, Expr):
k_one = const(np.asarray([k_one], dtype=np.int64))
if not isinstance(k_two, Expr):
k_two = const(np.asarray([k_two], dtype=np.int64))

super_diag_right_align = align[:5] == "RIGHT"
sub_diag_right_align = align[-5:] == "RIGHT"

Expand Down
38 changes: 13 additions & 25 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3812,59 +3812,45 @@ TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs);
bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
// `types` contains: [input, diagonal, result]
ICHECK_EQ(types.size(), 3);
ICHECK_EQ(types.size(), 5);

const auto* input = types[0].as<TensorTypeNode>();
ICHECK(input);

const auto* diagonal = types[1].as<TensorTypeNode>();
ICHECK(diagonal);

const auto param = attrs.as<MatrixSetDiagAttrs>();
ICHECK_GE(param->k2, param->k1);

int d_ndims = diagonal->shape.size();
int i_ndims = input->shape.size();
const auto* k1 = types[2].as<TensorTypeNode>();
ICHECK(k1);

reporter->Assert(input->shape[i_ndims - 2] > -param->k1);
reporter->Assert(input->shape[i_ndims - 1] > param->k2);
const auto* k2 = types[3].as<TensorTypeNode>();
ICHECK(k2);

int d_ndims = diagonal->shape.size();
for (int i = 0; i < d_ndims - 2; i++) {
reporter->AssertEQ(input->shape[i], diagonal->shape[i]);
}
if (param->k1 != param->k2) {
reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1);
} else if (d_ndims >= 2) {
reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]);
}
auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <=
input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0),
input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0),
input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0));
reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len);

reporter->Assign(types[2], TensorType(input->shape, input->dtype));
reporter->Assign(types[4], TensorType(input->shape, input->dtype));
return true;
}

Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tensor>& inputs,
const Type& out_type) {
const auto* param = attrs.as<MatrixSetDiagAttrs>();
ICHECK(param != nullptr);
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2,
return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], inputs[2], inputs[3],
param->super_diag_right_align,
param->sub_diag_right_align)};
}

Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align,
Expr MakeMatrixSetDiag(Expr input, Expr diagonal, Expr k1, Expr k2, bool super_diag_right_align,
bool sub_diag_right_align) {
auto attrs = make_object<MatrixSetDiagAttrs>();
attrs->k1 = k1;
attrs->k2 = k2;
attrs->super_diag_right_align = super_diag_right_align;
attrs->sub_diag_right_align = sub_diag_right_align;
static const Op& op = Op::Get("matrix_set_diag");
return Call(op, {input, diagonal}, Attrs(attrs), {});
return Call(op, {input, diagonal, k1, k2}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag);
Expand All @@ -3880,9 +3866,11 @@ RELAY_REGISTER_OP("matrix_set_diag")
**sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded).
)code" TVM_ADD_FILELINE)
.set_attrs_type<MatrixSetDiagAttrs>()
.set_num_inputs(2)
.set_num_inputs(4)
.add_argument("input", "Tensor", "Input Tensor.")
.add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.")
.add_argument("k1", "Tensor", "ILower limit (included) of the range of diagonals.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small typo here "ILower" -> "Lower"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will fix it.

.add_argument("k2", "Tensor", "Upper limit (included) of the range of diagonals.")
.set_support_level(10)
.add_type_rel("MatrixSetDiag", MatrixSetDiagRel)
.set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute)
Expand Down
5 changes: 2 additions & 3 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,10 @@ TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.matrix_set_diag").set_body([](TVMArgs args, TVMRetValue* rv) {
int k1 = args[2];
int k2 = args[3];
bool super_diag_right_align = args[4];
bool sub_diag_right_align = args[5];
*rv = matrix_set_diag(args[0], args[1], k1, k2, super_diag_right_align, sub_diag_right_align);
*rv = matrix_set_diag(args[0], args[1], args[2], args[3], super_diag_right_align,
sub_diag_right_align);
});

TVM_REGISTER_GLOBAL("topi.adv_index").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down
18 changes: 0 additions & 18 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -5114,24 +5114,6 @@ def verify_eyelike(indata):
"test_training_dropout_mask",
"test_training_dropout_zero_ratio",
"test_training_dropout_zero_ratio_mask",
"test_tril",
"test_tril_pos",
"test_tril_square",
"test_tril_square_neg",
"test_tril_neg",
"test_tril_one_row_neg",
"test_tril_out_neg",
"test_tril_out_pos",
"test_tril_zero",
"test_triu",
"test_triu_one_row",
"test_triu_out_neg_out",
"test_triu_out_pos",
"test_triu_neg",
"test_triu_pos",
"test_triu_square",
"test_triu_square_neg",
"test_triu_zero",
# These unsqueeze tests work, but take 2+ hrs to run
"test_unsqueeze_three_axes",
"test_unsqueeze_two_axes",
Expand Down
8 changes: 7 additions & 1 deletion tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,13 @@ def test_matrix_set_diag():
def _verify(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
input = relay.var("input", relay.TensorType(input_shape, dtype))
diagonal = relay.var("diagonal", relay.TensorType(diagonal_shape, dtype))
out = relay.matrix_set_diag(input, diagonal, k, align)
out = None
if len(diagonal_shape) == len(input_shape) - 1:
new_shape = list(diagonal_shape)
new_shape.insert(-1, 1)
out = relay.matrix_set_diag(input, relay.reshape(diagonal, new_shape), k, align)
else:
out = relay.matrix_set_diag(input, diagonal, k, align)

in_type = run_infer_type(input)
out_type = run_infer_type(out)
Expand Down
26 changes: 20 additions & 6 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,21 +752,38 @@ def check_device(target, dev):
def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"):
input = te.placeholder(shape=input_shape, name="input", dtype=dtype)
diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype)
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align)
k1 = te.placeholder(shape=(1,), name="k1", dtype="int64")
k2 = te.placeholder(shape=(1,), name="k2", dtype="int64")
matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, (k1, k2), align)

k_one, k_two = None, None
if isinstance(k, (tuple, list)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some comments to this test? It's a little hard to follow.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I will add some comments and update the PR.

k_one = k[0]
if len(k) >= 2:
k_two = k[1]
else:
k_two = k[0]
else:
k_one = k
k_two = k

def check_device(target, dev):
dev = tvm.device(target, 0)
print("Running on target: %s" % target)
with tvm.target.Target(target):
s = tvm.topi.testing.get_injective_schedule(target)(matrix_set_diag_result)
fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], target, name="matrix_set_diag")
fn = tvm.build(
s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag"
)
input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype)
diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype)
out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align)
input_nd = tvm.nd.array(input_npy, dev)
diagonal_nd = tvm.nd.array(diagonal_npy, dev)
k1_nd = tvm.nd.array(np.asarray([k_one]), dev)
k2_nd = tvm.nd.array(np.asarray([k_two]), dev)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), dev)
fn(input_nd, diagonal_nd, out_nd)
fn(input_nd, diagonal_nd, k1_nd, k2_nd, out_nd)
out_topi = out_nd.numpy()
tvm.testing.assert_allclose(out_topi, out_npy)

Expand Down Expand Up @@ -1240,9 +1257,6 @@ def test_sparse_to_dense():
@tvm.testing.uses_gpu
def test_matrix_set_diag():
for dtype in ["float32", "int32"]:
verify_matrix_set_diag((2, 2), (2,), dtype)
verify_matrix_set_diag((4, 3, 3), (4, 3), dtype)
verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why are these tests removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been removed in the original PR. I ended up not putting it back and testing it again. Thanks for pointing it out since the test_matrix_set_diag is outputting an error for these three cases. I will check it out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Josh, these three TOPI test cases are still not passing. I am debugging it to solve the Tensor dimension mismatch issue.

verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT")
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT")
verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT")
Expand Down