-
Notifications
You must be signed in to change notification settings - Fork 41
Support generic reduction and scan cases. #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,22 +1,17 @@ | ||
| #include "ReduceScanCommon.h" | ||
| #include "TypeConverter.h" | ||
|
|
||
| #include "cpu/include/TritonToTritonCPU/Passes.h" | ||
|
|
||
| #include "mlir/Analysis/DataFlowFramework.h" | ||
| #include "mlir/Dialect/Index/IR/IndexDialect.h" | ||
| #include "mlir/Dialect/Index/IR/IndexOps.h" | ||
| #include "mlir/Dialect/MemRef/IR/MemRef.h" | ||
| #include "mlir/Dialect/Utils/IndexingUtils.h" | ||
| #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||
| #include "mlir/Pass/Pass.h" | ||
| #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
|
||
| #include "triton/Analysis/Allocation.h" | ||
| #include "triton/Analysis/AxisInfo.h" | ||
| #include "triton/Analysis/Membar.h" | ||
| #include "triton/Conversion/TritonGPUToLLVM/Utility.h" | ||
| #include "triton/Dialect/Triton/IR/Dialect.h" | ||
| #include "triton/Dialect/TritonCPU/IR/Dialect.h" | ||
|
|
||
| #include <numeric> | ||
|
|
||
| namespace mlir { | ||
| namespace triton { | ||
| #define GEN_PASS_DEF_CONVERTREDUCTIONOP | ||
|
|
@@ -44,28 +39,91 @@ class ReductionConversionTarget : public ConversionTarget { | |
| } | ||
| }; | ||
|
|
||
| struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> { | ||
| using OpConversionPattern::OpConversionPattern; | ||
| struct ReduceOpConversion | ||
| : public ReduceScanOpConversionBase<triton::ReduceOp, | ||
| triton::ReduceReturnOp> { | ||
| using ReduceScanOpConversionBase::ReduceScanOpConversionBase; | ||
|
|
||
| LogicalResult | ||
| matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| MLIRContext *ctx = op.getContext(); | ||
| // Currently, only simple reductions with a single input argumet are | ||
| // supported. | ||
| // TODO: support generic case. | ||
| // More simple cases with a single input and a single combine | ||
| // operation can utilize target-specific reduction operations like | ||
| // horizaontal vector operations. We detect such cases here and map | ||
| // them to the vector::MultiDimReductionOp. | ||
| if (succeeded(mapToMultiDimReductionOp(op, rewriter))) | ||
| return success(); | ||
|
|
||
| return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter); | ||
| } | ||
|
|
||
| SmallVector<Value> | ||
| lower1DInput(ValueRange inputs, ReduceOp op, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto loc = op.getLoc(); | ||
| Region &combineOp = op.getRegion(); | ||
| int64_t vecSize = cast<VectorType>(inputs[0].getType()).getShape()[0]; | ||
| SmallVector<int64_t> range(vecSize); | ||
| std::iota(range.begin(), range.end(), 0); | ||
|
|
||
| ArrayRef<Value> dummies = createShuffleDummies(loc, inputs, rewriter); | ||
| SmallVector<Value> res = inputs; | ||
| for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) { | ||
| SmallVector<int64_t> shuffleIndices = range; | ||
| for (int64_t i = 0; i < stride; ++i) { | ||
| std::swap(shuffleIndices[i], shuffleIndices[i + stride]); | ||
| } | ||
| SmallVector<Value> shuffledInput; | ||
| for (auto [val, dummy] : llvm::zip(res, dummies)) { | ||
| shuffledInput.push_back(rewriter.create<vector::ShuffleOp>( | ||
| loc, val, dummy, shuffleIndices)); | ||
| } | ||
|
|
||
| res = accumulate(shuffledInput, res, combineOp, rewriter); | ||
| } | ||
|
|
||
| // The results are in the first element of each produced vector. | ||
| Value zero = | ||
| rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)); | ||
| for (size_t i = 0; i < res.size(); ++i) { | ||
| res[i] = rewriter.create<vector::ExtractElementOp>(loc, res[i], zero); | ||
| } | ||
| return res; | ||
| } | ||
|
|
||
| SmallVector<Value> | ||
| lowerLeadingDimension(ValueRange inputs, ReduceOp op, | ||
| ConversionPatternRewriter &rewriter) const override { | ||
| auto loc = op.getLoc(); | ||
| Region &combineOp = op.getRegion(); | ||
| auto shape = cast<VectorType>(inputs[0].getType()).getShape(); | ||
| SmallVector<Value> res; | ||
| for (int64_t idx = 0; idx < shape[0]; ++idx) { | ||
| SmallVector<Value> subInputs(inputs.size()); | ||
| std::transform(inputs.begin(), inputs.end(), subInputs.begin(), | ||
| [&](auto val) { | ||
| return rewriter.create<vector::ExtractOp>(loc, val, idx); | ||
| }); | ||
|
|
||
| res = accumulate(subInputs, res, combineOp, rewriter); | ||
| } | ||
| return res; | ||
| } | ||
|
|
||
| LogicalResult | ||
| mapToMultiDimReductionOp(triton::ReduceOp op, | ||
| ConversionPatternRewriter &rewriter) const { | ||
| if (op.getNumOperands() != 1 || op.getNumResults() != 1) | ||
| return failure(); | ||
|
|
||
| Value src = rewriter.getRemappedValue(op.getOperand(0)); | ||
| VectorType srcTy = dyn_cast<VectorType>(src.getType()); | ||
| assert(srcTy); | ||
| VectorType srcTy = cast<VectorType>(src.getType()); | ||
|
|
||
| Block *block = op.getBody(); | ||
| if (block->getNumArguments() != 2) | ||
| return failure(); | ||
| Value itArg = block->getArgument(0); | ||
| Value accArg = block->getArgument(1); | ||
| Value accArg = block->getArgument(0); | ||
| Value itArg = block->getArgument(1); | ||
|
|
||
| auto &blockOps = block->getOperations(); | ||
| if (blockOps.size() != 2) | ||
|
|
@@ -155,7 +213,18 @@ struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> { | |
| elemTy, static_cast<int64_t>( | ||
| (1UL << (elemTy.getIntOrFloatBitWidth() - 1)) - 1)); | ||
| else if (kind == vector::CombiningKind::MINIMUMF || | ||
| kind == vector::CombiningKind::MINNUMF) { | ||
| kind == vector::CombiningKind::MAXIMUMF) { | ||
| if (elemTy.isF32()) | ||
| initVal = | ||
| rewriter.getF32FloatAttr(std::numeric_limits<float>::quiet_NaN()); | ||
| else if (elemTy.isF64()) | ||
| initVal = | ||
| rewriter.getF64FloatAttr(std::numeric_limits<double>::quiet_NaN()); | ||
| else | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not urgent, maybe we can support F16/BF16 using its raw binary representations for quite_NaN?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, this needs to be added. I couldn't yet find examples of how such constants are created, used, and lowered. Did you see any examples?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just saw one F16 testing cases in the test_core.py regarding reduction. But not urgent at all. |
||
| llvm_unreachable("Unsupported type for acc init value."); | ||
| } | ||
|
|
||
| else if (kind == vector::CombiningKind::MINNUMF) { | ||
| if (elemTy.isF32()) | ||
| initVal = | ||
| rewriter.getF32FloatAttr(std::numeric_limits<float>::infinity()); | ||
|
|
@@ -164,8 +233,7 @@ struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> { | |
| rewriter.getF64FloatAttr(std::numeric_limits<double>::infinity()); | ||
| else | ||
| llvm_unreachable("Unsupported type for acc init value."); | ||
| } else if (kind == vector::CombiningKind::MAXIMUMF || | ||
| kind == vector::CombiningKind::MAXNUMF) { | ||
| } else if (kind == vector::CombiningKind::MAXNUMF) { | ||
| if (elemTy.isF32()) | ||
| initVal = | ||
| rewriter.getF32FloatAttr(-std::numeric_limits<float>::infinity()); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. It's like
std::fminandstd::min.