@@ -81,10 +81,7 @@ def __init__(self):
8181 auto dilated_w = (filter_size.w() - 1) * dilation.column() + 1;
8282 auto h = (input_size.h() + padding.n() + padding.h() - dilated_h) / conv_stride.row() + 1;
8383 auto w = (input_size.w() + padding.w() + padding.c() - dilated_w) / conv_stride.column() + 1;
84- return cutlass::Tensor4DCoord(
85- input_size.n(),
86- h, w,
87- filter_size.n());
84+ return cutlass::Tensor4DCoord(input_size.n(), h, w, filter_size.n());
8885 }
8986};
9087
@@ -98,31 +95,6 @@ def __init__(self):
9895 cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(oshape);
9996 cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(oshape);
10097
101- cutlass::reference::host::TensorFillRandomUniform(
102- tensor_a.host_view(),
103- 1,
104- ElementInputA(7),
105- ElementInputA(-8),
106- 0);
107-
108- cutlass::reference::host::TensorFillRandomUniform(
109- tensor_b.host_view(),
110- 1,
111- ElementInputB(7),
112- ElementInputB(-8),
113- 0);
114-
115- cutlass::reference::host::TensorFill(
116- tensor_c.host_view());
117-
118- cutlass::reference::host::TensorFill(
119- tensor_ref_c.host_view());
120-
121- tensor_a.sync_device();
122- tensor_b.sync_device();
123- tensor_c.sync_device();
124- tensor_ref_c.sync_device();
125-
12698 cutlass::conv::Conv2dProblemSize problem_size(
12799 options.input_size,
128100 options.filter_size,
@@ -137,12 +109,12 @@ def __init__(self):
137109 using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
138110 typename ImplicitGemm::Arguments arguments{
139111 problem_size,
140- tensor_a.device_ref(),
141- tensor_b.device_ref(),
142- tensor_c.device_ref(),
143- tensor_c.device_ref(),
144- {ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
145- };
112+ tensor_a.device_ref(),
113+ tensor_b.device_ref(),
114+ tensor_c.device_ref(),
115+ tensor_c.device_ref(),
116+ {ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
117+ };
146118
147119 ImplicitGemm implicit_gemm_op;
148120 size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
0 commit comments