Skip to content

Commit

Permalink
rename to n_group
Browse files Browse the repository at this point in the history
  • Loading branch information
petrex committed Sep 20, 2019
1 parent 292e95d commit 7c49cfb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/contrib/miopen/conv_forward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
MIOpenThreadEntry* entry_ptr = MIOpenThreadEntry::ThreadLocal();
// assertion for group count;
assert(n_group > 0 && "Group Size > 0 is expected");
if(n_group>1)
mode = 2;
if(n_group > 1)
assert(mode > 1 && "Group /Depthwise Conv mode when num of groups > 1");
// Set Mode
entry_ptr->conv_entry.mode = static_cast<miopenConvolutionMode_t>(mode);
// Set Ctx
entry_ptr->conv_entry.ctx = TVMContext{kDLROCM, 0};
// Set Data Type
entry_ptr->conv_entry.data_type = static_cast<miopenDataType_t>(
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf) at
dtype); // MIOpen supports fp32(miopenFloat), fp16(miopenHalf), int32, int8 at
// this moment.
// Set Desc
MIOPEN_CALL(miopenInitConvolutionDescriptor(entry_ptr->conv_entry.conv_desc,
Expand All @@ -81,7 +81,7 @@ TVM_REGISTER_GLOBAL("tvm.contrib.miopen.conv2d.setup")
MIOPEN_CALL(miopenSet4dTensorDescriptor(entry_ptr->conv_entry.filter_desc,
entry_ptr->conv_entry.data_type,
w_dim0,
w_dim1/group,
w_dim1/n_group,
w_dim2,
w_dim3));
// Set Input
Expand Down

0 comments on commit 7c49cfb

Please sign in to comment.