Skip to content

Commit

Permalink
Merge branch 'developing' into parser
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 authored Sep 19, 2018
2 parents 4df42ab + 689b210 commit 78dcce1
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 45 deletions.
4 changes: 3 additions & 1 deletion framework/core/net/net.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,11 @@ void Net<Ttype, Dtype, Ptype, RunType>::prediction() {
#define RECORD_INNER
#if defined(RECORD_INNER) && defined(USE_X86_PLACE)
record_tensor_to_file(*out,("record_"+executer.name).c_str());
if(executer.name=="")
#endif
LOG(INFO) <<executer.name <<" d_tensor_out_p :" <<out->data();
#ifdef USE_CUDA
record_tensor_to_file(*out,("record_"+executer.name).c_str());
#endif
#ifdef USE_X86_PLACE
// for (int i = 0; i < 10; ++i) {
// std::cout << out->data()[i]<<" ";
Expand Down
26 changes: 18 additions & 8 deletions framework/operators/fusion_ops/conv_3x3_batchnorm_scale_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,17 @@ Status SassConvBatchnormScaleReluHelper<Ttype, Dtype, Ptype>::InitParam() {

// get relu param
auto alpha = GET_PARAMETER(float, relu_0_alpha);
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_relu);//, alpha); // TEMP


ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param, batchnorm_param,
scale_param);
_param_conv_batchnorm_scale_relu = conv_act_param;
if (alpha != 0) {
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_prelu, alpha); // TEMP
ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param, batchnorm_param,
scale_param);
_param_conv_batchnorm_scale_relu = conv_act_param;
} else {
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_relu); // TEMP
ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param, batchnorm_param,
scale_param);
_param_conv_batchnorm_scale_relu = conv_act_param;
}

return Status::OK();
}
Expand All @@ -102,8 +107,13 @@ template<typename Ttype, DataType Dtype, Precision Ptype>
Status SassConvBatchnormScaleReluHelper<Ttype, Dtype, Ptype>::Init(OpContext<Ttype>& ctx,
const std::vector<Tensor4dPtr<Ttype, Dtype> >& ins,
std::vector<Tensor4dPtr<Ttype, Dtype> >& outs) {
_funcs_conv_batchnorm_scale_relu.init(ins, outs, _param_conv_batchnorm_scale_relu, SPECIFY,
SABER_IMPL, ctx);
if (_param_conv_batchnorm_scale_relu.activation_param.active == Active_relu) {
_funcs_conv_batchnorm_scale_relu.init(ins, outs, _param_conv_batchnorm_scale_relu, SPECIFY,
SABER_IMPL, ctx);
} else {
_funcs_conv_batchnorm_scale_relu.init(ins, outs, _param_conv_batchnorm_scale_relu, SPECIFY,
VENDER_IMPL, ctx);
}
return Status::OK();
}

Expand Down
19 changes: 14 additions & 5 deletions framework/operators/fusion_ops/conv_batchnorm_scale_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,19 @@ Status ConvBatchnormScaleReluHelper<Ttype, Dtype, Ptype>::InitParam() {

// get relu param
auto alpha = GET_PARAMETER(float, relu_0_alpha);
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_relu);//, alpha); // TEMP
if (alpha != 0) {
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_prelu, alpha); // TEMP
ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param, batchnorm_param,
scale_param);
_param_conv_batchnorm_scale_relu = conv_act_param;
} else {
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_relu); // TEMP
ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param, batchnorm_param,
scale_param);
_param_conv_batchnorm_scale_relu = conv_act_param;
}


ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param, batchnorm_param,
scale_param);
_param_conv_batchnorm_scale_relu = conv_act_param;

return Status::OK();
}
Expand All @@ -115,7 +122,9 @@ Status ConvBatchnormScaleReluHelper<Ttype, Dtype, Ptype>::InferShape(const
template <>
Status ConvBatchnormScaleReluHelper<NV, AK_FLOAT, Precision::FP32>::Init(OpContext<NV> &ctx, \
const std::vector<Tensor4dPtr<NV, AK_FLOAT> >& ins, std::vector<Tensor4dPtr<NV, AK_FLOAT> >& outs) {
if (_param_conv_batchnorm_scale_relu.conv_param.group == ins[0]->channel() && \
bool use_saber = true;
use_saber = use_saber && (_param_conv_batchnorm_scale_relu.activation_param.active == Active_relu);
if (use_saber && _param_conv_batchnorm_scale_relu.conv_param.group == ins[0]->channel() && \
_param_conv_batchnorm_scale_relu.conv_param.group == outs[0]->channel()) {
_funcs_conv_batchnorm_scale_relu.init(ins, outs, _param_conv_batchnorm_scale_relu, SPECIFY,
SABER_IMPL, ctx);
Expand Down
14 changes: 10 additions & 4 deletions framework/operators/fusion_ops/conv_relu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@ Status ConvReluHelper<Ttype, Dtype, Ptype>::InitParam() {

// get relu param
auto alpha = GET_PARAMETER(float, relu_0_alpha);
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_relu);//, alpha); // TEMP

if (alpha != 0) {
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_prelu, alpha); // TEMP
ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param);
_param_conv_relu = conv_act_param;
} else {
ActivationParam<Tensor4d<Ttype, Dtype>> active_param(Active_relu); // TEMP
ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param);
_param_conv_relu = conv_act_param;
}

ConvActiveParam<Tensor4d<Ttype, Dtype>> conv_act_param(_conv_param, active_param);
_param_conv_relu = conv_act_param;

return Status::OK();

Expand Down Expand Up @@ -100,6 +105,7 @@ Status ConvReluHelper<NV, AK_FLOAT, Precision::FP32>::Init(OpContext<NV>& ctx, \
use_saber = use_saber && (_param_conv_relu.conv_param.weight()->width()==3);
use_saber = use_saber && (_param_conv_relu.conv_param.dilation_h == 1);
use_saber = use_saber && (_param_conv_relu.conv_param.dilation_w == 1);
use_saber = use_saber && (_param_conv_relu.activation_param.active == Active_relu);
if (((_param_conv_relu.conv_param.group == 1) && use_saber)|| (_param_conv_relu.conv_param.group == ins[0]->channel() && \
_param_conv_relu.conv_param.group == outs[0]->channel())) {
_funcs_conv_relu.init(ins, outs, _param_conv_relu, SPECIFY, SABER_IMPL, ctx);
Expand Down
49 changes: 43 additions & 6 deletions saber/funcs/impl/cuda/base/cuda_c/saber_activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,34 @@ __global__ void ker_prelu_fwd(Dtype * out_data,
}
}

template<typename Dtype>
__global__ void ker_prelu_fwd(Dtype * out_data,
const Dtype* in_data, const int count,
const Dtype slope, bool is_channel_shared,
int in_n, int in_c, int in_h, int in_w,
int in_n_stride, int in_c_stride, int in_h_stride, int in_w_stride,
int out_n_stride, int out_c_stride, int out_h_stride, int out_w_stride) {
CUDA_KERNEL_LOOP(tid, count){
int w = tid % in_w;
int h = (tid / (in_w)) % in_h;
int c = (tid / (in_h * in_w)) % in_c;
int n = (tid / (in_c * in_h * in_w)) % in_n;

int in_idx = n * in_n_stride
+ c * in_c_stride
+ h * in_h_stride
+ w * in_w_stride;

int out_idx = n * out_n_stride
+ c * out_c_stride
+ h * out_h_stride
+ w * out_w_stride;

Dtype in_var = in_data[in_idx];
out_data[out_idx] = in_var > 0 ? in_var : slope * in_var;
}
}

template <>
SaberStatus SaberActivation<NV, AK_FLOAT, AK_FLOAT, AK_FLOAT, \
NCHW, NCHW, NCHW>::dispatch( \
Expand Down Expand Up @@ -248,12 +276,21 @@ SaberStatus SaberActivation<NV, AK_FLOAT, AK_FLOAT, AK_FLOAT, \
break;
case Active_prelu:
auto prelu_param = param.prelu_param;
ker_prelu_fwd<InDataType>
<<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, cuda_stream>>>(
out_data, in_data, count, prelu_param.slope->data(), prelu_param.channel_shared,
in_shape[0], in_shape[1], in_shape[2], in_shape[3],
stride_in[0], stride_in[1], stride_in[2], stride_in[3],
stride_out[0], stride_out[1], stride_out[2], stride_out[3]);
if (param.prelu_param.slope == nullptr) {
ker_prelu_fwd<InDataType>
<< < CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, cuda_stream >> > (
out_data, in_data, count, param.negative_slope, prelu_param.channel_shared,
in_shape[0], in_shape[1], in_shape[2], in_shape[3],
stride_in[0], stride_in[1], stride_in[2], stride_in[3],
stride_out[0], stride_out[1], stride_out[2], stride_out[3]);
} else {
ker_prelu_fwd<InDataType>
<< < CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, cuda_stream >> > (
out_data, in_data, count, prelu_param.slope->data(), prelu_param.channel_shared,
in_shape[0], in_shape[1], in_shape[2], in_shape[3],
stride_in[0], stride_in[1], stride_in[2], stride_in[3],
stride_out[0], stride_out[1], stride_out[2], stride_out[3]);
}
break;
}

Expand Down
26 changes: 10 additions & 16 deletions saber/funcs/impl/cuda/base/cuda_c/saber_resize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ __global__ void resize_bilinear_2d_kernel(const int wout, const int hout,
dtype br = (w > win || h > hin)? 0 : src[src_indexBR];
#else
dtype tl = src[src_indexTL];
dtype tr = w > win? 0 : src[src_indexTR];//w > win? 0 :
dtype bl = h > hin? 0 : src[src_indexBL];//h > hin? 0 :
dtype br = (w > win || h > hin)? 0 : src[src_indexBR];//(w > win || h > hin)? 0 :
dtype tr = w < win? src[src_indexTR]:0;//w > win? 0 :
dtype bl = h < hin? src[src_indexBL]:0;//h > hin? 0 :
dtype br = (w < win && h < hin) ? src[src_indexBR]: 0;//(w > win || h > hin)? 0 :
#endif
dst[dst_index] = static_cast<dtype>(w_00 * tl + w_01 * tr + w_10 * bl + w_11 * br);
src_indexBR += src_stride_c;
Expand Down Expand Up @@ -152,19 +152,13 @@ SaberStatus SaberResize<NV, OpDtype, inDtype, outDtype,\
int dst_stride_h = dst_real_shape.count(height_idx + 1);//outputs[0]->count(height_idx + 1, dims);
int dst_stride_channel = dst_real_shape.count(channel_idx + 1);//outputs[0]->count(channel_idx + 1, dims);
int dst_stride_batch = dst_real_shape.count(num_idx + 1);//outputs[0]->count(num_idx + 1, dims);
const InDataType* in_data_batch = in_data;
OutDataType* out_data_batch = out_data;
for (int i = 0; i < n_out; ++i) {
resize_bilinear_2d_kernel<OpDataType><<<grid, block, 0, stream>>>(
w_out, h_out, n_out, c_out,
dst_stride_w, dst_stride_h, dst_stride_channel, dst_stride_batch,
w_in, h_in,
src_stride_w, src_stride_h, src_stride_channel, src_stride_batch,
1 / param.width_scale, 1 / param.height_scale,
in_data, out_data);
in_data_batch += src_stride_batch;
out_data_batch += dst_stride_batch;
}
resize_bilinear_2d_kernel<OpDataType><<<grid, block, 0, stream>>>(
w_out, h_out, n_out, c_out,
dst_stride_w, dst_stride_h, dst_stride_channel, dst_stride_batch,
w_in, h_in,
src_stride_w, src_stride_h, src_stride_channel, src_stride_batch,
1 / param.width_scale, 1 / param.height_scale,
in_data, out_data);
//outputs[0]->record_event(stream);
return SaberSuccess;
}
Expand Down
2 changes: 1 addition & 1 deletion saber/funcs/impl/cuda/saber_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class SaberActivation<NV, OpDtype, inDtype, outDtype,\
virtual SaberStatus dispatch(const std::vector<DataTensor_in*>& inputs,
std::vector<DataTensor_out*>& outputs,
ActivationParam<OpTensor>& param);

OpTensor _slope;
};

//template class SaberActivation<NV, AK_FLOAT, AK_FLOAT, AK_FLOAT, NCHW, NCHW, NCHW>;
Expand Down
12 changes: 9 additions & 3 deletions saber/funcs/impl/cuda/vender_conv_act.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "saber/funcs/impl/cuda/vender_conv_act.h"
#include "saber/funcs/impl/cuda/saber_activation.h"
#include "cuda_fp16.h"

namespace anakin {
Expand All @@ -10,6 +11,9 @@ SaberStatus VenderConv2DAct<NV, AK_FLOAT, AK_FLOAT, AK_FLOAT, NCHW, NCHW, NCHW>:
std::vector<DataTensor_out *>& outputs,
ConvActiveParam<OpTensor>& param, Context<NV>& ctx) {

if (_use_saber_act) {
_saber_act.create(inputs, outputs, param.activation_param, ctx);
}
if (!(&ctx == this->_ctx)) {
if (_handle != NULL) {
CUDNN_CHECK(cudnnDestroy(_handle));
Expand Down Expand Up @@ -65,7 +69,7 @@ SaberStatus VenderConv2DAct<NV, AK_FLOAT, AK_FLOAT, AK_FLOAT, NCHW, NCHW, NCHW>:
inputs[0]->dims() - 2, pad_a,
filter_stride_a, dilation_a);
// set activation descriptor
if(param.has_active) {
if(param.has_active && !_use_saber_act) {
cudnn::set_activation_des<OpDataType>(&_active_descs, param.activation_param.active);
}

Expand Down Expand Up @@ -113,12 +117,11 @@ SaberStatus VenderConv2DAct<NV, AK_FLOAT, AK_FLOAT, AK_FLOAT, NCHW, NCHW, NCHW>:
dispatch(const std::vector<DataTensor_in*>& inputs,
std::vector<DataTensor_out*>& outputs,
ConvActiveParam<OpTensor>& param) {

const InDataType *in_data = (const InDataType*)inputs[0]->data();
InDataType *out_data = (InDataType*)outputs[0]->mutable_data();

const float *weight_data = (const float *) param.conv_param.weight()->data();
if (param.has_active == false) {
if (_use_saber_act || param.has_active == false) {
CUDNN_CHECK(cudnnConvolutionForward(_handle,
cudnn::cudnnTypeWrapper<float>::kOne(),
_input_descs, in_data,
Expand All @@ -140,6 +143,9 @@ SaberStatus VenderConv2DAct<NV, AK_FLOAT, AK_FLOAT, AK_FLOAT, NCHW, NCHW, NCHW>:
_output_descs, out_data));

}
if (_use_saber_act) {
_saber_act.dispatch(outputs, outputs, param.activation_param);
}
return SaberSuccess;
}

Expand Down
10 changes: 9 additions & 1 deletion saber/funcs/impl/cuda/vender_conv_act.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
#define ANAKIN_SABER_FUNCS_IMPL_CUDA_CUDNN_CONV_ACT_H

#include "saber/funcs/impl/impl_conv_act.h"
#include "saber/funcs/impl/cuda/cudnn_helper.h"
#include "saber/funcs/impl/cuda/cudnn_helper.h"
#include "saber/funcs/impl/cuda/saber_activation.h"
#include <cudnn.h>

namespace anakin{
Expand Down Expand Up @@ -119,6 +120,10 @@ class VenderConv2DAct<NV, OpDtype, inDtype, outDtype,\
std::vector<DataTensor_out *>& outputs,
ConvActiveParam<OpTensor>& param, Context<NV>& ctx) {
// ---- init cudnn resources ----
if (param.activation_param.active!= Active_relu) {
_use_saber_act = true;
_saber_act.init(inputs, outputs, param.activation_param, ctx);
}

_workspaceSizeInBytes = 0;
_workspaceData = NULL;
Expand Down Expand Up @@ -192,6 +197,9 @@ class VenderConv2DAct<NV, OpDtype, inDtype, outDtype,\
cudnnTensorDescriptor_t _input_nchw_descs;
cudnnTensorDescriptor_t _output_nchw_descs;

bool _use_saber_act{false};
SaberActivation<NV, OpDtype, inDtype, outDtype,\
LayOutType_op, LayOutType_in, LayOutType_out> _saber_act;
void *x8_data;
void *y8_data;

Expand Down

0 comments on commit 78dcce1

Please sign in to comment.