diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index b697f581c6b9..8c56d3458766 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1987,6 +1987,7 @@ def deserialize_fp8(np_data, in_dtype): # --------------- +@pytest.mark.cpu @pytest.mark.interpreter def test_max_returns_zero(device): # Simple test with a tl.max call that returns 0. The interpreter had a bug @@ -2013,6 +2014,7 @@ def get_reduced_dtype(dtype_str, op): return dtype_str +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ 'min', @@ -2131,9 +2133,6 @@ def kernel(X, Z, BLOCK: tl.constexpr): def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested - if is_cpu() and op in ('argmin', 'argmax'): - pytest.skip(f"Not yet implemented on CPU: {op}") - @triton.jit def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr): @@ -2261,17 +2260,24 @@ def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur): return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config) def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device): check_type_supported(dtype_str, device) if dtype_str == 'bfloat16': - if op == 'cummax': + if is_cuda() and op == 'cummax': pytest.skip("bfloat16 compare not suppoted before sm90") if op == 'linear_recurrence': pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues") numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str + # bf16 vector cast is broken in LLVM for large vectors: + # https://github.com/llvm/llvm-project/issues/92471 + # TODO: Remove the change after the bug is fixed. + if is_cpu() and dtype_str == 'bfloat16': + shape = (min(shape[0], 128), min(shape[1], 128)) + # triton kernel @triton.jit def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr): @@ -2876,6 +2882,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) +@pytest.mark.cpu @pytest.mark.interpreter def test_generic_reduction(device): diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.h b/third_party/cpu/include/TritonToTritonCPU/Passes.h index f67c2de7e2ce..c7d072ab9175 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.h +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.h @@ -25,6 +25,7 @@ std::unique_ptr> createConvertDotOp(); std::unique_ptr> createConvertControlFlowOps(); std::unique_ptr> createConvertHistogramOp(); std::unique_ptr> createConvertReductionOp(); +std::unique_ptr> createConvertScanOp(); void tritonToTritonCPUPipelineBuilder(OpPassManager &pm); void registerTritonToTritonCPUPipeline(); diff --git a/third_party/cpu/include/TritonToTritonCPU/Passes.td b/third_party/cpu/include/TritonToTritonCPU/Passes.td index 60b9942d08cd..28ad258c38c0 100644 --- a/third_party/cpu/include/TritonToTritonCPU/Passes.td +++ b/third_party/cpu/include/TritonToTritonCPU/Passes.td @@ -100,4 +100,18 @@ def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp"> "mlir::triton::cpu::TritonCPUDialect"]; } +def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> { + let summary = "Convert Triton ScanOp."; + let description = [{ + + }]; + let constructor = "mlir::triton::cpu::createConvertScanOp()"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::vector::VectorDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::cpu::TritonCPUDialect"]; +} + #endif diff --git a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt index d7974fe63079..fc22e12b867d 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt +++ b/third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt @@ -4,8 +4,9 @@ add_triton_library(TritonToTritonCPU ConvertElementwiseOps.cpp ConvertHistogramOp.cpp ConvertMemoryOps.cpp - ConvertReductionOp.cpp ConvertPtrOps.cpp + ConvertReductionOp.cpp + ConvertScanOp.cpp Pipeline.cpp TypeConverter.cpp diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp index 62463c834fd4..a460d9834e4e 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp @@ -1,22 +1,17 @@ +#include "ReduceScanCommon.h" #include "TypeConverter.h" #include "cpu/include/TritonToTritonCPU/Passes.h" -#include "mlir/Analysis/DataFlowFramework.h" -#include "mlir/Dialect/Index/IR/IndexDialect.h" -#include "mlir/Dialect/Index/IR/IndexOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "triton/Analysis/Allocation.h" -#include "triton/Analysis/AxisInfo.h" -#include "triton/Analysis/Membar.h" -#include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonCPU/IR/Dialect.h" +#include + namespace mlir { namespace triton { #define GEN_PASS_DEF_CONVERTREDUCTIONOP @@ -44,28 +39,91 @@ class ReductionConversionTarget : public ConversionTarget { } }; -struct ReduceOpConversion : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +struct ReduceOpConversion + : public ReduceScanOpConversionBase { + using ReduceScanOpConversionBase::ReduceScanOpConversionBase; LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MLIRContext *ctx = op.getContext(); - // Currently, only simple reductions with a single input argumet are - // supported. - // TODO: support generic case. + // More simple cases with a single input and a single combine + // operation can utilize target-specific reduction operations like + // horizaontal vector operations. We detect such cases here and map + // them to the vector::MultiDimReductionOp. + if (succeeded(mapToMultiDimReductionOp(op, rewriter))) + return success(); + + return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter); + } + + SmallVector + lower1DInput(ValueRange inputs, ReduceOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + int64_t vecSize = cast(inputs[0].getType()).getShape()[0]; + SmallVector range(vecSize); + std::iota(range.begin(), range.end(), 0); + + ArrayRef dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector res = inputs; + for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) { + SmallVector shuffleIndices = range; + for (int64_t i = 0; i < stride; ++i) { + std::swap(shuffleIndices[i], shuffleIndices[i + stride]); + } + SmallVector shuffledInput; + for (auto [val, dummy] : llvm::zip(res, dummies)) { + shuffledInput.push_back(rewriter.create( + loc, val, dummy, shuffleIndices)); + } + + res = accumulate(shuffledInput, res, combineOp, rewriter); + } + + // The results are in the first element of each produced vector. + Value zero = + rewriter.create(loc, rewriter.getIndexAttr(0)); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, res[i], zero); + } + return res; + } + + SmallVector + lowerLeadingDimension(ValueRange inputs, ReduceOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + auto shape = cast(inputs[0].getType()).getShape(); + SmallVector res; + for (int64_t idx = 0; idx < shape[0]; ++idx) { + SmallVector subInputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), subInputs.begin(), + [&](auto val) { + return rewriter.create(loc, val, idx); + }); + + res = accumulate(subInputs, res, combineOp, rewriter); + } + return res; + } + + LogicalResult + mapToMultiDimReductionOp(triton::ReduceOp op, + ConversionPatternRewriter &rewriter) const { if (op.getNumOperands() != 1 || op.getNumResults() != 1) return failure(); Value src = rewriter.getRemappedValue(op.getOperand(0)); - VectorType srcTy = dyn_cast(src.getType()); - assert(srcTy); + VectorType srcTy = cast(src.getType()); Block *block = op.getBody(); if (block->getNumArguments() != 2) return failure(); - Value itArg = block->getArgument(0); - Value accArg = block->getArgument(1); + Value accArg = block->getArgument(0); + Value itArg = block->getArgument(1); auto &blockOps = block->getOperations(); if (blockOps.size() != 2) @@ -155,7 +213,18 @@ struct ReduceOpConversion : public OpConversionPattern { elemTy, static_cast( (1UL << (elemTy.getIntOrFloatBitWidth() - 1)) - 1)); else if (kind == vector::CombiningKind::MINIMUMF || - kind == vector::CombiningKind::MINNUMF) { + kind == vector::CombiningKind::MAXIMUMF) { + if (elemTy.isF32()) + initVal = + rewriter.getF32FloatAttr(std::numeric_limits::quiet_NaN()); + else if (elemTy.isF64()) + initVal = + rewriter.getF64FloatAttr(std::numeric_limits::quiet_NaN()); + else + llvm_unreachable("Unsupported type for acc init value."); + } + + else if (kind == vector::CombiningKind::MINNUMF) { if (elemTy.isF32()) initVal = rewriter.getF32FloatAttr(std::numeric_limits::infinity()); @@ -164,8 +233,7 @@ struct ReduceOpConversion : public OpConversionPattern { rewriter.getF64FloatAttr(std::numeric_limits::infinity()); else llvm_unreachable("Unsupported type for acc init value."); - } else if (kind == vector::CombiningKind::MAXIMUMF || - kind == vector::CombiningKind::MAXNUMF) { + } else if (kind == vector::CombiningKind::MAXNUMF) { if (elemTy.isF32()) initVal = rewriter.getF32FloatAttr(-std::numeric_limits::infinity()); diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp new file mode 100644 index 000000000000..5425b5dbf800 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertScanOp.cpp @@ -0,0 +1,156 @@ +#include "ReduceScanCommon.h" +#include "TypeConverter.h" + +#include "cpu/include/TritonToTritonCPU/Passes.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Pass/Pass.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include + +namespace mlir { +namespace triton { +#define GEN_PASS_DEF_CONVERTSCANOP +#include "cpu/include/TritonToTritonCPU/Passes.h.inc" +} // namespace triton +} // namespace mlir + +using namespace mlir; +using namespace mlir::triton; +using namespace mlir::triton::cpu; + +namespace { + +class ScanConversionTarget : public ConversionTarget { +public: + explicit ScanConversionTarget(MLIRContext &ctx, TypeConverter &converter) + : ConversionTarget(ctx) { + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalDialect(); + addLegalOp(); + + addIllegalOp(); + } +}; + +struct ScanOpConversion + : public ReduceScanOpConversionBase { + using ReduceScanOpConversionBase::ReduceScanOpConversionBase; + + SmallVector + lower1DInput(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + int64_t vecSize = cast(inputs[0].getType()).getShape()[0]; + Type maskTy = VectorType::get(vecSize, rewriter.getI1Type()); + + ArrayRef dummies = createShuffleDummies(loc, inputs, rewriter); + SmallVector res = inputs; + for (int64_t stride = 1; stride < vecSize; stride *= 2) { + SmallVector shuffleIndices(vecSize, 0); + int64_t start = reverse ? vecSize - 1 - stride : stride; + int64_t end = reverse ? -1 : vecSize; + int64_t step = reverse ? -1 : 1; + for (int64_t i = start; i != end; i += step) { + shuffleIndices[i] = i - step * stride; + } + SmallVector shuffledInput; + for (auto [val, dummy] : llvm::zip(res, dummies)) { + shuffledInput.push_back(rewriter.create( + loc, val, dummy, shuffleIndices)); + } + + auto newRes = accumulate(res, shuffledInput, combineOp, rewriter); + + // Number of already computed elements is equal to the current + // stride. Mask them out using a constant mask. + SmallVector maskVals(vecSize, true); + if (reverse) { + std::fill(maskVals.rbegin(), maskVals.rbegin() + stride, false); + } else { + std::fill(maskVals.begin(), maskVals.begin() + stride, false); + } + Value mask = rewriter.create( + loc, maskTy, rewriter.getBoolVectorAttr(maskVals)); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = vector::selectPassthru(rewriter, mask, newRes[i], res[i]); + } + } + + return res; + } + + SmallVector + lowerLeadingDimension(ValueRange inputs, ScanOp op, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Region &combineOp = op.getRegion(); + bool reverse = op.getReverse(); + auto shape = cast(inputs[0].getType()).getShape(); + SmallVector resTypes; + for (const auto &resTy : op.getResultTypes()) { + resTypes.push_back(VectorType::get( + shape, cast(resTy).getElementType())); + } + SmallVector res = makeEmptyResults(loc, resTypes, rewriter); + SmallVector acc; + int64_t start = reverse ? shape[0] - 1 : 0; + int64_t end = reverse ? -1 : shape[0]; + int64_t step = reverse ? -1 : 1; + for (int64_t idx = start; idx != end; idx += step) { + SmallVector subInputs(inputs.size()); + std::transform(inputs.begin(), inputs.end(), subInputs.begin(), + [&](auto val) { + return rewriter.create(loc, val, idx); + }); + + acc = accumulate(subInputs, acc, combineOp, rewriter); + + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, acc[i], res[i], idx); + } + } + return res; + } +}; + +struct ConvertScanOp : public triton::impl::ConvertScanOpBase { + using ConvertScanOpBase::ConvertScanOpBase; + + ConvertScanOp() : ConvertScanOpBase() {} + + void runOnOperation() override { + MLIRContext *context = &getContext(); + ModuleOp mod = getOperation(); + + TritonToTritonCPUTypeConverter typeConverter; + ScanConversionTarget convTarget(*context, typeConverter); + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +namespace mlir { +namespace triton { +namespace cpu { + +std::unique_ptr> createConvertScanOp() { + return std::make_unique(); +} + +} // namespace cpu +} // namespace triton +} // namespace mlir diff --git a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp index 50d5814270d7..2b26cec34248 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/Pipeline.cpp @@ -14,6 +14,7 @@ void tritonToTritonCPUPipelineBuilder(OpPassManager &pm) { pm.addPass(mlir::triton::cpu::createConvertDotOp()); pm.addPass(mlir::triton::cpu::createConvertHistogramOp()); pm.addPass(mlir::triton::cpu::createConvertReductionOp()); + pm.addPass(mlir::triton::cpu::createConvertScanOp()); pm.addPass(mlir::triton::cpu::createConvertControlFlowOps()); // pm.addPass(mlir::createReconcileUnrealizedCastsPass()); } diff --git a/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h new file mode 100644 index 000000000000..b2edc5e98b36 --- /dev/null +++ b/third_party/cpu/lib/TritonToTritonCPU/ReduceScanCommon.h @@ -0,0 +1,244 @@ +#include "mlir/Transforms/DialectConversion.h" + +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonCPU/IR/Dialect.h" + +#include + +namespace mlir { +namespace triton { +namespace cpu { + +// Base class for converting scans and reductions. +// +// It provides accumulation function that clones operations from the +// original combine region and applies them on provided vectors. +// Also, it handles multi-diumensional cases reducing them to two +// possible options: lowering for a 1-D vector inputs and lowering +// the operation over the leading dimension. +// +// Specialized pattern should implement lower1DInput to handle +// trailing dimension case (commonly through shuffles + accumulate) +// and lowerLeadingDimension to handle the leading dimension case +// through accumulation of sub-vectors. +template +struct ReduceScanOpConversionBase : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::getTypeConverter; + using typename OpConversionPattern::OpAdaptor; + + virtual SmallVector + lower1DInput(ValueRange inputs, OpT op, + ConversionPatternRewriter &rewriter) const = 0; + virtual SmallVector + lowerLeadingDimension(ValueRange inputs, OpT op, + ConversionPatternRewriter &rewriter) const = 0; + + LogicalResult + matchAndRewrite(OpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto rank = cast(op.getOperand(0).getType()).getRank(); + if (op.getAxis() == (rank - 1)) + return lowerTrailingDimension(op, rewriter); + + return lowerNonTrailingDimension(op, rewriter); + } + + // To handle the trailing dimension case, we extract all input vectors + // and process them through lower1DInput, then build the resulting + // vector using inserts. + LogicalResult + lowerTrailingDimension(OpT op, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + SmallVector inputTys(inputs.size()); + std::transform(inputs.begin(), inputs.end(), inputTys.begin(), + [](auto val) { return cast(val.getType()); }); + + // 1-D input case. + if (inputTys.front().getRank() == 1) { + auto res = lower1DInput(inputs, op, rewriter); + rewriter.replaceOp(op, res); + return success(); + } + + SmallVector res = + makeEmptyResults(loc, op.getResultTypes(), rewriter); + auto shape = inputTys[0].getShape(); + int64_t numElems = inputTys[0].getNumElements(); + auto strides = computeStrides(shape); + // Remove the last stride to produce sub-vector indices. + strides.pop_back(); + for (int64_t idx = 0; idx < numElems; idx += shape.back()) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, indices); + }); + + auto resElems = lower1DInput(subInputs, op, rewriter); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = rewriter.create(loc, resElems[i], res[i], + indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + // In this case we either call lowerLeadingDimension to process the input + // or extract sub-vectors, call lowerLeadingDimension, and then reconstruct + // the result. + LogicalResult + lowerNonTrailingDimension(OpT op, ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + SmallVector inputs; + if (failed(rewriter.getRemappedValues(op.getOperands(), inputs))) + return failure(); + + uint32_t axis = op.getAxis(); + if (axis == 0) { + rewriter.replaceOp(op, lowerLeadingDimension(inputs, op, rewriter)); + return success(); + } + + SmallVector res = + makeEmptyResults(loc, op.getResultTypes(), rewriter); + auto vecTy = cast(inputs[0].getType()); + auto shape = vecTy.getShape(); + auto strides = computeStrides(shape); + // Remove trailing elems to build indices of required rank. + strides.erase(strides.begin() + axis, strides.end()); + int64_t numElems = vecTy.getNumElements(); + int64_t step = strides.back(); + for (int64_t idx = 0; idx < numElems; idx += step) { + auto indices = delinearize(idx, strides); + SmallVector subInputs(inputs.size()); + std::transform( + inputs.begin(), inputs.end(), subInputs.begin(), [&](auto val) { + return rewriter.create(loc, val, indices); + }); + auto resVecs = lowerLeadingDimension(subInputs, op, rewriter); + for (size_t i = 0; i < res.size(); ++i) { + res[i] = + rewriter.create(loc, resVecs[i], res[i], indices); + } + } + + rewriter.replaceOp(op, res); + return success(); + } + + // Accumulate inputs and existing accumulators into a new accumaltors + // applying operations from the combine region. + SmallVector accumulate(ValueRange inputs, ValueRange acc, + Region &combineOp, + ConversionPatternRewriter &rewriter) const { + if (acc.empty()) + return inputs; + + auto shape = cast(inputs[0].getType()).getShape(); + auto &block = combineOp.getBlocks().front(); + IRMapping map; + // Map block arguments to the current inputs and accumulators. + for (unsigned i = 0; i < acc.size(); ++i) { + map.map(block.getArgument(i), acc[i]); + map.map(block.getArgument(acc.size() + i), inputs[i]); + } + for (auto &op : block.getOperations()) { + // Returned values are a new accumulator. + if (isa(op)) { + SmallVector res; + for (auto operand : op.getOperands()) { + res.push_back(map.lookup(operand)); + } + return res; + } + + // Clone operation mapping its inputs and building vector + // result types using the input shape. + OperationState newState(op.getLoc(), op.getName()); + for (auto operand : op.getOperands()) { + newState.operands.push_back( + lookupMappedValue(map, operand, shape, rewriter)); + } + for (auto ty : op.getResultTypes()) { + newState.types.push_back(VectorType::get(shape, ty)); + } + newState.attributes = op.getAttrs(); + auto newOp = rewriter.create(newState); + + // Add new values to the map. + for (auto [oldVal, newVal] : + llvm::zip(op.getResults(), newOp->getResults())) { + map.map(oldVal, newVal); + } + } + llvm_unreachable("No return op found in scan/reduce region"); + } + + Value lookupMappedValue(IRMapping &localMap, Value val, + ArrayRef shape, + ConversionPatternRewriter &rewriter) const { + + Value res = localMap.lookupOrNull(val); + if (!res) { + // If value is not found then it's an invariant defined in the outer + // region. We check if it has been already translated and add a splat + // operation if it hasn't. + res = invariantsMap.lookupOrNull(val); + if (!res) { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPointAfterValue(val); + res = rewriter.create( + val.getLoc(), VectorType::get(shape, val.getType()), val); + invariantsMap.map(val, res); + rewriter.restoreInsertionPoint(ip); + } + } + return res; + } + + SmallVector + makeEmptyResults(Location loc, TypeRange resTypes, + ConversionPatternRewriter &rewriter) const { + // Initialize results to zero values. + SmallVector res; + for (auto ty : resTypes) { + res.push_back(rewriter.create( + loc, rewriter.getZeroAttr(getTypeConverter()->convertType(ty)))); + } + return res; + } + + // Dummy vectors are required for shuffles that cannot work on a single + // vector. + ArrayRef + createShuffleDummies(Location loc, ValueRange inputs, + ConversionPatternRewriter &rewriter) const { + if (shuffleDummies.empty()) { + for (auto val : inputs) { + auto ty = cast(val.getType()); + shuffleDummies.push_back(rewriter.create( + loc, rewriter.getZeroAttr(ty.cloneWith(1, ty.getElementType())))); + } + } + return shuffleDummies; + } + +private: + mutable IRMapping invariantsMap; + mutable SmallVector shuffleDummies; +}; + +} // namespace cpu +} // namespace triton +} // namespace mlir