diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index f4f571409f85..b697f581c6b9 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2418,6 +2418,7 @@ def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.const # --------------- +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("M, N", [[2048, 2], [1024, 8], [1024, 128], [256, 512], [32, 512], [8, 512], [8, 2]]) def test_histogram(M, N, device): diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index 5893c99f250e..f67c2de7e2ce 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -23,6 +23,7 @@ std::unique_ptr> createConvertMemoryOps(); std::unique_ptr> createConvertPtrOps(); std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); +std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index a2663bea5589..6604ca4fcc12 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -74,6 +74,17 @@ def ConvertControlFlowOps : Pass<"triton-cpu-convert-control-flow-op", "mlir::Mo "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertHistogramOp : Pass<"triton-cpu-convert-histogram-op", "mlir::ModuleOp"> { + let summary = "Convert Triton HistogramOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertHistogramOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::memref::MemRefDialect", + "mlir::vector::VectorDialect", + def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> { let summary = "Convert Triton ReduceOp."; let description = [{ diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index d18488e5aef0..d7974fe63079 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonToTritonCPU ConvertControlFlowOps.cpp ConvertDotOp.cpp ConvertElementwiseOps.cpp + ConvertHistogramOp.cpp ConvertMemoryOps.cpp ConvertReductionOp.cpp ConvertPtrOps.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp new file mode 100644 index 000000000000..0bcbfcc9f264 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertHistogramOp.cpp @@ -0,0 +1,134 @@ +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/Index/IR/IndexOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/BuiltinDialect.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTHISTOGRAMOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class HistogramConversionTarget : public ConversionTarget { +public: + explicit HistogramConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + + addIllegalOp(); + } +}; + +struct HistogramOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::HistogramOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto src = rewriter.getRemappedValue(op.getSrc()); + auto srcTy = dyn_cast(src.getType()); + auto resTy = + dyn_cast(getTypeConverter()->convertType(op.getType())); + + if (srcTy.getRank() != 1) + llvm_unreachable("unsupported input for histogram op (rank != 1)"); + + Value zero = rewriter.create( + loc, resTy, rewriter.getZeroAttr(resTy)); + Value one = rewriter.create(loc, resTy, + rewriter.getOneAttr(resTy)); + VectorType cmpVecTy = + VectorType::get(resTy.getShape(), srcTy.getElementType()); + Value rangeVec = rewriter.create( + loc, resTy, makeRangeAttr(cmpVecTy, rewriter)); + Value res = zero; + for (int64_t i = 0; i < srcTy.getShape()[0]; ++i) { + Value idx = rewriter.create( + loc, rewriter.getIndexType(), rewriter.getIndexAttr(i)); + Value elem = rewriter.create(loc, src, idx); + Value elemVec = rewriter.create(loc, cmpVecTy, elem); + Value mask = rewriter.create(loc, arith::CmpIPredicate::eq, + elemVec, rangeVec); + Value delta = vector::selectPassthru(rewriter, mask, one, zero); + res = rewriter.create(loc, res, delta); + } + + rewriter.replaceOp(op, res); + + return success(); + } + + TypedAttr makeRangeAttr(VectorType resTy, + ConversionPatternRewriter &rewriter) const { + Type elemTy = resTy.getElementType(); + if (elemTy.isInteger(32)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI32VectorAttr(range); + } else if (elemTy.isInteger(64)) { + SmallVector range(resTy.getShape()[0]); + std::iota(range.begin(), range.end(), 0); + return rewriter.getI64VectorAttr(range); + } else { + llvm_unreachable( + "unsupported src elem type for histogram (expected i32 or i64)"); + } + } +}; + +struct ConvertHistogramOp + : public triton::impl::ConvertHistogramOpBase { + using ConvertHistogramOpBase::ConvertHistogramOpBase; + + ConvertHistogramOp() : ConvertHistogramOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + HistogramConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertHistogramOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index 87d72f7a6473..50d5814270d7 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -12,6 +12,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertPtrOps()); pm.addPass(mlir::triton::cpu::createConvertElementwiseOps()); pm.addPass(mlir::triton::cpu::createConvertDotOp()); + pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); pm.addPass(mlir::triton::cpu::createConvertReductionOp()); pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass());