Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/tvm/relay/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3006,7 +3006,7 @@ def convert_batch_matmul(self, op):
rank_diff = rank_a - rank_b
new_b_shape = _op.concatenate(
[
_expr.const([1] * rank_diff, dtype=_infer_type(b_shape).checked_type.dtype),
_expr.const([1] * rank_diff, dtype=_infer_type(new_b_shape).checked_type.dtype),
shape_b,
],
0,
Expand All @@ -3015,7 +3015,7 @@ def convert_batch_matmul(self, op):
rank_diff = rank_b - rank_a
new_a_shape = _op.concatenate(
[
_expr.const([1] * rank_diff, dtype=_infer_type(a_shape).checked_type.dtype),
_expr.const([1] * rank_diff, dtype=_infer_type(new_a_shape).checked_type.dtype),
shape_a,
],
0,
Expand All @@ -3041,9 +3041,9 @@ def convert_batch_matmul(self, op):
_op.concatenate([out_batch, _op.strided_slice(shape_b, [rank_b - 2], [rank_b])], 0)
)
if not tvm.ir.structural_equal(shape_a, a_broadcasted_shape):
input_a = _op.transform.broadcast_to(a, a_broadcasted_shape)
input_a = _op.transform.broadcast_to(input_a, a_broadcasted_shape)
if not tvm.ir.structural_equal(shape_b, b_broadcasted_shape):
input_b = _op.transform.broadcast_to(b, b_broadcasted_shape)
input_b = _op.transform.broadcast_to(input_b, b_broadcasted_shape)

input_a = self.flatten_to_nd(input_a, shape_a, 3)
input_b = self.flatten_to_nd(input_b, shape_b, 3)
Expand Down