diff --git a/src/target/codegen_cuda.cc b/src/target/codegen_cuda.cc index 641d4b332..98171cecf 100644 --- a/src/target/codegen_cuda.cc +++ b/src/target/codegen_cuda.cc @@ -216,6 +216,188 @@ CodeGenTileLangCUDA::CodeGenTileLangCUDA() { runtime::symbol::tvm_global_barrier_state); } +void CodeGenTileLangCUDA::ReserveKeywordsAsUnique_() { + CodeGenC::ReserveKeywordsAsUnique(); + name_supply_->ReserveName("max"); + name_supply_->ReserveName("min"); + name_supply_->ReserveName("isfinite"); + name_supply_->ReserveName("isinf"); + name_supply_->ReserveName("isnan"); + + // skip single precision mathematical functions + name_supply_->ReserveName("acosf"); + name_supply_->ReserveName("acoshf"); + name_supply_->ReserveName("asinf"); + name_supply_->ReserveName("asinhf"); + name_supply_->ReserveName("atan2f"); + name_supply_->ReserveName("atanf"); + name_supply_->ReserveName("atanhf"); + name_supply_->ReserveName("cbrtf"); + name_supply_->ReserveName("ceilf"); + name_supply_->ReserveName("copysignf"); + name_supply_->ReserveName("cosf"); + name_supply_->ReserveName("coshf"); + name_supply_->ReserveName("cospif"); + name_supply_->ReserveName("cyl_bessel_i0f"); + name_supply_->ReserveName("cyl_bessel_i1f"); + name_supply_->ReserveName("erfcf"); + name_supply_->ReserveName("erfcinvf"); + name_supply_->ReserveName("erfcxf"); + name_supply_->ReserveName("erff"); + name_supply_->ReserveName("erfinvf"); + name_supply_->ReserveName("exp10f"); + name_supply_->ReserveName("exp2f"); + name_supply_->ReserveName("expf"); + name_supply_->ReserveName("expm1f"); + name_supply_->ReserveName("fabsf"); + name_supply_->ReserveName("fdimf"); + name_supply_->ReserveName("fdividef"); + name_supply_->ReserveName("floorf"); + name_supply_->ReserveName("fmaf"); + name_supply_->ReserveName("fmaxf"); + name_supply_->ReserveName("fminf"); + name_supply_->ReserveName("fmodf"); + name_supply_->ReserveName("frexpf"); + name_supply_->ReserveName("hypotf"); + name_supply_->ReserveName("ilogbf"); + name_supply_->ReserveName("j0f"); + name_supply_->ReserveName("j1f"); + name_supply_->ReserveName("jnf"); + name_supply_->ReserveName("ldexpf"); + name_supply_->ReserveName("lgammaf"); + name_supply_->ReserveName("llrintf"); + name_supply_->ReserveName("llroundf"); + name_supply_->ReserveName("log10f"); + name_supply_->ReserveName("log1pf"); + name_supply_->ReserveName("log2f"); + name_supply_->ReserveName("logbf"); + name_supply_->ReserveName("logf"); + name_supply_->ReserveName("lrintf"); + name_supply_->ReserveName("lroundf"); + name_supply_->ReserveName("modff"); + name_supply_->ReserveName("nanf"); + name_supply_->ReserveName("nearbyintf"); + name_supply_->ReserveName("nextafterf"); + name_supply_->ReserveName("norm3df"); + name_supply_->ReserveName("norm4df"); + name_supply_->ReserveName("normcdff"); + name_supply_->ReserveName("normcdfinvf"); + name_supply_->ReserveName("normf"); + name_supply_->ReserveName("powf"); + name_supply_->ReserveName("rcbrtf"); + name_supply_->ReserveName("remainderf"); + name_supply_->ReserveName("remquof"); + name_supply_->ReserveName("rhypotf"); + name_supply_->ReserveName("rintf"); + name_supply_->ReserveName("rnorm3df"); + name_supply_->ReserveName("rnorm4df"); + name_supply_->ReserveName("rnormf"); + name_supply_->ReserveName("roundf"); + name_supply_->ReserveName("rsqrtf"); + name_supply_->ReserveName("scalblnf"); + name_supply_->ReserveName("scalbnf"); + name_supply_->ReserveName("signbit"); + name_supply_->ReserveName("sincosf"); + name_supply_->ReserveName("sincospif"); + name_supply_->ReserveName("sinf"); + name_supply_->ReserveName("sinhf"); + name_supply_->ReserveName("sinpif"); + name_supply_->ReserveName("sqrtf"); + name_supply_->ReserveName("tanf"); + name_supply_->ReserveName("tanhf"); + name_supply_->ReserveName("tgammaf"); + name_supply_->ReserveName("truncf"); + name_supply_->ReserveName("y0f"); + name_supply_->ReserveName("y1f"); + name_supply_->ReserveName("ynf"); + + // skip double precision mathematical functions + name_supply_->ReserveName("acos"); + name_supply_->ReserveName("acosh"); + name_supply_->ReserveName("asin"); + name_supply_->ReserveName("asinh"); + name_supply_->ReserveName("atan"); + name_supply_->ReserveName("atan2"); + name_supply_->ReserveName("atanh"); + name_supply_->ReserveName("cbrt"); + name_supply_->ReserveName("ceil"); + name_supply_->ReserveName("copysign"); + name_supply_->ReserveName("cos"); + name_supply_->ReserveName("cosh"); + name_supply_->ReserveName("cospi"); + name_supply_->ReserveName("cyl_bessel_i0"); + name_supply_->ReserveName("cyl_bessel_i1"); + name_supply_->ReserveName("erf"); + name_supply_->ReserveName("erfc"); + name_supply_->ReserveName("erfcinv"); + name_supply_->ReserveName("erfcx"); + name_supply_->ReserveName("erfinv"); + name_supply_->ReserveName("exp"); + name_supply_->ReserveName("exp10"); + name_supply_->ReserveName("exp2"); + name_supply_->ReserveName("expm1"); + name_supply_->ReserveName("fabs"); + name_supply_->ReserveName("fdim"); + name_supply_->ReserveName("floor"); + name_supply_->ReserveName("fma"); + name_supply_->ReserveName("fmax"); + name_supply_->ReserveName("fmin"); + name_supply_->ReserveName("fmod"); + name_supply_->ReserveName("frexp"); + name_supply_->ReserveName("hypot"); + name_supply_->ReserveName("ilogb"); + name_supply_->ReserveName("j0"); + name_supply_->ReserveName("j1"); + name_supply_->ReserveName("jn"); + name_supply_->ReserveName("ldexp"); + name_supply_->ReserveName("lgamma"); + name_supply_->ReserveName("llrint"); + name_supply_->ReserveName("llround"); + name_supply_->ReserveName("log"); + name_supply_->ReserveName("log10"); + name_supply_->ReserveName("log1p"); + name_supply_->ReserveName("log2"); + name_supply_->ReserveName("logb"); + name_supply_->ReserveName("lrint"); + name_supply_->ReserveName("lround"); + name_supply_->ReserveName("modf"); + name_supply_->ReserveName("nan"); + name_supply_->ReserveName("nearbyint"); + name_supply_->ReserveName("nextafter"); + name_supply_->ReserveName("norm"); + name_supply_->ReserveName("norm3d"); + name_supply_->ReserveName("norm4d"); + name_supply_->ReserveName("normcdf"); + name_supply_->ReserveName("normcdfinv"); + name_supply_->ReserveName("pow"); + name_supply_->ReserveName("rcbrt"); + name_supply_->ReserveName("remainder"); + name_supply_->ReserveName("remquo"); + name_supply_->ReserveName("rhypot"); + name_supply_->ReserveName("rint"); + name_supply_->ReserveName("rnorm"); + name_supply_->ReserveName("rnorm3d"); + name_supply_->ReserveName("rnorm4d"); + name_supply_->ReserveName("round"); + name_supply_->ReserveName("rsqrt"); + name_supply_->ReserveName("scalbln"); + name_supply_->ReserveName("scalbn"); + name_supply_->ReserveName("signbit"); + name_supply_->ReserveName("sin"); + name_supply_->ReserveName("sincos"); + name_supply_->ReserveName("sincospi"); + name_supply_->ReserveName("sinh"); + name_supply_->ReserveName("sinpi"); + name_supply_->ReserveName("sqrt"); + name_supply_->ReserveName("tan"); + name_supply_->ReserveName("tanh"); + name_supply_->ReserveName("tgamma"); + name_supply_->ReserveName("trunc"); + name_supply_->ReserveName("y0"); + name_supply_->ReserveName("y1"); + name_supply_->ReserveName("yn"); +} + void CodeGenTileLangCUDA::PrintFuncPrefix(std::ostream &os) { os << "extern \"C\" __global__ "; } @@ -3431,7 +3613,7 @@ void CodeGenTileLangCUDA::AddFunction(const GlobalVar &gvar, // clear previous generated state. this->InitFuncState(f); // reserve keywords - ReserveKeywordsAsUnique(); + ReserveKeywordsAsUnique_(); auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); ICHECK(global_symbol) diff --git a/src/target/codegen_cuda.h b/src/target/codegen_cuda.h index 9cf460213..837d25b37 100644 --- a/src/target/codegen_cuda.h +++ b/src/target/codegen_cuda.h @@ -65,6 +65,7 @@ class CodeGenTileLangCUDA final : public CodeGenC { const PrimFunc &func, std::ostream &os); protected: + void ReserveKeywordsAsUnique_(); virtual std::string GetBufferRef(DataType t, const BufferNode *buffer, PrimExpr index) final; void PrintCallExtern(Type ret_type, ffi::String global_symbol,