diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index d3e0ee102a51..29ba31eaf1f3 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,4 +1,5 @@ #pragma once +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index 680af81ac41c..d07f0743615d 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -30,7 +30,6 @@ #include "mlir/IR/Dialect.h" // TritonNvidiaGPU depends on Triton -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Traits.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td index 08ff21f523f0..f2ab288c1799 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -38,7 +38,6 @@ def TritonNvidiaGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "triton::gpu::TritonGPUDialect", - "mlir::triton::nvgpu::NVGPUDialect", "mlir::gpu::GPUDialect", "tensor::TensorDialect", ]; diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 69545a00d83e..fcce9884bc6e 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -6,6 +6,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index ba0423203a0b..2b12e727025f 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -11,6 +11,7 @@ #include "Utility.h" #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Target/PTX/TmaMetadata.h" #include diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 4ef4f65ea1c7..fb2f46f9936a 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -18,6 +18,7 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -387,6 +388,11 @@ struct ConvertTritonGPUToLLVM using ConvertTritonGPUToLLVMBase< ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + ConvertTritonGPUToLLVM(int32_t computeCapability, Target target, mlir::triton::gpu::TMAMetadataTy *tmaMetadata) : ConvertTritonGPUToLLVMBase({computeCapability, target}), diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index f89cfe7b8c7d..06d338685909 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -1,5 +1,6 @@ #include "Utility.h" #include "TypeConverter.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" namespace mlir { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 5a7adeab492f..9dd072d0c942 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -6,7 +6,6 @@ #include "triton/Analysis/Utility.h" #include "triton/Conversion/MLIRTypes.h" #include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" -#include "triton/Dialect/NVGPU/IR/Dialect.h" // Shortcuts for some commonly used LLVM ops to keep code simple and intuitive // Operators diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index 12739e56722c..f6ae42ac3e9a 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -36,7 +36,7 @@ def test_op(M, N, dtype, mode): x.grad = None th_y.backward(dy) th_dx = x.grad.clone() - if dtype == 'float16': + if dtype == torch.float16: torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) else: torch.testing.assert_close(th_dx, tt_dx)