Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

First batch of cherry-picks to 1.6 branch (#16711, #16737, #16408) #16763

Merged
merged 4 commits into from
Nov 8, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,69 @@ namespace mxnet {
namespace common {
/*! \brief common utils for cuda */
namespace cuda {
/*!
* \brief Converts between C++ datatypes and enums/constants needed by cuBLAS.
*/
template<typename DType>
struct CublasType;

// With CUDA v8, cuBLAS adopted use of cudaDataType_t instead of its own
// datatype cublasDataType_t. The older cudaDataType_t values could be
// included below, but since this class was introduced to support the cuBLAS v8
// call cublasGemmEx(), burdening the class with the legacy type values
// was not needed.

template<>
struct CublasType<float> {
static const int kFlag = mshadow::kFloat32;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_32F;
#endif
typedef float ScaleType;
static const float one;
static const float zero;
};
template<>
struct CublasType<double> {
static const int kFlag = mshadow::kFloat64;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_64F;
#endif
typedef double ScaleType;
static const double one;
static const double zero;
};
template<>
struct CublasType<mshadow::half::half_t> {
static const int kFlag = mshadow::kFloat16;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_16F;
#endif
typedef float ScaleType;
static const mshadow::half::half_t one;
static const mshadow::half::half_t zero;
};
template<>
struct CublasType<uint8_t> {
static const int kFlag = mshadow::kUint8;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_8I;
#endif
typedef uint8_t ScaleType;
static const uint8_t one = 1;
static const uint8_t zero = 0;
};
template<>
struct CublasType<int32_t> {
static const int kFlag = mshadow::kInt32;
#if CUDA_VERSION >= 8000
static const cudaDataType_t kCudaFlag = CUDA_R_32I;
#endif
typedef int32_t ScaleType;
static const int32_t one = 1;
static const int32_t zero = 0;
};

/*!
* \brief Get string representation of cuBLAS errors.
* \param error The error.
Expand Down Expand Up @@ -218,6 +281,17 @@ inline const char* CublasGetErrorString(cublasStatus_t error) {
return "Unknown cuBLAS status";
}

#if CUDA_VERSION >= 8000
/*!
* \brief Create the proper constant for indicating cuBLAS transposition, if desired.
* \param transpose Whether transposition should be performed.
* \return the yes/no transposition-indicating constant.
*/
inline cublasOperation_t CublasTransposeOp(bool transpose) {
return transpose ? CUBLAS_OP_T : CUBLAS_OP_N;
}
#endif

/*!
* \brief Get string representation of cuSOLVER errors.
* \param error The error.
Expand Down
36 changes: 36 additions & 0 deletions src/common/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,42 @@ inline bool is_float(const int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16;
}

inline int more_precise_type(const int type1, const int type2) {
if (type1 == type2) return type1;
if (is_float(type1) && is_float(type2)) {
if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
return mshadow::kFloat64;
}
if (type1 == mshadow::kFloat32 || type2 == mshadow::kFloat32) {
return mshadow::kFloat32;
}
return mshadow::kFloat16;
} else if (is_float(type1) || is_float(type2)) {
return is_float(type1) ? type1 : type2;
}
if (type1 == mshadow::kInt64 || type2 == mshadow::kInt64) {
return mshadow::kInt64;
}
if (type1 == mshadow::kInt32 || type2 == mshadow::kInt32) {
return mshadow::kInt32;
}
CHECK(!((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)))
<< "1 is UInt8 and 1 is Int8 should not get here";
if (type1 == mshadow::kUint8 || type2 == mshadow::kUint8) {
return mshadow::kUint8;
}
return mshadow::kInt8;
}

inline int np_binary_out_type(const int type1, const int type2) {
if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
(type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
return mshadow::kInt32;
}
return more_precise_type(type1, type2);
}

} // namespace common
} // namespace mxnet
#endif // MXNET_COMMON_UTILS_H_
9 changes: 9 additions & 0 deletions src/operator/contrib/transformer-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
namespace mxnet {
namespace op {

struct InterleavedMatMulParam : public dmlc::Parameter<InterleavedMatMulParam> {
int heads;
bool bwd_ignore_zero_init;
DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) {
DMLC_DECLARE_FIELD(heads)
.describe("Set number of heads");
}
};

template<typename xpu>
static void DivSqrtDimForward_(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
Expand Down
Loading