Skip to content

Commit

Permalink
Add backwardusenone
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Mar 5, 2020
1 parent 6b89506 commit 37d81c8
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 93 deletions.
4 changes: 2 additions & 2 deletions src/operator/tensor/elemwise_binary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ NNVM_REGISTER_OP(_grad_add)

NNVM_REGISTER_OP(_backward_add)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::BackwardUseNoneWithHalf2<gpu, mshadow_op::identity,
VectorizedBackwardUseNoneCompute<mshadow_op::identity,
mshadow_op::identity>);

NNVM_REGISTER_OP(elemwise_sub)
Expand All @@ -236,7 +236,7 @@ NNVM_REGISTER_OP(elemwise_sub)

NNVM_REGISTER_OP(_backward_sub)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::BackwardUseNoneWithHalf2<gpu, mshadow_op::identity,
VectorizedBackwardUseNoneCompute<mshadow_op::identity,
mshadow_op::negation>);

NNVM_REGISTER_OP(elemwise_mul)
Expand Down
265 changes: 174 additions & 91 deletions src/operator/tensor/elemwise_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,6 @@ class VectorizedStorage {
} scratch_;
};

template <typename LType>
MSHADOW_XINLINE void ldg(LType* dst, const LType* src) {
*dst = *src;
}

template <>
MSHADOW_XINLINE void ldg(double* dst, const double* src) {
double temp;
asm volatile ("ld.global.f64 %0, [%1];" :
"=d"(temp) :
"l"(src));
*dst = temp;
}

/*template <>*/
/*MSHADOW_XINLINE void ldg(uint64_t* dst, const uint64_t* src) {*/
/*uint64_t temp;*/
/*asm volatile ("ld.global.u64 %0, [%1];" :*/
/*"=l"(temp) :*/
/*"l"(src));*/
/**dst = temp;*/
/*}*/

template <typename DType, typename LType, bool aligned = false>
class VectorizedAccessor {
public:
Expand Down Expand Up @@ -104,12 +81,10 @@ class VectorizedAccessor {

MSHADOW_XINLINE void load(const index_t id, const index_t N) {
if (aligned) {
ldg<typename std::remove_const<LType>::type>(&(storage_.scratch_.aligned),
aligned_ptr_ + id);
storage_.scratch_.aligned = aligned_ptr_[id];
} else {
if (id > 0 && id < n_elems_ - 1) {
ldg<typename std::remove_const<LType>::type>(&(storage_.scratch_.aligned),
aligned_ptr_ + id);
storage_.scratch_.aligned = aligned_ptr_[id];
} else {
#pragma unroll
for (int j = 0; j < storage_.nvec; ++j) {
Expand Down Expand Up @@ -161,11 +136,19 @@ class VectorizedStorer : public VectorizedAccessor<DType, LType, aligned> {

namespace {

template <typename DType, int NumInputs, int NumOutputs>
struct VectorizedKernelParams {
const DType* inputs[NumInputs];
DType* outputs[NumOutputs];
};


template <bool aligned, typename DType, typename LType, typename OP, int req>
__global__ void VectorizedElementwiseKernel(DType* output, const DType* input0, const DType* input1, index_t N) {
VectorizedLoader<DType, LType, aligned> loader0(input0, N);
VectorizedLoader<DType, LType, aligned> loader1(input1, N);
VectorizedStorer<DType, LType, aligned> storer(output, N);
__global__ void VectorizedBinaryKernelFwd(const VectorizedKernelParams<DType, 2, 1> params,
const index_t N) {
VectorizedLoader<DType, LType, aligned> loader0(params.inputs[0], N);
VectorizedLoader<DType, LType, aligned> loader1(params.inputs[1], N);
VectorizedStorer<DType, LType, aligned> storer(params.outputs[0], N);

const index_t M = loader0.num_aligned_elements();

Expand All @@ -192,6 +175,82 @@ __global__ void VectorizedElementwiseKernel(DType* output, const DType* input0,
}
}

template <bool aligned, typename DType, typename LType,
typename LOP, typename ROP, int lreq, int rreq>
__global__ void VectorizedBinaryKernelBwdUseNone(const VectorizedKernelParams<DType, 1, 2> params,
const index_t N) {
VectorizedLoader<DType, LType, aligned> loader(params.inputs[0], N);
VectorizedStorer<DType, LType, aligned> lstorer(params.outputs[0], N);
VectorizedStorer<DType, LType, aligned> rstorer(params.outputs[1], N);

const index_t M = loader.num_aligned_elements();

for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
loader.load(tid, N);
if (lreq == kAddTo) {
lstorer.load(tid, N);
}
if (rreq == kAddTo) {
rstorer.load(tid, N);
}
#pragma unroll
for (int i = 0; i < loader.storage_.nvec; ++i) {
DType inp = loader.storage_.scratch_.separate[i];
if (!((std::is_same<LOP, mshadow_op::identity>::value && lreq == kWriteInplace) ||
lreq == kNullOp)) {
DType ltemp = LOP::Map(inp);
if (lreq == kAddTo) {
lstorer.storage_.scratch_.separate[i] += ltemp;
} else {
lstorer.storage_.scratch_.separate[i] = ltemp;
}
lstorer.store(tid, N);
}
if (!((std::is_same<ROP, mshadow_op::identity>::value && rreq == kWriteInplace) ||
rreq == kNullOp)) {
DType rtemp = ROP::Map(inp);

if (rreq == kAddTo) {
rstorer.storage_.scratch_.separate[i] += rtemp;
} else {
rstorer.storage_.scratch_.separate[i] = rtemp;
}
rstorer.store(tid, N);
}
}
}
}

template <typename DType, typename OP, int req>
class VectorizedBinaryFwd {
public:
using ParamType = VectorizedKernelParams<DType, 2, 1>;

template <bool aligned, typename LType>
static void Launch(const index_t blocks, const index_t threads,
cudaStream_t stream,
const ParamType params, const index_t N) {
VectorizedBinaryKernelFwd<aligned, DType, LType, OP, req>
<<<blocks, threads, 0, stream>>>(params, N);
}
};

template <typename DType, typename LOP, typename ROP, int lreq, int rreq>
class VectorizedBinaryBwdUseNone {
public:
using ParamType = VectorizedKernelParams<DType, 1, 2>;

template <bool aligned, typename LType>
static void Launch(const index_t blocks, const index_t threads,
cudaStream_t stream,
const ParamType params, const index_t N) {
VectorizedBinaryKernelBwdUseNone<aligned, DType, LType, LOP, ROP, lreq, rreq>
<<<blocks, threads, 0, stream>>>(params, N);
}
};

enum class Alignment {
SAME_ALIGNED,
SAME_UNALIGNED,
Expand All @@ -204,10 +263,22 @@ int CalcAlignment(const DType* ptr) {
return ptr_as_number % sizeof(LType);
}

template <typename LType, typename DType>
Alignment CheckAlignment(const std::vector<DType*>& pointers) {
template <typename LType, typename DType, int N, int M>
Alignment CheckAlignment(const VectorizedKernelParams<DType, N, M>& params) {
int align = -1;
for (const DType* ptr : pointers) {

for (const DType* ptr : params.inputs) {
int new_align = CalcAlignment<LType>(ptr);
if (align == -1) {
align = new_align;
} else {
if (align != new_align) {
return Alignment::DIFFERENT;
}
}
}

for (const DType* ptr : params.outputs) {
int new_align = CalcAlignment<LType>(ptr);
if (align == -1) {
align = new_align;
Expand All @@ -222,80 +293,92 @@ Alignment CheckAlignment(const std::vector<DType*>& pointers) {
: Alignment::SAME_UNALIGNED;
}

size_t minthree(const size_t a, const size_t b, const size_t c) {
return a < b ? (a < c ? a : c) : (b < c ? b : c);
template <typename DType, typename LType, typename Kernel>
void VectorizedKernelLauncher(const index_t size, mshadow::Stream<gpu>* s, typename Kernel::ParamType params) {
static_assert(sizeof(LType) >= sizeof(DType), "Load type is smaller than operand type");
if (size != 0) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
constexpr int nvec = sizeof(LType) / sizeof(DType);
VectorizedLoader<DType, LType> l(params.inputs[0], size);
size_t num_elements = l.num_aligned_elements();
constexpr int threads = 512;
constexpr int max_blocks = 65535;
index_t blocks = std::min(static_cast<int>((num_elements + threads - 1) / threads),
max_blocks);
auto align = CheckAlignment<LType, DType>(params);
if (align == Alignment::SAME_ALIGNED && (size % nvec == 0)) {
Kernel::template Launch<true, LType>(blocks, threads, stream, params, size);
} else {
if (align != Alignment::DIFFERENT) {
Kernel::template Launch<false, LType>(blocks, threads, stream, params, size);
} else {
index_t blocks = std::min(static_cast<int>((size + threads - 1) /
threads),
max_blocks);
// If the pointers are aligned differently we cannot vectorize
Kernel::template Launch<true, DType>(blocks, threads, stream, params, size);
}
}
}
}

} // namespace

template<typename OP>
void VectorizedCompute(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
Stream<gpu> *s = ctx.get_stream<gpu>();
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (dmlc::GetEnv("DEBUG_VECTOR", false)) {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
Kernel<mxnet_op::op_with_req<OP, Req>, gpu>::Launch(s, size,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
}
});
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
using LType = uint4;
using Kernel = VectorizedBinaryFwd<DType, OP, Req>;

const index_t size = outputs[0].Size();
typename Kernel::ParamType params;
params.inputs[0] = inputs[0].dptr<DType>();
params.inputs[1] = inputs[1].dptr<DType>();
params.outputs[0] = outputs[0].dptr<DType>();

VectorizedKernelLauncher<DType, LType, Kernel>(size, s, params);
});
});
}

template<typename LOP, typename ROP>
void VectorizedBackwardUseNoneCompute(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);

MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
const index_t size = inputs[0].Size();
if (req[0] != kNullOp || req[1] != kNullOp) {
MXNET_ASSIGN_REQ_SWITCH(req[0], lreq, {
MXNET_ASSIGN_REQ_SWITCH(req[1], rreq, {
using LType = uint4;
static_assert(sizeof(LType) >= sizeof(DType), "Load type is smaller than operand type");
if (outputs[0].Size() != 0) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
constexpr int nvec = sizeof(LType) / sizeof(DType);
VectorizedLoader<DType, LType> l(outputs[0].dptr<DType>(), outputs[0].Size());
size_t num_elements = l.num_aligned_elements();
constexpr int threads = 512;
index_t blocks = std::min(static_cast<int>((num_elements + threads - 1) / threads),
65535);
auto align = CheckAlignment<LType, DType>({outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>()});
if (align == Alignment::SAME_ALIGNED && (outputs[0].Size() % nvec == 0)) {
VectorizedElementwiseKernel<true, DType, LType, OP, Req>
<<<blocks, threads, 0, stream>>>(outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].Size());
} else {
if (align != Alignment::DIFFERENT) {
VectorizedElementwiseKernel<false, DType, LType, OP, Req>
<<<blocks, threads, 0, stream>>>(outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].Size());
} else {
index_t blocks = std::min(static_cast<int>((outputs[0].Size() + threads - 1) /
threads),
65535);
// If the pointers are aligned differently we cannot vectorize
VectorizedElementwiseKernel<true, DType, DType, OP, Req>
<<<blocks, threads, 0, stream>>>(outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].Size());
}
}
}
using Kernel = VectorizedBinaryBwdUseNone<DType, LOP, ROP, lreq, rreq>;

typename Kernel::ParamType params;
params.inputs[0] = inputs[0].dptr<DType>();
params.outputs[0] = outputs[0].dptr<DType>();
params.outputs[1] = outputs[1].dptr<DType>();

VectorizedKernelLauncher<DType, LType, Kernel>(size, s, params);
});
}
});
}
});
}

} // namespace

} // namespace op
} // namespace mxnet

Expand Down

0 comments on commit 37d81c8

Please sign in to comment.