diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 646ed28e8a..5b146b0559 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -363,9 +363,7 @@ def make_llir(src, metadata, options): context = llvm.context() llvm_mod = llvm.to_module(mod, context) intel.set_spv_target_triple(llvm_mod) - if os.getenv("TRITON_INTEL_FAST_MATH", "0") == "1": - intel.set_fast_math(llvm_mod) - + intel.set_fast_math(llvm_mod) if options.extern_libs: paths = [path for (name, path) in options.extern_libs] llvm.link_extern_libs(llvm_mod, paths) diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index 6ce8477337..e80c59e27c 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -273,12 +273,24 @@ void init_triton_intel(py::module &&m) { // fast math semantics on all arithmetic operations. // https://github.com/intel/intel-xpu-backend-for-triton/issues/3862 m.def("set_fast_math", [](llvm::Module *mod) { + std::optional fastMath = mlir::triton::tools::isEnvValueBool( + mlir::triton::tools::getStrEnv("TRITON_INTEL_FAST_MATH")); + std::optional enableFpFusion = mlir::triton::tools::isEnvValueBool( + mlir::triton::tools::getStrEnv("TRITON_DEFAULT_FP_FUSION")); + if (fastMath.has_value() && !fastMath.value()) + return; + using namespace llvm; for (Function &func : *mod) { for (Instruction &inst : instructions(func)) { if (auto *op = dyn_cast(&inst)) { FastMathFlags FMF; - FMF.setFast(true); + // Default to allow contract when default fp fusion is not disabled. + if ((!enableFpFusion.has_value() || enableFpFusion.value()) && + !fastMath.has_value()) + FMF.setAllowContract(true); + else if (fastMath.has_value() && fastMath.value()) + FMF.setFast(true); inst.setFastMathFlags(FMF); } }