-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relay/TOPI][ONNX/TFLite] Refactor MATRIX_SET_DIAG Operator for Relay/TOPI to support ONNX Trilu operator #10873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
6119889
0adfb26
1a75e90
726d85e
33d3806
27dc879
06f902c
e297dbe
ccb63b7
2911137
166639d
c3381b9
eff297b
0f1535e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3812,59 +3812,45 @@ TVM_REGISTER_NODE_TYPE(MatrixSetDiagAttrs); | |
| bool MatrixSetDiagRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, | ||
| const TypeReporter& reporter) { | ||
| // `types` contains: [input, diagonal, result] | ||
AndrewZhaoLuo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ICHECK_EQ(types.size(), 3); | ||
| ICHECK_EQ(types.size(), 5); | ||
|
|
||
| const auto* input = types[0].as<TensorTypeNode>(); | ||
| ICHECK(input); | ||
|
|
||
| const auto* diagonal = types[1].as<TensorTypeNode>(); | ||
| ICHECK(diagonal); | ||
|
|
||
| const auto param = attrs.as<MatrixSetDiagAttrs>(); | ||
| ICHECK_GE(param->k2, param->k1); | ||
|
|
||
| int d_ndims = diagonal->shape.size(); | ||
| int i_ndims = input->shape.size(); | ||
| const auto* k1 = types[2].as<TensorTypeNode>(); | ||
| ICHECK(k1); | ||
|
|
||
| reporter->Assert(input->shape[i_ndims - 2] > -param->k1); | ||
| reporter->Assert(input->shape[i_ndims - 1] > param->k2); | ||
| const auto* k2 = types[3].as<TensorTypeNode>(); | ||
| ICHECK(k2); | ||
|
|
||
| int d_ndims = diagonal->shape.size(); | ||
| for (int i = 0; i < d_ndims - 2; i++) { | ||
| reporter->AssertEQ(input->shape[i], diagonal->shape[i]); | ||
| } | ||
| if (param->k1 != param->k2) { | ||
| reporter->AssertEQ(diagonal->shape[d_ndims - 2], param->k2 - param->k1 + 1); | ||
| } else if (d_ndims >= 2) { | ||
| reporter->AssertEQ(input->shape[d_ndims - 2], diagonal->shape[d_ndims - 2]); | ||
| } | ||
| auto max_diag_len = if_then_else(input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0) <= | ||
| input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0), | ||
| input->shape[i_ndims - 2] + (param->k2 > 0 ? param->k2 : 0), | ||
| input->shape[i_ndims - 1] + (param->k1 < 0 ? -param->k1 : 0)); | ||
| reporter->AssertEQ(diagonal->shape[d_ndims - 1], max_diag_len); | ||
|
|
||
| reporter->Assign(types[2], TensorType(input->shape, input->dtype)); | ||
| reporter->Assign(types[4], TensorType(input->shape, input->dtype)); | ||
| return true; | ||
| } | ||
|
|
||
| Array<te::Tensor> MatrixSetDiagCompute(const Attrs& attrs, const Array<te::Tensor>& inputs, | ||
| const Type& out_type) { | ||
| const auto* param = attrs.as<MatrixSetDiagAttrs>(); | ||
| ICHECK(param != nullptr); | ||
| return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], param->k1, param->k2, | ||
| return Array<te::Tensor>{topi::matrix_set_diag(inputs[0], inputs[1], inputs[2], inputs[3], | ||
| param->super_diag_right_align, | ||
| param->sub_diag_right_align)}; | ||
| } | ||
|
|
||
| Expr MakeMatrixSetDiag(Expr input, Expr diagonal, int k1, int k2, bool super_diag_right_align, | ||
| Expr MakeMatrixSetDiag(Expr input, Expr diagonal, Expr k1, Expr k2, bool super_diag_right_align, | ||
| bool sub_diag_right_align) { | ||
| auto attrs = make_object<MatrixSetDiagAttrs>(); | ||
| attrs->k1 = k1; | ||
| attrs->k2 = k2; | ||
| attrs->super_diag_right_align = super_diag_right_align; | ||
| attrs->sub_diag_right_align = sub_diag_right_align; | ||
| static const Op& op = Op::Get("matrix_set_diag"); | ||
| return Call(op, {input, diagonal}, Attrs(attrs), {}); | ||
| return Call(op, {input, diagonal, k1, k2}, Attrs(attrs), {}); | ||
| } | ||
|
|
||
| TVM_REGISTER_GLOBAL("relay.op._make.matrix_set_diag").set_body_typed(MakeMatrixSetDiag); | ||
|
|
@@ -3880,9 +3866,11 @@ RELAY_REGISTER_OP("matrix_set_diag") | |
| **sub_diag_right_align** Bool, true iff sub-diagonal is right aligned (left-padded). | ||
| )code" TVM_ADD_FILELINE) | ||
| .set_attrs_type<MatrixSetDiagAttrs>() | ||
| .set_num_inputs(2) | ||
| .set_num_inputs(4) | ||
| .add_argument("input", "Tensor", "Input Tensor.") | ||
| .add_argument("diagonal", "Tensor", "Values to be filled in the diagonal.") | ||
| .add_argument("k1", "Tensor", "ILower limit (included) of the range of diagonals.") | ||
|
||
| .add_argument("k2", "Tensor", "Upper limit (included) of the range of diagonals.") | ||
| .set_support_level(10) | ||
| .add_type_rel("MatrixSetDiag", MatrixSetDiagRel) | ||
| .set_attr<FTVMCompute>("FTVMCompute", MatrixSetDiagCompute) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -752,21 +752,38 @@ def check_device(target, dev): | |
| def verify_matrix_set_diag(input_shape, diagonal_shape, dtype, k=0, align="RIGHT_LEFT"): | ||
| input = te.placeholder(shape=input_shape, name="input", dtype=dtype) | ||
| diagonal = te.placeholder(shape=diagonal_shape, name="diagonal", dtype=dtype) | ||
| matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, k, align) | ||
| k1 = te.placeholder(shape=(1,), name="k1", dtype="int64") | ||
| k2 = te.placeholder(shape=(1,), name="k2", dtype="int64") | ||
| matrix_set_diag_result = topi.transform.matrix_set_diag(input, diagonal, (k1, k2), align) | ||
|
|
||
| k_one, k_two = None, None | ||
| if isinstance(k, (tuple, list)): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add some comments to this test? It's a little hard to follow.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. I will add some comments and update the PR. |
||
| k_one = k[0] | ||
| if len(k) >= 2: | ||
| k_two = k[1] | ||
| else: | ||
| k_two = k[0] | ||
| else: | ||
| k_one = k | ||
| k_two = k | ||
|
|
||
| def check_device(target, dev): | ||
| dev = tvm.device(target, 0) | ||
| print("Running on target: %s" % target) | ||
| with tvm.target.Target(target): | ||
| s = tvm.topi.testing.get_injective_schedule(target)(matrix_set_diag_result) | ||
| fn = tvm.build(s, [input, diagonal, matrix_set_diag_result], target, name="matrix_set_diag") | ||
| fn = tvm.build( | ||
| s, [input, diagonal, k1, k2, matrix_set_diag_result], target, name="matrix_set_diag" | ||
| ) | ||
| input_npy = np.random.randint(-100, 100, size=input_shape).astype(dtype) | ||
| diagonal_npy = np.random.randint(-100, 100, size=diagonal_shape).astype(dtype) | ||
| out_npy = tvm.topi.testing.matrix_set_diag(input_npy, diagonal_npy, k, align) | ||
| input_nd = tvm.nd.array(input_npy, dev) | ||
| diagonal_nd = tvm.nd.array(diagonal_npy, dev) | ||
| k1_nd = tvm.nd.array(np.asarray([k_one]), dev) | ||
| k2_nd = tvm.nd.array(np.asarray([k_two]), dev) | ||
| out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(matrix_set_diag_result.dtype), dev) | ||
| fn(input_nd, diagonal_nd, out_nd) | ||
| fn(input_nd, diagonal_nd, k1_nd, k2_nd, out_nd) | ||
| out_topi = out_nd.numpy() | ||
| tvm.testing.assert_allclose(out_topi, out_npy) | ||
|
|
||
|
|
@@ -1240,9 +1257,6 @@ def test_sparse_to_dense(): | |
| @tvm.testing.uses_gpu | ||
| def test_matrix_set_diag(): | ||
| for dtype in ["float32", "int32"]: | ||
| verify_matrix_set_diag((2, 2), (2,), dtype) | ||
| verify_matrix_set_diag((4, 3, 3), (4, 3), dtype) | ||
| verify_matrix_set_diag((2, 3, 4), (2, 3), dtype, 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just curious, why are these tests removed?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This has been removed in the original PR. I ended up not putting it back and testing it again. Thanks for pointing it out since the test_matrix_set_diag is outputting an error for these three cases. I will check it out.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Josh, these three TOPI test cases are still not passing. I am debugging it to solve the Tensor dimension mismatch issue. |
||
| verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_RIGHT") | ||
| verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "LEFT_LEFT") | ||
| verify_matrix_set_diag((2, 3, 4), (2, 4, 3), dtype, (-1, 2), "RIGHT_RIGHT") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.