Skip to content

Commit cc0f27a

Browse files
author
Josh Fromm
authored
[Relay] Remove overwriting of matmul shapes when they are static (#13615)
In the Relay Matmul shape relation, we are a little over enthusiastic about unifying dynamic shapes. If one of the shapes is static, it does not need to be unified. This change only rewrites dynamic shapes to required static constraints. * Remove overwriting of matmul shapes when they are static * Simplify nesting * Add shape check to dense tests.
1 parent 7fd0cdb commit cc0f27a

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

src/relay/op/nn/nn.h

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,23 +113,32 @@ bool MatmulRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
113113
std::vector<PrimExpr> B_shape(tensor_b->shape.begin(), tensor_b->shape.end());
114114
auto sa = A_shape.size();
115115
auto sb = B_shape.size();
116+
size_t index_swap_A;
117+
size_t index_swap_B;
116118
if (transpose_a && transpose_b) {
117-
auto tmp = A_shape[sa - 2];
118-
A_shape[sa - 2] = B_shape[sb - 1];
119-
B_shape[sb - 1] = tmp;
119+
index_swap_A = sa - 2;
120+
index_swap_B = sb - 1;
120121
} else if (transpose_a) {
121-
auto tmp = A_shape[sa - 2];
122-
A_shape[sa - 2] = B_shape[sb - 2];
123-
B_shape[sb - 2] = tmp;
122+
index_swap_A = sa - 2;
123+
index_swap_B = sb - 2;
124124
} else if (transpose_b) {
125-
auto tmp = A_shape[sa - 1];
126-
A_shape[sa - 1] = B_shape[sb - 1];
127-
B_shape[sb - 1] = tmp;
125+
index_swap_A = sa - 1;
126+
index_swap_B = sb - 1;
128127
} else {
129-
auto tmp = A_shape[sa - 1];
130-
A_shape[sa - 1] = B_shape[sb - 2];
131-
B_shape[sb - 2] = tmp;
128+
index_swap_A = sa - 1;
129+
index_swap_B = sb - 2;
132130
}
131+
132+
// Rewrite dynamic axes to static where constraints allow.
133+
auto tmp = A_shape[index_swap_A];
134+
if (A_shape[index_swap_A].as<tir::AnyNode>()) {
135+
A_shape[index_swap_A] = B_shape[index_swap_B];
136+
}
137+
if (B_shape[index_swap_B].as<tir::AnyNode>()) {
138+
B_shape[index_swap_B] = tmp;
139+
}
140+
141+
// Update input types with new constrained shapes.
133142
reporter->Assign(types[0], TensorType(A_shape, tensor_a->dtype));
134143
reporter->Assign(types[1], TensorType(B_shape, tensor_b_dtype));
135144
}

tests/python/relay/test_op_level1.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import tvm.topi.testing
2626
from tvm.contrib.nvcc import have_fp16
2727
import tvm.testing
28+
from tvm.topi.utils import get_const_tuple
2829

2930
executor_kind = tvm.testing.parameter("graph", "vm")
3031

@@ -695,6 +696,8 @@ def test_dense(executor_kind):
695696
w = relay.var("w", relay.TensorType((k, n), dtype))
696697
y = relay.nn.dense(x, w)
697698
yy = run_infer_type(y)
699+
# Confirm that input shape has not been rewritten to become dynamic.
700+
assert get_const_tuple(yy.type_args[0].shape) == (4, 2)
698701

699702
n, c, h, w = te.size_var("n"), te.size_var("c"), te.size_var("h"), 2
700703
x = relay.var("x", relay.TensorType((n, c, h, w), dtype))

0 commit comments

Comments
 (0)