Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding large tensor support for dropout operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Oct 15, 2019
1 parent ac0030c commit b80b3a3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
12 changes: 6 additions & 6 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ class DropoutOp {
* \param input_data Input data to perform the dropout on
* \param pkeep Dropout rate (keep when the generated random number is less than this value)
*/
MSHADOW_XINLINE static void Map(int id,
MSHADOW_XINLINE static void Map(index_t id,
RandGenerator<xpu, DType> gen,
const int N,
const int step,
const index_t N,
const index_t step,
DType *dropout_out,
DType *mask_out,
const DType *input_data,
Expand All @@ -199,10 +199,10 @@ class DropoutOp {
};
struct BernoulliKernel {
/*! \brief Bernoulli kernel for generating mask */
MSHADOW_XINLINE static void Map(int id,
MSHADOW_XINLINE static void Map(index_t id,
RandGenerator<xpu, DType> gen,
const int N,
const int step,
const index_t N,
const index_t step,
DType *mask_out,
const real_t pkeep) {
RNG_KERNEL_LOOP(xpu, DType, id, gen, N, step, {
Expand Down
5 changes: 3 additions & 2 deletions src/operator/tensor/elemwise_binary_broadcast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,10 @@ inline int BinaryBroadcastShapeCompact(const mxnet::TShape& lshape, const mxnet:
*new_oshape = mxnet::TShape(odim, 1);
int bl = oshape.ndim() - lshape.ndim();
int br = oshape.ndim() - rshape.ndim();
int j = 0, lprod = 1, rprod = 1, oprod = 1;
int j = 0;
index_t lprod = 1, rprod = 1, oprod = 1;
for (int i = 0; i < oshape.ndim(); ++i) {
int l = 1, r = 1, o = oshape[i];
index_t l = 1, r = 1, o = oshape[i];
if (i >= bl) l = lshape[i-bl];
if (i >= br) r = rshape[i-br];
if ((lprod != rprod || l != r) &&
Expand Down

0 comments on commit b80b3a3

Please sign in to comment.