2424
2525from .generic import *
2626from .. import op as _op
27- from .cuda import judge_winograd , naive_schedule
27+ from .cuda import batch_matmul_strategy_cuda , conv2d_strategy_cuda , dense_strategy_cuda
2828
2929
3030@conv2d_strategy .register ("rocm" )
3131def conv2d_strategy_rocm (attrs , inputs , out_type , target ):
3232 """conv2d rocm strategy"""
33- strategy = _op .OpStrategy ()
34- data , kernel = inputs
35- dilation_h , dilation_w = attrs .get_int_tuple ("dilation" )
3633 groups = attrs .groups
3734 layout = attrs .data_layout
38- stride_h , stride_w = attrs .get_int_tuple ("strides" )
39- kernel_layout = attrs .kernel_layout
4035 padding = attrs .get_int_tuple ("padding" )
41- if dilation_h < 1 or dilation_w < 1 :
42- raise ValueError ("dilation should be positive value" )
4336
44- if groups == 1 :
45- if layout == "NCHW" :
46- # TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
47- assert kernel_layout == "OIHW"
48- strategy .add_implementation (
49- wrap_compute_conv2d (topi .cuda .conv2d_nchw ),
50- wrap_topi_schedule (topi .cuda .schedule_conv2d_nchw ),
51- name = "conv2d_nchw.cuda" ,
52- )
53- _ , _ , kh , kw = get_const_tuple (kernel .shape )
54- if (
55- 2 < kh < 8
56- and 2 < kw < 8
57- and kh == kw
58- and stride_h == 1
59- and stride_w == 1
60- and dilation_h == 1
61- and dilation_w == 1
62- ):
63- strategy .add_implementation (
64- wrap_compute_conv2d (topi .cuda .conv2d_nchw_winograd ),
65- wrap_topi_schedule (topi .cuda .schedule_conv2d_nchw_winograd ),
66- name = "conv2d_nchw_winograd.cuda" ,
67- plevel = 5 ,
68- )
69- elif layout == "NHWC" :
70- assert kernel_layout == "HWIO"
71- strategy .add_implementation (
72- wrap_compute_conv2d (topi .gpu .conv2d_nhwc ),
73- wrap_topi_schedule (topi .gpu .schedule_conv2d_nhwc ),
74- name = "conv2d_nhwc.gpu" ,
75- )
76- N , H , W , _ = get_const_tuple (data .shape )
77- KH , KW , CI , CO = get_const_tuple (kernel .shape )
37+ strategy = conv2d_strategy_cuda (attrs , inputs , out_type , target )
7838
79- (_ , judge_winograd_autotvm , judge_winograd_auto_scheduler ,) = judge_winograd (
80- N ,
81- H ,
82- W ,
83- KH ,
84- KW ,
85- CI ,
86- CO ,
87- padding ,
88- stride_h ,
89- stride_w ,
90- dilation_h ,
91- dilation_w ,
92- data .dtype ,
93- kernel .dtype ,
94- pre_flag = False ,
95- )
96-
97- if judge_winograd_autotvm :
98- strategy .add_implementation (
99- wrap_compute_conv2d (topi .cuda .conv2d_nhwc_winograd_direct ),
100- wrap_topi_schedule (topi .cuda .schedule_conv2d_nhwc_winograd_direct ),
101- name = "conv2d_nhwc_winograd_direct.cuda" ,
102- plevel = 5 ,
103- )
39+ # add miopen implementation
40+ if (
41+ "miopen" in target .libs
42+ and groups == 1
43+ and layout == "NCHW"
44+ and padding [0 ] == padding [2 ]
45+ and padding [1 ] == padding [3 ]
46+ ):
47+ strategy .add_implementation (
48+ wrap_compute_conv2d (topi .rocm .conv2d_nchw_miopen , True ),
49+ wrap_topi_schedule (topi .rocm .schedule_conv2d_nchw_miopen ),
50+ name = "conv2d_nchw_miopen.rocm" ,
51+ plevel = 50 ,
52+ )
10453
105- if is_auto_scheduler_enabled () and judge_winograd_auto_scheduler :
106- strategy .add_implementation (
107- wrap_compute_conv2d (topi .nn .conv2d_winograd_nhwc ),
108- naive_schedule , # this implementation should never be picked by autotvm
109- name = "conv2d_nhwc.winograd" ,
110- plevel = 15 ,
111- )
112- elif layout == "HWCN" :
113- assert kernel_layout == "HWIO"
114- strategy .add_implementation (
115- wrap_compute_conv2d (topi .cuda .conv2d_hwcn ),
116- wrap_topi_schedule (topi .cuda .schedule_conv2d_hwcn ),
117- name = "conv2d_hwcn.cuda" ,
118- )
119- elif layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
120- assert kernel_layout == "OIHW4o4i"
121- strategy .add_implementation (
122- wrap_compute_conv2d (topi .cuda .conv2d_NCHWc_int8 , True ),
123- wrap_topi_schedule (topi .cuda .schedule_conv2d_NCHWc_int8 ),
124- name = "conv2d_NCHWc_int8.cuda" ,
125- )
126- else :
127- raise RuntimeError ("Unsupported conv2d layout {} for CUDA" .format (layout ))
128- # add miopen implementation
129- if (
130- "miopen" in target .libs
131- and layout == "NCHW"
132- and padding [0 ] == padding [2 ]
133- and padding [1 ] == padding [3 ]
134- ):
135- strategy .add_implementation (
136- wrap_compute_conv2d (topi .rocm .conv2d_nchw_miopen , True ),
137- wrap_topi_schedule (topi .rocm .schedule_conv2d_nchw_miopen ),
138- name = "conv2d_nchw_miopen.rocm" ,
139- plevel = 15 ,
140- )
141- elif is_depthwise_conv2d (data .shape , layout , kernel .shape , kernel_layout , groups ):
142- if layout == "NCHW" :
143- assert kernel_layout == "OIHW"
144- strategy .add_implementation (
145- wrap_compute_conv2d (topi .cuda .depthwise_conv2d_nchw ),
146- wrap_topi_schedule (topi .cuda .schedule_depthwise_conv2d_nchw ),
147- name = "depthwise_conv2d_nchw.cuda" ,
148- )
149- elif layout == "NHWC" :
150- assert kernel_layout == "HWOI"
151- strategy .add_implementation (
152- wrap_compute_conv2d (topi .nn .depthwise_conv2d_nhwc ),
153- wrap_topi_schedule (topi .cuda .schedule_depthwise_conv2d_nhwc ),
154- name = "depthwise_conv2d_nhwc.cuda" ,
155- )
156- else :
157- raise RuntimeError ("Unsupported depthwise_conv2d layout {}" .format (layout ))
158- else : # group_conv2d
159- if layout == "NCHW" :
160- # TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
161- assert kernel_layout == "OIHW"
162- strategy .add_implementation (
163- wrap_compute_conv2d (topi .cuda .group_conv2d_nchw , has_groups = True ),
164- wrap_topi_schedule (topi .cuda .schedule_group_conv2d_nchw ),
165- name = "group_conv2d_nchw.cuda" ,
166- )
167- elif layout == "NCHW4c" and data .dtype in ["int8" , "uint8" ]:
168- assert kernel_layout == "OIHW4o4i"
169- strategy .add_implementation (
170- wrap_compute_conv2d (topi .cuda .group_conv2d_NCHWc_int8 , True ),
171- wrap_topi_schedule (topi .cuda .schedule_group_conv2d_NCHWc_int8 ),
172- name = "group_conv2d_NCHWc_int8.cuda" ,
173- )
174- else :
175- raise RuntimeError ("Unsupported group_conv2d layout {}" .format (layout ))
17654 return strategy
17755
17856
17957@dense_strategy .register ("rocm" )
18058def dense_strategy_rocm (attrs , inputs , out_type , target ):
18159 """Dense strategy for ROCM"""
18260 assert len (inputs [0 ].shape ) == 2 and len (inputs [1 ].shape ) == 2 , "Only support 2-dim dense"
183- strategy = _op .OpStrategy ()
184- strategy .add_implementation (
185- wrap_compute_dense (topi .rocm .dense ),
186- wrap_topi_schedule (topi .rocm .schedule_dense ),
187- name = "dense.rocm" ,
188- )
189- data , weights = inputs
190- if (data .dtype == "int8"
191- and weights .dtype == "int8"
192- and out_type .dtype == "int32"
193- ):
194- strategy .add_implementation (
195- wrap_compute_dense (topi .cuda .dense_int8 ),
196- wrap_topi_schedule (topi .cuda .schedule_dense_int8 ),
197- name = "dense_int8.rocm" ,
198- )
61+ strategy = dense_strategy_cuda (attrs , inputs , out_type , target )
62+
19963 if target .kind .name == "rocm" and "rocblas" in target .libs :
20064 assert out_type .dtype == inputs [0 ].dtype , "Mixed precision not supported."
20165 strategy .add_implementation (
@@ -210,13 +74,8 @@ def dense_strategy_rocm(attrs, inputs, out_type, target):
21074@batch_matmul_strategy .register ("rocm" )
21175def batch_matmul_strategy_rocm (attrs , inputs , out_type , target ):
21276 """Batch matmul strategy for ROCM"""
213- strategy = _op .OpStrategy ()
214- strategy .add_implementation (
215- wrap_compute_batch_matmul (topi .cuda .batch_matmul , need_out_dtype = True ),
216- wrap_topi_schedule (topi .cuda .schedule_batch_matmul ),
217- name = "batch_matmul.cuda" ,
218- plevel = 10 ,
219- )
77+ strategy = batch_matmul_strategy_cuda (attrs , inputs , out_type , target )
78+
22079 if target .kind .name == "rocm" and "rocblas" in target .libs :
22180 assert out_type .dtype == inputs [0 ].dtype , "Mixed precision not supported."
22281 strategy .add_implementation (
0 commit comments