diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 436e85247ffe..c891ec5a28cf 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -579,6 +579,23 @@ void CodeGenCUDA::PrintStorageScope(const std::string& scope, std::ostream& os) } } +std::string CodeGenCUDA::CastFromTo(std::string value, DataType from, DataType target) { + if (from == target) return value; + std::ostringstream os; + os << "(("; + this->PrintType(target, os); + os << ")"; + if (from.is_float16() && (target.is_int() || target.is_uint()) && target.bits() == 8) { + os << "("; + if (target.is_uint()) { + os << "u"; + } + os << "int)"; + } + os << value << ")"; + return os.str(); +} + void CodeGenCUDA::VisitExpr_(const CastNode* op, std::ostream& os) { DataType from_ty = op->value.dtype(); DataType target_ty = op->dtype; diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 0fef15c7a7f3..bb507c179993 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -58,6 +58,7 @@ class CodeGenCUDA final : public CodeGenC { void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; + std::string CastFromTo(std::string value, DataType from, DataType target) final; // overload visitor void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)