Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Fix more pooling schedule #9021

Merged
merged 1 commit into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/tvm/topi/cuda/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _schedule_non_global(Pool):
def traverse(OP):
"""Internal traverse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if tag.is_injective(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/x86/pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def _schedule(PaddedInput, Pool):
def traverse(OP):
"""Internal traverse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if tag.is_injective(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
Expand Down Expand Up @@ -137,7 +137,7 @@ def schedule_adaptive_pool(outs):
def traverse(OP):
"""Internal traverse function"""
# inline all one-to-one-mapping operators except the last stage (output)
if tag.is_broadcast(OP.tag):
if tag.is_injective(OP.tag):
if OP not in s.outputs:
s[OP].compute_inline()
for tensor in OP.input_tensors:
Expand Down