diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 5cf1c3a25707..d3e0ee102a51 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -40,5 +40,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, - mlir::gpu::GPUDialect>(); + mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, + mlir::triton::nvgpu::NVGPUDialect>(); } diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 6ba2fe711816..24d971701242 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -7,7 +7,6 @@ #include "mlir/IR/Dialect.h" // TritonGPU depends on Triton -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index 7533044b41a6..136b90ee65e5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -16,7 +16,6 @@ def TritonGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", - "mlir::triton::nvgpu::NVGPUDialect", "mlir::gpu::GPUDialect", "tensor::TensorDialect", ]; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index f074b250a7bb..13aef00a6d0b 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -6,6 +6,7 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Operators diff --git a/lib/Dialect/NVGPU/IR/CMakeLists.txt b/lib/Dialect/NVGPU/IR/CMakeLists.txt index 4e9e1ada172c..24a93ce58ea3 100644 --- a/lib/Dialect/NVGPU/IR/CMakeLists.txt +++ b/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -6,4 +6,5 @@ add_mlir_dialect_library(NVGPUIR NVGPUAttrDefsIncGen LINK_LIBS PUBLIC + MLIRLLVMDialect ) diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index be59fc42ab57..12739e56722c 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -36,5 +36,7 @@ def test_op(M, N, dtype, mode): x.grad = None th_y.backward(dy) th_dx = x.grad.clone() - - torch.testing.assert_close(th_dx, tt_dx) + if dtype == 'float16': + torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) + else: + torch.testing.assert_close(th_dx, tt_dx) diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp index 90e6ef8c3d1d..20603cd2e41c 100644 --- a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp @@ -22,6 +22,7 @@ */ #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "DumpLayout.h"