Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1185] Support large array in several operators (part 1) #13418

Merged
merged 13 commits into from
Dec 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ inline bool ElemwiseStorageAttr(const nnvm::NodeAttrs& attrs,
* \tparam rsp whether row sparse stype is supported
* \tparam rsp whether csr stype is supported
*/
template<int n_in, int n_out, bool cpu_only, bool rsp, bool csr>
template<index_t n_in, index_t n_out, bool cpu_only, bool rsp, bool csr>
inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand All @@ -115,7 +115,7 @@ inline bool ElemwiseStorageType(const nnvm::NodeAttrs& attrs,
template<typename AttrType, bool (*is_none)(const AttrType&),
bool (*assign)(AttrType*, const AttrType&), bool reverse_infer,
std::string (*attr_string)(const AttrType&),
int n_in = -1, int n_out = -1>
index_t n_in = -1, index_t n_out = -1>
inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs,
Expand Down Expand Up @@ -154,7 +154,7 @@ inline bool ElemwiseAttr(const nnvm::NodeAttrs& attrs,
return true;
}

template<int n_in, int n_out>
template<index_t n_in, index_t n_out>
inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
Expand All @@ -168,7 +168,7 @@ inline bool ElemwiseShape(const nnvm::NodeAttrs& attrs,
attrs, in_attrs, out_attrs, TShape());
}

template<int n_in, int n_out>
template<index_t n_in, index_t n_out>
inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
Expand Down
68 changes: 34 additions & 34 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,8 @@ inline int get_num_threads<cpu>(const int N) {

/* \brief Compute flattened index given coordinates and shape. */
template<int ndim>
MSHADOW_XINLINE int ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
int ret = 0;
MSHADOW_XINLINE index_t ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {
index_t ret = 0;
#pragma unroll
for (int i = 0; i < ndim; ++i) {
ret = ret * shape[i] + (shape[i] > coord[i]) * coord[i];
Expand All @@ -301,11 +301,11 @@ MSHADOW_XINLINE int ravel(const Shape<ndim>& coord, const Shape<ndim>& shape) {

/* Compute coordinates from flattened index given shape */
template<int ndim>
MSHADOW_XINLINE Shape<ndim> unravel(const int idx, const Shape<ndim>& shape) {
MSHADOW_XINLINE Shape<ndim> unravel(const index_t idx, const Shape<ndim>& shape) {
Shape<ndim> ret;
#pragma unroll
for (int i = ndim-1, j = idx; i >=0; --i) {
int tmp = j / shape[i];
for (index_t i = ndim-1, j = idx; i >=0; --i) {
auto tmp = j / shape[i];
ret[i] = j - tmp*shape[i];
j = tmp;
}
Expand All @@ -315,8 +315,8 @@ MSHADOW_XINLINE Shape<ndim> unravel(const int idx, const Shape<ndim>& shape) {

/* Compute dot product of two vector */
template<int ndim>
MSHADOW_XINLINE int dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
int ret = 0;
MSHADOW_XINLINE index_t dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {
index_t ret = 0;
#pragma unroll
for (int i = 0; i < ndim; ++i) {
ret += coord[i] * stride[i];
Expand All @@ -327,12 +327,12 @@ MSHADOW_XINLINE int dot(const Shape<ndim>& coord, const Shape<ndim>& stride) {

/* Combining unravel and dot */
template<int ndim>
MSHADOW_XINLINE int unravel_dot(const int idx, const Shape<ndim>& shape,
MSHADOW_XINLINE index_t unravel_dot(const index_t idx, const Shape<ndim>& shape,
const Shape<ndim>& stride) {
int ret = 0;
index_t ret = 0;
#pragma unroll
for (int i = ndim-1, j = idx; i >=0; --i) {
int tmp = j / shape[i];
for (index_t i = ndim-1, j = idx; i >=0; --i) {
auto tmp = j / shape[i];
ret += (j - tmp*shape[i])*stride[i];
j = tmp;
}
Expand Down Expand Up @@ -433,51 +433,51 @@ struct op_with_req {

/*! \brief input is one tensor */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in) {
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in) {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i]));
}

/*! \brief inputs are two tensors */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *lhs, const DType *rhs) {
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *lhs, const DType *rhs) {
KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
}

/*! \brief input is tensor and a scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in, const DType value) {
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in, const DType value) {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
}

/*! \brief input is tensor and two scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType *in,
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType *in,
const DType value_1, const DType value_2) {
KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value_1, value_2));
}

/*! \brief No inputs (ie fill to constant value) */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out) {
MSHADOW_XINLINE static void Map(index_t i, DType *out) {
KERNEL_ASSIGN(out[i], req, OP::Map());
}

/*! \brief input is single scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out, const DType value) {
MSHADOW_XINLINE static void Map(index_t i, DType *out, const DType value) {
KERNEL_ASSIGN(out[i], req, OP::Map(value));
}

/*! \brief inputs are two tensors and a scalar value */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out,
MSHADOW_XINLINE static void Map(index_t i, DType *out,
const DType *input_1, const DType *input_2, const DType value) {
KERNEL_ASSIGN(out[i], req, OP::Map(input_1[i], input_2[i], value));
}

/*! \brief inputs are three tensors (ie backward grad with binary grad function) */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out,
MSHADOW_XINLINE static void Map(index_t i, DType *out,
const DType *input_1,
const DType *input_2,
const DType *input_3) {
Expand All @@ -503,21 +503,21 @@ struct Kernel<OP, cpu> {
* \param args Varargs to eventually pass to the OP::Map() function
*/
template<typename ...Args>
inline static bool Launch(mshadow::Stream<cpu> *, const int N, Args... args) {
inline static bool Launch(mshadow::Stream<cpu> *, const size_t N, Args... args) {
#ifdef _OPENMP
const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (omp_threads < 2) {
for (int i = 0; i < N; ++i) {
for (size_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (index_t i = 0; i < static_cast<index_t>(N); ++i) {
OP::Map(i, args...);
}
}
#else
for (int i = 0; i < N; ++i) {
for (size_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
#endif
Expand Down Expand Up @@ -567,22 +567,22 @@ struct Kernel<OP, cpu> {
* \param args Varargs to eventually pass to the OP::Map() function
*/
template<typename PRIMITIVE_OP, typename DType, typename ...Args>
static void LaunchTuned(mshadow::Stream<cpu> *, const int N, Args... args) {
static void LaunchTuned(mshadow::Stream<cpu> *, const size_t N, Args... args) {
#ifdef _OPENMP
const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (omp_threads < 2 || !tuned_op<PRIMITIVE_OP, DType>::UseOMP(
static_cast<size_t>(N), static_cast<size_t>(omp_threads))) {
for (int i = 0; i < N; ++i) {
N, static_cast<size_t>(omp_threads))) {
for (size_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
} else {
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; ++i) {
for (index_t i = 0; i < static_cast<index_t>(N); ++i) {
OP::Map(i, args...);
}
}
#else
for (int i = 0; i < N; ++i) {
for (size_t i = 0; i < N; ++i) {
OP::Map(i, args...);
}
#endif
Expand All @@ -596,15 +596,15 @@ struct Kernel<OP, cpu> {
* \param args Varargs to eventually pass to the UseOMP() and OP::Map() functions
*/
template<typename ...Args>
inline static void LaunchEx(mshadow::Stream<cpu> *s, const int N, Args... args) {
inline static void LaunchEx(mshadow::Stream<cpu> *s, const size_t N, Args... args) {
#ifdef _OPENMP
const int omp_threads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
if (omp_threads < 2) {
OP::Map(0, N, args...);
} else {
const int length = (N + omp_threads - 1) / omp_threads;
const auto length = (N + omp_threads - 1) / omp_threads;
#pragma omp parallel for num_threads(omp_threads)
for (int i = 0; i < N; i += length) {
for (index_t i = 0; i < static_cast<index_t>(N); i += length) {
OP::Map(i, i + length > N ? N - i : length, args...);
}
}
Expand All @@ -626,7 +626,7 @@ struct Kernel<OP, cpu> {
template<typename DType, typename T = OP, typename ...Args>
static MSHADOW_CINLINE
typename std::enable_if<std::is_base_of<tunable, T>::value, bool>::type
Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
Launch(mshadow::Stream<cpu> *s, const size_t N, DType *dest, Args... args) {
LaunchTuned<T, DType>(s, N, dest, args...);
return true;
}
Expand All @@ -644,7 +644,7 @@ struct Kernel<OP, cpu> {
template<typename DType, typename T = OP, typename ...Args>
static MSHADOW_CINLINE
typename std::enable_if<std::is_base_of<tunable, typename T::Operation>::value, bool>::type
Launch(mshadow::Stream<cpu> *s, const int N, DType *dest, Args... args) {
Launch(mshadow::Stream<cpu> *s, const size_t N, DType *dest, Args... args) {
LaunchTuned<typename T::Operation, DType>(s, N, dest, args...);
return true;
}
Expand Down Expand Up @@ -700,7 +700,7 @@ template<int val>
struct set_to_int : public tunable {
// mxnet_op version (when used directly with Kernel<>::Launch()) */
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType *out) {
MSHADOW_XINLINE static void Map(index_t i, DType *out) {
out[i] = DType(val);
}
// mshadow_op version (when used with op_with_req<>)
Expand Down
43 changes: 22 additions & 21 deletions src/operator/random/sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,33 @@ namespace op {
template<typename OP, typename xpu, typename GType, typename ...Args>
inline static void LaunchRNG(mshadow::Stream<xpu> *s,
common::random::RandGenerator<xpu, GType> *gen,
const int N, Args... args) {
const index_t N, Args... args) {
// minimal check to avoid division by zero, below.
// if `N` is zero the map operation is a no-op in any case.
if (N <= 0) {
return;
}
const int nloop = (N + RandGenerator<xpu>::kMinNumRandomPerThread - 1) /
const index_t nloop = (N + RandGenerator<xpu>::kMinNumRandomPerThread - 1) /
RandGenerator<xpu>::kMinNumRandomPerThread;
const int nthread = std::min(nloop, RandGenerator<xpu>::kNumRandomStates);
const int step = (N + nthread - 1) / nthread;
const index_t nthread = std::min(nloop,
static_cast<index_t>(RandGenerator<xpu>::kNumRandomStates));
const index_t step = (N + nthread - 1) / nthread;
Kernel<OP, xpu>::Launch(s, nthread, *gen, N, step, args...);
}

#define RNG_KERNEL_LOOP(xpu, GType, thread_id, gen, N, step, ...) \
const int start = thread_id * step; \
const int end = start + step; \
const index_t start = thread_id * step; \
const index_t end = start + step; \
typename RandGenerator<xpu, GType>::Impl genImpl(&gen, thread_id); \
for (int i = start; i < end && i < N; ++i) { \
for (index_t i = start; i < end && i < N; ++i) { \
{__VA_ARGS__} \
}

template<typename xpu>
struct SampleUniformKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *lower, const IType *upper, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand Down Expand Up @@ -127,8 +128,8 @@ struct RandIntSampler {
template<typename xpu>
struct SampleNormalKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *mean, const IType *std, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand All @@ -154,8 +155,8 @@ struct NormalSampler {
template<typename xpu>
struct SampleExponentialKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, OType> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, OType> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *lambda, OType *out) {
RNG_KERNEL_LOOP(xpu, OType, id, gen, N, step, {
Expand Down Expand Up @@ -202,8 +203,8 @@ MSHADOW_XINLINE OType SampleGamma(IType a, IType b, typename RandGenerator<xpu,
template<typename xpu>
struct SampleGammaKernel {
template<typename IType, typename OType, typename FType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, FType> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, FType> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *alpha, const IType *beta, OType *out) {
RNG_KERNEL_LOOP(xpu, FType, id, gen, N, step, {
Expand Down Expand Up @@ -264,8 +265,8 @@ MSHADOW_XINLINE int SamplePoisson(float lambda, typename RandGenerator<xpu, floa
template<typename xpu>
struct SamplePoissonKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, float> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *lambda, OType *out) {
RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
Expand All @@ -291,8 +292,8 @@ struct PoissonSampler {
template<typename xpu>
struct SampleNegativeBinomialKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, float> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *k, const IType *p, OType *out) {
RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
Expand Down Expand Up @@ -323,8 +324,8 @@ struct NegativeBinomialSampler {
template<typename xpu>
struct SampleGeneralizedNegativeBinomialKernel {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(int id, RandGenerator<xpu, float> gen,
const int N, const int step,
MSHADOW_XINLINE static void Map(index_t id, RandGenerator<xpu, float> gen,
const index_t N, const index_t step,
index_t nParm, index_t nSample,
const IType *mu, const IType *alpha, OType *out) {
RNG_KERNEL_LOOP(xpu, float, id, gen, N, step, {
Expand Down
Loading