Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add symbol.SwapAxis operator, just can do Forward(). #502

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,4 @@ List of Contributors
* [Xiaodong](https://github.com/XD-DENG)
* [Nan Xiao](https://github.com/road2stat)
* [Junyuan Xie](https://github.com/piiswrong)
* [Ming Zhang](https://github.com/starimpact)
6 changes: 3 additions & 3 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ ADD_CFLAGS =
#---------------------------------------------

# whether use CUDA during compile
USE_CUDA = 0
USE_CUDA = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change this back to default, as most users don't have cuda


# add the path to CUDA libary to link and compile flag
# if you have already add them to enviroment variable, leave it as NONE
# USE_CUDA_PATH = /usr/local/cuda
USE_CUDA_PATH = NONE
USE_CUDA_PATH = /usr/local/cuda
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not change config.mk, copy it to root, and change it on project root


# whether use CUDNN R3 library
USE_CUDNN = 0
USE_CUDNN = 1

# whether use opencv during compilation
# you can disable it, however, you will not able to use
Expand Down
252 changes: 252 additions & 0 deletions src/operator/swapaxis-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
/*!
* Copyright (c) 2015 by Contributors
* \file swapaxis-inl.h
* \brief
* \author Ming Zhang
*/
#ifndef MXNET_OPERATOR_SWAPAXIS_INL_H_
#define MXNET_OPERATOR_SWAPAXIS_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <algorithm>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include "./operator_common.h"

#define SWAPAXIS_DBG 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove debug macro


namespace mxnet {
namespace op {

struct SwapAxis{
enum SwapAxisOpInputs {kData};
enum SwapAxisOpOutputs {kOut};
};

using namespace mshadow;
using namespace mshadow::expr;

struct SwapAxisParam : public dmlc::Parameter<SwapAxisParam> {
// use int for enumeration
uint32_t dim1, dim2;
DMLC_DECLARE_PARAMETER(SwapAxisParam) {
DMLC_DECLARE_FIELD(dim1)
.set_default(0)
.describe("the first axis to be swapped.");
DMLC_DECLARE_FIELD(dim2)
.set_default(0)
.describe("the second axis to be swapped.");
}
};

//a1 higher dimension to be swapped, assert a1 > a2
//a2 lower dimension to be swapped
template<typename xpu>
class SwapAxisOp : public Operator {
public:
explicit SwapAxisOp(SwapAxisParam p) {
CHECK_LT(p.dim1, p.dim2) << "dim1 must be lower than dim2.";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider automatically swap if not the case

#if SWAPAXIS_DBG
printf("hello swapaxis SwapAxisOp:dim1:%d, dim2:%d!\n", p.dim1, p.dim2);
#endif
this->param_ = p;
}


void Reshape2Five(Shape<5> &inter_shape, TShape &shape, uint32_t dim1, uint32_t dim2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use google C style pass non-const ref as pointer, and other as const reference

{
int ndim_in = shape.ndim();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

index_t

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function can be static

int si;
for (si = 0; si < 5; si++)
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: use for() {

inter_shape[si] = 1;
}
//dim_0
for (si = 0; si < dim1; si++)
{
inter_shape[0] *= shape[si];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

}
//dim_1
inter_shape[1] = shape[dim1];
//dim_2
for (si = dim1+1; si < dim2; si++)
{
inter_shape[2] *= shape[si];
}
//dim_3
inter_shape[3] = shape[dim2];
//dim_4
for (si = dim2+1; si < ndim_in; si++)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space between dim2 and +1.

{
inter_shape[4] *= shape[si];
}
}

void __swapaxis(Stream<xpu> *s, const std::vector<TBlob> &in_data, const std::vector<TBlob> &out_data)
{
uint32_t dim1 = param_.dim1;
uint32_t dim2 = param_.dim2;

TBlob data_in = in_data[SwapAxis::kData];
TBlob data_out = out_data[SwapAxis::kData];

TShape shape_in = data_in.shape_;
TShape shape_out = data_out.shape_;

Shape<5> inter_shape;

Reshape2Five(inter_shape, shape_in, dim1, dim2);

Tensor<xpu, 5> inter_data_in = data_in.get_with_shape<xpu, 5, real_t>(inter_shape, s);

int dwTmp = 0;

Shape<5> inter_shape2 = inter_shape;
inter_shape2[1] = inter_shape[3];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::swap

inter_shape2[3] = inter_shape[1];

Tensor<xpu, 5> inter_data_out = data_out.get_with_shape<xpu, 5, real_t>(inter_shape2, s);

TShape shape_tmp = shape_in;
dwTmp = shape_tmp[dim1];
shape_tmp[dim1] = shape_tmp[dim2];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

std::swap(shape_tmp[dim1], shap_tmp[dim2])

shape_tmp[dim2] = dwTmp;

CHECK(shape_out == shape_tmp);

inter_data_out = swapaxis<3, 1>(inter_data_in);
}

virtual void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &aux_args) {

#if SWAPAXIS_DBG
printf("hello swapaxis Forward!\n");
#endif
Stream<xpu> *s = ctx.get_stream<xpu>();

__swapaxis(s, in_data, out_data);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Google code style SwapAxis

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename _swapaxis to SwapAxis


#if 0
delete []aDims_in;
#endif
}

virtual void Backward(const OpContext &ctx,
const std::vector<TBlob> &out_grad,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &aux_args) {
#if SWAPAXIS_DBG
printf("hello swapaxis Backward!\n");
#endif
Stream<xpu> *s = ctx.get_stream<xpu>();
// __swapaxis(s, in_data, out_data);
__swapaxis(s, out_grad, in_grad);
}

SwapAxisParam param_;
};


template<typename xpu>
Operator* CreateOp(SwapAxisParam param);


#if DMLC_USE_CXX11
class SwapAxisProp : public OperatorProperty {
public:
std::vector<std::string> ListArguments() const override {
return {"data"};
}

void Init(const std::vector<std::pair<std::string, std::string> >& kwargs) override {
param_.Init(kwargs);
}

std::map<std::string, std::string> GetParams() const override {
return param_.__DICT__();
}

bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape,
std::vector<TShape> *aux_shape) const override {
int input_num = in_shape->size();
if (input_num == 0)
{
std::cout << "Have no input data.\n";
return false;
}
TShape &shape0 = (*in_shape)[SwapAxis::kData];
#if SWAPAXIS_DBG
printf("in_shape_num:%d\n", input_num);
printf("in_shape_0, dim:%d, size:%d\n", (int)shape0.ndim(), (int)shape0.Size());
#endif
if (shape0.ndim() != 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this restriction is no longer needed

{
std::cout << "Input data should be 4D.\n";
return false;
}
out_shape->clear();
out_shape->push_back(shape0);
TShape &shape1 = (*out_shape)[SwapAxis::kOut];
#if 1
int tmp = 0;
tmp = shape1[param_.dim1];
shape1[param_.dim1] = shape1[param_.dim2];
shape1[param_.dim2] = tmp;
#endif
#if SWAPAXIS_DBG
for (int i = 0; i < 4; i++)
{
printf("%d[%d], ", shape1[i], shape0[i]);
}
printf("\n");
#endif
return true;
}

OperatorProperty* Copy() const override {
auto ptr = new SwapAxisProp();
ptr->param_ = param_;
return ptr;
}

std::string TypeString() const override {
return "SwapAxis";
}
/*
std::vector<int> DeclareBackwardDependency(
const std::vector<int> &out_grad,
const std::vector<int> &in_data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

override DeclarebackwardDependency

return out_grad[SwapAxis::kOut];

const std::vector<int> &out_data) const override;

std::vector<ResourceRequest> ForwardResource(
const std::vector<TShape> &in_shape) const override;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the commented fields

std::vector<ResourceRequest> BackwardResource(
const std::vector<TShape> &in_shape) const override;
*/
Operator* CreateOperator(Context ctx) const override;

private:
SwapAxisParam param_;
}; // class SwapAxisProp
#endif // DMLC_USE_CXX11


}
}

#endif


35 changes: 35 additions & 0 deletions src/operator/swapaxis.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*!
* Copyright (c) 2015 by Contributors
* \file swapaxis.cc
* \brief
* \author Ming Zhang
*/

#include "./swapaxis-inl.h"

namespace mxnet {
namespace op {

template<>
Operator* CreateOp<cpu>(SwapAxisParam param) {
return new SwapAxisOp<cpu>(param);
}

Operator* SwapAxisProp::CreateOperator(Context ctx) const {
#if SWAPAXIS_DBG
printf("hello swapaxis CreateOperator!\n");
#endif
DO_BIND_DISPATCH(CreateOp, param_);
}


DMLC_REGISTER_PARAMETER(SwapAxisParam);

MXNET_REGISTER_OP_PROPERTY(SwapAxis, SwapAxisProp)
.add_argument("data", "Symbol", "Input data to the SwapAxisOp.")
.add_arguments(SwapAxisParam::__FIELDS__())
.describe("Apply swapaxis to input.");


}
}
20 changes: 20 additions & 0 deletions src/operator/swapaxis.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*!
* Copyright (c) 2015 by Contributors
* \file swapaxis.cu
* \brief
* \author Ming Zhang
*/

#include "./swapaxis-inl.h"

namespace mxnet {
namespace op {

template<>
Operator *CreateOp<gpu>(SwapAxisParam param) {
return new SwapAxisOp<gpu>(param);
}

} // namespace op
} // namespace mxnet