Skip to content

Commit

Permalink
[TOP] fix weight layout in conv2d_transpose (apache#220)
Browse files Browse the repository at this point in the history
* update tvm

* [TOP] fix weight layout in conv2d_transpose
  • Loading branch information
Huyuwei authored and tqchen committed Nov 6, 2017
1 parent 583a461 commit 969d79b
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
9 changes: 3 additions & 6 deletions src/top/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,11 @@ inline bool Conv2DTransposeInferShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(param.dilation.ndim(), 2U)
<< "incorrect dilate size: " << param.dilation;

TShape wshape({param.channels / param.groups,
dshape_nchw[1] / param.groups,
TShape wshape({dshape_nchw[1],
param.channels / param.groups,
param.kernel_size[0],
param.kernel_size[1]});

wshape = ConvertLayout(wshape, kNCHW, param.layout);
wshape[0] *= param.groups;

NNVM_ASSIGN_INPUT_SHAPE(attrs, *in_shape, Conv2DTransposeParam::kWeight, wshape);

if (param.use_bias) {
Expand Down Expand Up @@ -192,7 +189,7 @@ said convolution.
- **data**: This depends on the `layout` parameter. Input is 4D array of shape
(batch_size, in_channels, height, width) if `layout` is `NCHW`.
- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1])
- **weight**: (in_channels, channels, kernel_size[0], kernel_size[1])
- **bias**: (channels,)
- **out**: This depends on the `layout` parameter. Output is 4D array of shape
(batch_size, channels, out_height, out_width) if `layout` is `NCHW`.
Expand Down
6 changes: 3 additions & 3 deletions tests/python/compiler/test_top_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ def test_conv2d_transpose():
name="y", padding=(1,1), output_padding=(2,2))
dtype = "float32"
dshape = (1, 3, 18, 18)
kshape = (10, 3, 3, 3)
kshape = (3, 10, 3, 3)
oshape = (1, 10, 37, 37)
shape_dict = {"x": dshape}
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, shape_dict)
m = graph_runtime.create(graph, lib, ctx)
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
kernel = tvm.nd.array(np.random.uniform(size=kshape).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[0]).astype(dtype))
bias = tvm.nd.array(np.random.uniform(size=kshape[1]).astype(dtype))
m.run(x=data, y_weight=kernel, y_bias=bias)
out = m.get_output(0, tvm.nd.empty(oshape, dtype))
c_np = topi.testing.conv2d_transpose_nchw_python(
data.asnumpy(), kernel.asnumpy(), 2, 1)
c_np = c_np + bias.asnumpy().reshape(kshape[0], 1, 1)
c_np = c_np + bias.asnumpy().reshape(kshape[1], 1, 1)
d_np = np.zeros(shape=oshape)
d_np[:,:,0:c_np.shape[2],0:c_np.shape[3]] = c_np
np.testing.assert_allclose(out.asnumpy(), d_np, rtol=1e-5)
Expand Down
2 changes: 1 addition & 1 deletion tvm
Submodule tvm updated from a152a9 to 0fa4d9

0 comments on commit 969d79b

Please sign in to comment.