diff --git a/src/operator/quantization/quantize-inl.h b/src/operator/quantization/quantize-inl.h index 8b7a11cc5a89..3c8ad81566e4 100644 --- a/src/operator/quantization/quantize-inl.h +++ b/src/operator/quantization/quantize-inl.h @@ -95,6 +95,10 @@ void QuantizeCompute(const nnvm::NodeAttrs& attrs, const QuantizeParam& param = nnvm::get(attrs.parsed); if (param.out_type == mshadow::kUint8) { + if (std::is_same::value) { + LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, " + "please switch to the context of CPU or int8 data type for GPU."; + } Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), inputs[0].dptr(), inputs[1].dptr(), inputs[2].dptr(), diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 5ae10a7e4fa8..03ee48fddb16 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -139,6 +139,10 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, auto out_type = GetOutputType(param); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { if (out_type == mshadow::kUint8) { + if (std::is_same::value) { + LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, " + "please switch to the context of CPU or int8 data type for GPU."; + } Kernel::Launch( s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), inputs[0].dptr(), param.min_calib_range.value(), @@ -170,6 +174,10 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, broadcast::Reduce( s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); if (out_type == mshadow::kUint8) { + if (std::is_same::value) { + LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, " + "please switch to the context of CPU or int8 data type for GPU."; + } Kernel::Launch( s, outputs[0].Size(), outputs[0].dptr(), outputs[1].dptr(), outputs[2].dptr(), inputs[0].dptr(), in_min_t.dptr(),