Skip to content

Commit 0898dbf

Browse files
author
Sergey Shtin
committed
Change strategy for cuda to fix tests.
1 parent 1bc3b03 commit 0898dbf

File tree

1 file changed

+9
-5
lines changed
  • python/tvm/relay/op/strategy

1 file changed

+9
-5
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,15 @@ def schedule_reduce_cuda(attrs, outs, target):
4242
return topi.cuda.schedule_reduce(outs)
4343

4444

45-
@schedule_concatenate.register(["cuda", "gpu"])
46-
def schedule_concatenate_cuda(attrs, outs, target):
47-
"""schedule concatenate for cuda"""
48-
with target:
49-
return topi.cuda.schedule_injective(outs)
45+
@concatenate_strategy.register(["cuda", "gpu"])
46+
def concatenate_strategy_cuda(attrs, inputs, out_type, target):
47+
strategy = _op.OpStrategy()
48+
strategy.add_implementation(
49+
wrap_compute_concat(topi.transform.concatenate),
50+
wrap_topi_schedule(topi.cuda.schedule_injective),
51+
name="concatenate.cuda",
52+
)
53+
return strategy
5054

5155

5256
@schedule_pool.register(["cuda", "gpu"])

0 commit comments

Comments
 (0)