Skip to content

Commit

Permalink
Merge pull request #1615 from longjon/deconv-layer
Browse files Browse the repository at this point in the history
Add deconvolution layer with refactoring of convolution layer to share code
  • Loading branch information
shelhamer committed Feb 1, 2015
2 parents 92a66e3 + 25c2e3f commit 9767b99
Show file tree
Hide file tree
Showing 8 changed files with 794 additions and 341 deletions.
165 changes: 133 additions & 32 deletions include/caffe/vision_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,101 @@

namespace caffe {

/**
* @brief Abstract base class that factors out the BLAS code common to
* ConvolutionLayer and DeconvolutionLayer.
*/
template <typename Dtype>
class BaseConvolutionLayer : public Layer<Dtype> {
public:
explicit BaseConvolutionLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

virtual inline int MinBottomBlobs() const { return 1; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline bool EqualNumBottomTopBlobs() const { return true; }

protected:
// Helper functions that abstract away the column buffer and gemm arguments.
// The last argument in forward_cpu_gemm is so that we can skip the im2col if
// we just called weight_cpu_gemm with the same input.
void forward_cpu_gemm(const Dtype* input, const Dtype* weights,
Dtype* output, bool skip_im2col = false);
void forward_cpu_bias(Dtype* output, const Dtype* bias);
void backward_cpu_gemm(const Dtype* input, const Dtype* weights,
Dtype* output);
void weight_cpu_gemm(const Dtype* input, const Dtype* output, Dtype*
weights);
void backward_cpu_bias(Dtype* bias, const Dtype* input);

#ifndef CPU_ONLY
void forward_gpu_gemm(const Dtype* col_input, const Dtype* weights,
Dtype* output, bool skip_im2col = false);
void forward_gpu_bias(Dtype* output, const Dtype* bias);
void backward_gpu_gemm(const Dtype* input, const Dtype* weights,
Dtype* col_output);
void weight_gpu_gemm(const Dtype* col_input, const Dtype* output, Dtype*
weights);
void backward_gpu_bias(Dtype* bias, const Dtype* input);
#endif

// reverse_dimensions should return true iff we are implementing deconv, so
// that conv helpers know which dimensions are which.
virtual bool reverse_dimensions() = 0;
// Compute height_out_ and width_out_ from other parameters.
virtual void compute_output_shape() = 0;

int kernel_h_, kernel_w_;
int stride_h_, stride_w_;
int num_;
int channels_;
int pad_h_, pad_w_;
int height_, width_;
int group_;
int num_output_;
int height_out_, width_out_;
bool bias_term_;
bool is_1x1_;

private:
// wrap im2col/col2im so we don't have to remember the (long) argument lists
inline void conv_im2col_cpu(const Dtype* data, Dtype* col_buff) {
im2col_cpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
}
inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) {
col2im_cpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
}
#ifndef CPU_ONLY
inline void conv_im2col_gpu(const Dtype* data, Dtype* col_buff) {
im2col_gpu(data, conv_in_channels_, conv_in_height_, conv_in_width_,
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, col_buff);
}
inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) {
col2im_gpu(col_buff, conv_in_channels_, conv_in_height_, conv_in_width_,
kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_, data);
}
#endif

int conv_out_channels_;
int conv_in_channels_;
int conv_out_spatial_dim_;
int conv_in_height_;
int conv_in_width_;
int kernel_dim_;
int weight_offset_;
int col_offset_;
int output_offset_;

Blob<Dtype> col_buffer_;
Blob<Dtype> bias_multiplier_;
};

/**
* @brief Convolves the input image with a bank of learned filters,
* and (optionally) adds biases.
Expand All @@ -33,7 +128,7 @@ namespace caffe {
* the output channel N' columns of the output matrix.
*/
template <typename Dtype>
class ConvolutionLayer : public Layer<Dtype> {
class ConvolutionLayer : public BaseConvolutionLayer<Dtype> {
public:
/**
* @param param provides ConvolutionParameter convolution_param,
Expand Down Expand Up @@ -64,18 +159,10 @@ class ConvolutionLayer : public Layer<Dtype> {
* kernels + stream parallelism) engines.
*/
explicit ConvolutionLayer(const LayerParameter& param)
: Layer<Dtype>(param) {}
virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);

: BaseConvolutionLayer<Dtype>(param) {}
virtual inline LayerParameter_LayerType type() const {
return LayerParameter_LayerType_CONVOLUTION;
}
virtual inline int MinBottomBlobs() const { return 1; }
virtual inline int MinTopBlobs() const { return 1; }
virtual inline bool EqualNumBottomTopBlobs() const { return true; }

protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
Expand All @@ -86,30 +173,44 @@ class ConvolutionLayer : public Layer<Dtype> {
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual inline bool reverse_dimensions() { return false; }
virtual void compute_output_shape();
};

int kernel_h_, kernel_w_;
int stride_h_, stride_w_;
int num_;
int channels_;
int pad_h_, pad_w_;
int height_, width_;
int group_;
int num_output_;
int height_out_, width_out_;
bool bias_term_;
bool is_1x1_;
/**
* @brief Convolve the input with a bank of learned filters, and (optionally)
* add biases, treating filters and convolution parameters in the
* opposite sense as ConvolutionLayer.
*
* ConvolutionLayer computes each output value by dotting an input window with
* a filter; DeconvolutionLayer multiplies each input value by a filter
* elementwise, and sums over the resulting output windows. In other words,
* DeconvolutionLayer is ConvolutionLayer with the forward and backward passes
* reversed. DeconvolutionLayer reuses ConvolutionParameter for its
* parameters, but they take the opposite sense as in ConvolutionLayer (so
* padding is removed from the output rather than added to the input, and
* stride results in upsampling rather than downsampling).
*/
template <typename Dtype>
class DeconvolutionLayer : public BaseConvolutionLayer<Dtype> {
public:
explicit DeconvolutionLayer(const LayerParameter& param)
: BaseConvolutionLayer<Dtype>(param) {}
virtual inline LayerParameter_LayerType type() const {
return LayerParameter_LayerType_DECONVOLUTION;
}

/// M_ is the channel dimension of the output for a single group, which is the
/// leading dimension of the filter matrix.
int M_;
/// K_ is the dimension of an unrolled input for a single group, which is the
/// leading dimension of the data matrix.
int K_;
/// N_ is the spatial dimension of the output, the H x W, which are the last
/// dimensions of the data and filter matrices.
int N_;
Blob<Dtype> col_buffer_;
Blob<Dtype> bias_multiplier_;
protected:
virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top);
virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
virtual inline bool reverse_dimensions() { return true; }
virtual void compute_output_shape();
};

#ifdef USE_CUDNN
Expand Down
Loading

0 comments on commit 9767b99

Please sign in to comment.