Skip to content

Commit

Permalink
add cuda context
Browse files Browse the repository at this point in the history
  • Loading branch information
hypercubestart committed Apr 25, 2021
1 parent 81b68d6 commit 50ac65b
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm
from tvm import te
from tvm import autotvm
from tvm.target import Target
from tvm.topi.cuda.injective import schedule_injective_from_existing
from ..utils import get_const_tuple, traverse_inline, simplify, tag
from ..nn.pad import pad
Expand Down Expand Up @@ -224,12 +225,13 @@ def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
s[packed_data].pragma(s[packed_data].op.axis[0], "debug_skip_region")
s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
else:
if (
isinstance(packed_kernel.op, te.tensor.ComputeOp)
and packed_kernel.name == "packed_kernel"
):
schedule_injective_from_existing(s, packed_kernel)
schedule_injective_from_existing(s, packed_data)
with Target("cuda"):
if (
isinstance(packed_kernel.op, te.tensor.ComputeOp)
and packed_kernel.name == "packed_kernel"
):
schedule_injective_from_existing(s, packed_kernel)
schedule_injective_from_existing(s, packed_data)

if pad_data != packed_data:
s[pad_data].compute_inline()
Expand Down

0 comments on commit 50ac65b

Please sign in to comment.