Skip to content

Commit 49c8e68

Browse files
committed
[Accuracy diff] Fix accuracy diff for conv2d_transpose API with NHWC format
Fix gradient calculation error in conv2d_transpose when using NHWC format with padding > 0. The issue was in im2col_cfo_cpu.h where incorrect index calculation caused gradients to be shifted to wrong positions. Key changes: - Replace incorrect ternary operator index calculation with direct calculation and boundary checking in NHWC branches - Add TestWithSAMEPad_NHWC and TestWithSAMEPadGroups_NHWC test cases - Ensure gradients match PyTorch reference implementation - Fix code formatting to meet clang-format requirements
1 parent f2933f3 commit 49c8e68

File tree

2 files changed

+60
-20
lines changed

2 files changed

+60
-20
lines changed

paddle/phi/kernels/funcs/im2col_cfo_cpu.h

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -210,11 +210,15 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const phi::DenseTensor& im,
210210
std::memcpy(dst_data + plw, src_data, copy_size);
211211
} else {
212212
for (int kow = 0; kow < output_width - plw - prw; ++kow) {
213-
dst_data[plw + kow] =
214-
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
215-
kow) *
216-
im_channels +
217-
ic];
213+
int im_row = oh - plh + kh;
214+
int im_col = kow;
215+
if (im_row >= 0 && im_row < im_height && im_col >= 0 &&
216+
im_col < im_width) {
217+
dst_data[plw + kow] =
218+
im_data[(im_row * im_width + im_col) * im_channels + ic];
219+
} else {
220+
dst_data[plw + kow] = static_cast<T>(0);
221+
}
218222
}
219223
}
220224
dst_data = dst_data + col_matrix_width;
@@ -269,11 +273,15 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const phi::DenseTensor& im,
269273
sizeof(T) * (output_width - (plw - kw)));
270274
} else {
271275
for (int kow = 0; kow < output_width - (plw - kw); ++kow) {
272-
dst_data[plw - kw + kow] =
273-
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
274-
kow) *
275-
im_channels +
276-
ic];
276+
int im_row = oh - plh + kh;
277+
int im_col = kow;
278+
if (im_row >= 0 && im_row < im_height && im_col >= 0 &&
279+
im_col < im_width) {
280+
dst_data[plw - kw + kow] =
281+
im_data[(im_row * im_width + im_col) * im_channels + ic];
282+
} else {
283+
dst_data[plw - kw + kow] = static_cast<T>(0);
284+
}
277285
}
278286
}
279287
dst_data = dst_data + col_matrix_width;
@@ -284,11 +292,15 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const phi::DenseTensor& im,
284292
dst_data, src_data + (kw - plw), sizeof(T) * output_width);
285293
} else {
286294
for (int kow = 0; kow < output_width; ++kow) {
287-
dst_data[kow] =
288-
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
289-
kw - plw + kow) *
290-
im_channels +
291-
ic];
295+
int im_row = oh - plh + kh;
296+
int im_col = kw - plw + kow;
297+
if (im_row >= 0 && im_row < im_height && im_col >= 0 &&
298+
im_col < im_width) {
299+
dst_data[kow] =
300+
im_data[(im_row * im_width + im_col) * im_channels + ic];
301+
} else {
302+
dst_data[kow] = static_cast<T>(0);
303+
}
292304
}
293305
}
294306
dst_data = dst_data + col_matrix_width;
@@ -301,11 +313,15 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const phi::DenseTensor& im,
301313
sizeof(T) * (output_width - i));
302314
} else {
303315
for (int kow = 0; kow < output_width - i; ++kow) {
304-
dst_data[kow] =
305-
im_data[(((oh - plh > 0 ? oh - plh : 0) + kh) * im_width +
306-
kw - plw + kow) *
307-
im_channels +
308-
ic];
316+
int im_row = oh - plh + kh;
317+
int im_col = kw - plw + kow;
318+
if (im_row >= 0 && im_row < im_height && im_col >= 0 &&
319+
im_col < im_width) {
320+
dst_data[kow] =
321+
im_data[(im_row * im_width + im_col) * im_channels + ic];
322+
} else {
323+
dst_data[kow] = static_cast<T>(0);
324+
}
309325
}
310326
}
311327
dst_data = dst_data + col_matrix_width;

test/legacy_test/test_conv2d_transpose_op.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,5 +1575,29 @@ def init_data(self):
15751575
self.np_out = np.zeros([4, 0, 6, 6])
15761576

15771577

1578+
class TestWithSAMEPad_NHWC(TestConv2DTransposeOp):
1579+
def init_test_case(self):
1580+
self.stride = [1, 1]
1581+
self.dilations = [1, 1]
1582+
self.groups = 1
1583+
self.input_size = [1, 3, 3, 1] # NHWC
1584+
f_c = self.input_size[-1]
1585+
self.filter_size = [f_c, 2, 3, 3]
1586+
self.data_format = 'NHWC'
1587+
self.padding_algorithm = 'SAME'
1588+
1589+
1590+
class TestWithSAMEPadGroups_NHWC(TestConv2DTransposeOp):
1591+
def init_test_case(self):
1592+
self.stride = [1, 1]
1593+
self.dilations = [1, 1]
1594+
self.groups = 2
1595+
self.input_size = [1, 3, 3, 2] # NHWC
1596+
f_c = self.input_size[-1]
1597+
self.filter_size = [f_c, 1, 3, 3]
1598+
self.data_format = 'NHWC'
1599+
self.padding_algorithm = 'SAME'
1600+
1601+
15781602
if __name__ == '__main__':
15791603
unittest.main()

0 commit comments

Comments
 (0)