Skip to content

Commit 0225f2b

Browse files
committed
share op strategy between cuda and rocm
1 parent 762c7e8 commit 0225f2b

File tree

4 files changed

+22
-241
lines changed

4 files changed

+22
-241
lines changed

python/tvm/relay/op/strategy/rocm.py

Lines changed: 20 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -24,178 +24,42 @@
2424

2525
from .generic import *
2626
from .. import op as _op
27-
from .cuda import judge_winograd, naive_schedule
27+
from .cuda import batch_matmul_strategy_cuda, conv2d_strategy_cuda, dense_strategy_cuda
2828

2929

3030
@conv2d_strategy.register("rocm")
3131
def conv2d_strategy_rocm(attrs, inputs, out_type, target):
3232
"""conv2d rocm strategy"""
33-
strategy = _op.OpStrategy()
34-
data, kernel = inputs
35-
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
3633
groups = attrs.groups
3734
layout = attrs.data_layout
38-
stride_h, stride_w = attrs.get_int_tuple("strides")
39-
kernel_layout = attrs.kernel_layout
4035
padding = attrs.get_int_tuple("padding")
41-
if dilation_h < 1 or dilation_w < 1:
42-
raise ValueError("dilation should be positive value")
4336

44-
if groups == 1:
45-
if layout == "NCHW":
46-
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
47-
assert kernel_layout == "OIHW"
48-
strategy.add_implementation(
49-
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
50-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw),
51-
name="conv2d_nchw.cuda",
52-
)
53-
_, _, kh, kw = get_const_tuple(kernel.shape)
54-
if (
55-
2 < kh < 8
56-
and 2 < kw < 8
57-
and kh == kw
58-
and stride_h == 1
59-
and stride_w == 1
60-
and dilation_h == 1
61-
and dilation_w == 1
62-
):
63-
strategy.add_implementation(
64-
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
65-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
66-
name="conv2d_nchw_winograd.cuda",
67-
plevel=5,
68-
)
69-
elif layout == "NHWC":
70-
assert kernel_layout == "HWIO"
71-
strategy.add_implementation(
72-
wrap_compute_conv2d(topi.gpu.conv2d_nhwc),
73-
wrap_topi_schedule(topi.gpu.schedule_conv2d_nhwc),
74-
name="conv2d_nhwc.gpu",
75-
)
76-
N, H, W, _ = get_const_tuple(data.shape)
77-
KH, KW, CI, CO = get_const_tuple(kernel.shape)
37+
strategy = conv2d_strategy_cuda(attrs, inputs, out_type, target)
7838

79-
(_, judge_winograd_autotvm, judge_winograd_auto_scheduler,) = judge_winograd(
80-
N,
81-
H,
82-
W,
83-
KH,
84-
KW,
85-
CI,
86-
CO,
87-
padding,
88-
stride_h,
89-
stride_w,
90-
dilation_h,
91-
dilation_w,
92-
data.dtype,
93-
kernel.dtype,
94-
pre_flag=False,
95-
)
96-
97-
if judge_winograd_autotvm:
98-
strategy.add_implementation(
99-
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_direct),
100-
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc_winograd_direct),
101-
name="conv2d_nhwc_winograd_direct.cuda",
102-
plevel=5,
103-
)
39+
# add miopen implementation
40+
if (
41+
"miopen" in target.libs
42+
and groups == 1
43+
and layout == "NCHW"
44+
and padding[0] == padding[2]
45+
and padding[1] == padding[3]
46+
):
47+
strategy.add_implementation(
48+
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
49+
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
50+
name="conv2d_nchw_miopen.rocm",
51+
plevel=50,
52+
)
10453

105-
if is_auto_scheduler_enabled() and judge_winograd_auto_scheduler:
106-
strategy.add_implementation(
107-
wrap_compute_conv2d(topi.nn.conv2d_winograd_nhwc),
108-
naive_schedule, # this implementation should never be picked by autotvm
109-
name="conv2d_nhwc.winograd",
110-
plevel=15,
111-
)
112-
elif layout == "HWCN":
113-
assert kernel_layout == "HWIO"
114-
strategy.add_implementation(
115-
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
116-
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn),
117-
name="conv2d_hwcn.cuda",
118-
)
119-
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
120-
assert kernel_layout == "OIHW4o4i"
121-
strategy.add_implementation(
122-
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
123-
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
124-
name="conv2d_NCHWc_int8.cuda",
125-
)
126-
else:
127-
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
128-
# add miopen implementation
129-
if (
130-
"miopen" in target.libs
131-
and layout == "NCHW"
132-
and padding[0] == padding[2]
133-
and padding[1] == padding[3]
134-
):
135-
strategy.add_implementation(
136-
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
137-
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
138-
name="conv2d_nchw_miopen.rocm",
139-
plevel=15,
140-
)
141-
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
142-
if layout == "NCHW":
143-
assert kernel_layout == "OIHW"
144-
strategy.add_implementation(
145-
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
146-
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw),
147-
name="depthwise_conv2d_nchw.cuda",
148-
)
149-
elif layout == "NHWC":
150-
assert kernel_layout == "HWOI"
151-
strategy.add_implementation(
152-
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
153-
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
154-
name="depthwise_conv2d_nhwc.cuda",
155-
)
156-
else:
157-
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
158-
else: # group_conv2d
159-
if layout == "NCHW":
160-
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
161-
assert kernel_layout == "OIHW"
162-
strategy.add_implementation(
163-
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
164-
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
165-
name="group_conv2d_nchw.cuda",
166-
)
167-
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
168-
assert kernel_layout == "OIHW4o4i"
169-
strategy.add_implementation(
170-
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
171-
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8),
172-
name="group_conv2d_NCHWc_int8.cuda",
173-
)
174-
else:
175-
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
17654
return strategy
17755

17856

17957
@dense_strategy.register("rocm")
18058
def dense_strategy_rocm(attrs, inputs, out_type, target):
18159
"""Dense strategy for ROCM"""
18260
assert len(inputs[0].shape) == 2 and len(inputs[1].shape) == 2, "Only support 2-dim dense"
183-
strategy = _op.OpStrategy()
184-
strategy.add_implementation(
185-
wrap_compute_dense(topi.rocm.dense),
186-
wrap_topi_schedule(topi.rocm.schedule_dense),
187-
name="dense.rocm",
188-
)
189-
data, weights = inputs
190-
if (data.dtype == "int8"
191-
and weights.dtype == "int8"
192-
and out_type.dtype == "int32"
193-
):
194-
strategy.add_implementation(
195-
wrap_compute_dense(topi.cuda.dense_int8),
196-
wrap_topi_schedule(topi.cuda.schedule_dense_int8),
197-
name="dense_int8.rocm",
198-
)
61+
strategy = dense_strategy_cuda(attrs, inputs, out_type, target)
62+
19963
if target.kind.name == "rocm" and "rocblas" in target.libs:
20064
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
20165
strategy.add_implementation(
@@ -210,13 +74,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
21074
@batch_matmul_strategy.register("rocm")
21175
def batch_matmul_strategy_rocm(attrs, inputs, out_type, target):
21276
"""Batch matmul strategy for ROCM"""
213-
strategy = _op.OpStrategy()
214-
strategy.add_implementation(
215-
wrap_compute_batch_matmul(topi.cuda.batch_matmul, need_out_dtype=True),
216-
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
217-
name="batch_matmul.cuda",
218-
plevel=10,
219-
)
77+
strategy = batch_matmul_strategy_cuda(attrs, inputs, out_type, target)
78+
22079
if target.kind.name == "rocm" and "rocblas" in target.libs:
22180
assert out_type.dtype == inputs[0].dtype, "Mixed precision not supported."
22281
strategy.add_implementation(

python/tvm/topi/cuda/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _schedule_dense_int8(cfg, s, output):
175175
do_tensorize = True
176176
# if "vulkan" in target.keys or "rocm" in target.keys:
177177
# do_tensorize = "+dotprod" in target.mattr or target.supports_integer_dot_product
178-
assert False
178+
179179
if do_tensorize:
180180
dtypes = (data.dtype, weight.dtype)
181181
s[CC].tensorize(ki, dp4a("shared", "shared", "local", dtypes))

python/tvm/topi/rocm/dense.py

Lines changed: 1 addition & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -19,85 +19,8 @@
1919
from tvm import te
2020
from tvm import autotvm
2121
from tvm.contrib import rocblas
22-
from .. import generic, nn
22+
from .. import generic
2323
from .. import tag
24-
from ..utils import traverse_inline
25-
26-
27-
@autotvm.register_topi_compute("dense.rocm")
28-
def dense(cfg, data, weight, bias=None, out_dtype=None):
29-
"""Dense operator for rocm backend.
30-
31-
Parameters
32-
----------
33-
data : tvm.te.Tensor
34-
2-D with shape [batch, in_dim]
35-
36-
weight : tvm.te.Tensor
37-
2-D with shape [out_dim, in_dim]
38-
39-
bias : tvm.te.Tensor, optional
40-
1-D with shape [out_dim]
41-
42-
out_dtype : str
43-
The output type. This is used for mixed precision.
44-
45-
Returns
46-
-------
47-
output : tvm.te.Tensor
48-
2-D with shape [batch, out_dim]
49-
"""
50-
assert len(data.shape) == 2 and len(weight.shape) == 2, "only support 2-dim dense"
51-
if bias is not None:
52-
assert len(bias.shape) == 1
53-
if out_dtype is None:
54-
out_dtype = data.dtype
55-
return nn.dense(data, weight, bias, out_dtype)
56-
57-
58-
@autotvm.register_topi_schedule("dense.rocm")
59-
def schedule_dense(cfg, outs):
60-
"""Schedule for dense operator.
61-
62-
Parameters
63-
----------
64-
outs: Array of Tensor
65-
The computation graph description of dense
66-
in the format of an array of tensors.
67-
68-
Returns
69-
-------
70-
s: Schedule
71-
The computation schedule for dense.
72-
"""
73-
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
74-
s = te.create_schedule([x.op for x in outs])
75-
76-
def _callback(op):
77-
if op.tag == "dense":
78-
Dense = op.output(0)
79-
num_thread = 64
80-
k = Dense.op.reduce_axis[0]
81-
ko, kf = s[Dense].split(k, factor=num_thread)
82-
DenseF = s.rfactor(Dense, kf)
83-
84-
if Dense.op in s.outputs:
85-
Out = Dense
86-
else:
87-
Out = outs[0].op.output(0)
88-
s[Dense].compute_at(s[Out], s[Out].op.axis[1])
89-
s[Out].bind(s[Out].op.axis[0], te.thread_axis("blockIdx.y"))
90-
s[Out].bind(s[Out].op.axis[1], te.thread_axis("blockIdx.x"))
91-
92-
tx = s[Dense].op.reduce_axis[0]
93-
thread_x = te.thread_axis("threadIdx.x")
94-
s[Dense].bind(tx, thread_x)
95-
s[DenseF].compute_at(s[Dense], tx)
96-
s[Dense].set_store_predicate(thread_x.var.equal(0))
97-
s[Out].set_store_predicate(thread_x.var.equal(0))
98-
99-
traverse_inline(s, outs[0].op, _callback)
100-
return s
10124

10225

10326
@autotvm.register_topi_compute("dense_rocblas.rocm")

tests/python/topi/python/test_topi_dense.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
],
5353
"mali": [(topi.mali.dense, topi.mali.schedule_dense)],
5454
"bifrost": [(topi.bifrost.dense, topi.bifrost.schedule_dense)],
55-
"rocm": [(topi.rocm.dense, topi.rocm.schedule_dense)],
5655
"hls": [(topi.nn.dense, topi.hls.schedule_dense)],
5756
}
5857

0 commit comments

Comments
 (0)