Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1987,6 +1987,7 @@ def deserialize_fp8(np_data, in_dtype):
# ---------------


@pytest.mark.cpu
@pytest.mark.interpreter
def test_max_returns_zero(device):
# Simple test with a tl.max call that returns 0. The interpreter had a bug
Expand All @@ -2013,6 +2014,7 @@ def get_reduced_dtype(dtype_str, op):
return dtype_str


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [
'min',
Expand Down Expand Up @@ -2131,9 +2133,6 @@ def kernel(X, Z, BLOCK: tl.constexpr):
def test_reduce(op, dtype_str, shape, axis, keep_dims, num_ctas, device):
check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested

if is_cpu() and op in ('argmin', 'argmax'):
pytest.skip(f"Not yet implemented on CPU: {op}")

@triton.jit
def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr,
AXIS: tl.constexpr, KEEP_DIMS: tl.constexpr):
Expand Down Expand Up @@ -2261,17 +2260,24 @@ def roll(a1, b1_last, b1_cur, a2, b2_last, b2_cur):
return a1 + a2, tl.where(a2 == 1, b1_cur, 0) + b2_last, b2_cur


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize("op, dtype_str, shape, axis, reverse, num_warps", scan_configs + negative_config)
def test_scan2d(op, dtype_str, shape, axis, reverse, num_warps, device):
check_type_supported(dtype_str, device)
if dtype_str == 'bfloat16':
if op == 'cummax':
if is_cuda() and op == 'cummax':
pytest.skip("bfloat16 compare not suppoted before sm90")
if op == 'linear_recurrence':
pytest.skip("Skipping linear_recurrence scan on bfloat16 due to accuracy issues")
numpy_dtype_str = 'float32' if dtype_str == 'bfloat16' else dtype_str

# bf16 vector cast is broken in LLVM for large vectors:
# https://github.com/llvm/llvm-project/issues/92471
# TODO: Remove the change after the bug is fixed.
if is_cpu() and dtype_str == 'bfloat16':
shape = (min(shape[0], 128), min(shape[1], 128))

# triton kernel
@triton.jit
def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexpr):
Expand Down Expand Up @@ -2876,6 +2882,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)


@pytest.mark.cpu
@pytest.mark.interpreter
def test_generic_reduction(device):

Expand Down
1 change: 1 addition & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertDotOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertControlFlowOps();
std::unique_ptr<OperationPass<ModuleOp>> createConvertHistogramOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertReductionOp();
std::unique_ptr<OperationPass<ModuleOp>> createConvertScanOp();

void tritonToTritonCPUPipelineBuilder(OpPassManager &pm);
void registerTritonToTritonCPUPipeline();
Expand Down
14 changes: 14 additions & 0 deletions third_party/cpu/include/TritonToTritonCPU/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -100,4 +100,18 @@ def ConvertReductionOp : Pass<"triton-cpu-convert-reduction", "mlir::ModuleOp">
"mlir::triton::cpu::TritonCPUDialect"];
}

def ConvertScanOp : Pass<"triton-cpu-convert-scan", "mlir::ModuleOp"> {
let summary = "Convert Triton ScanOp.";
let description = [{

}];
let constructor = "mlir::triton::cpu::createConvertScanOp()";

let dependentDialects = ["mlir::arith::ArithDialect",
"mlir::vector::VectorDialect",
"mlir::scf::SCFDialect",
"mlir::triton::TritonDialect",
"mlir::triton::cpu::TritonCPUDialect"];
}

#endif
3 changes: 2 additions & 1 deletion third_party/cpu/lib/TritonToTritonCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ add_triton_library(TritonToTritonCPU
ConvertElementwiseOps.cpp
ConvertHistogramOp.cpp
ConvertMemoryOps.cpp
ConvertReductionOp.cpp
ConvertPtrOps.cpp
ConvertReductionOp.cpp
ConvertScanOp.cpp
Pipeline.cpp
TypeConverter.cpp

Expand Down
112 changes: 90 additions & 22 deletions third_party/cpu/lib/TritonToTritonCPU/ConvertReductionOp.cpp
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
#include "ReduceScanCommon.h"
#include "TypeConverter.h"

#include "cpu/include/TritonToTritonCPU/Passes.h"

#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "triton/Analysis/Allocation.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Analysis/Membar.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonCPU/IR/Dialect.h"

#include <numeric>

namespace mlir {
namespace triton {
#define GEN_PASS_DEF_CONVERTREDUCTIONOP
Expand Down Expand Up @@ -44,28 +39,91 @@ class ReductionConversionTarget : public ConversionTarget {
}
};

struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern::OpConversionPattern;
struct ReduceOpConversion
: public ReduceScanOpConversionBase<triton::ReduceOp,
triton::ReduceReturnOp> {
using ReduceScanOpConversionBase::ReduceScanOpConversionBase;

LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *ctx = op.getContext();
// Currently, only simple reductions with a single input argumet are
// supported.
// TODO: support generic case.
// More simple cases with a single input and a single combine
// operation can utilize target-specific reduction operations like
// horizaontal vector operations. We detect such cases here and map
// them to the vector::MultiDimReductionOp.
if (succeeded(mapToMultiDimReductionOp(op, rewriter)))
return success();

return ReduceScanOpConversionBase::matchAndRewrite(op, adaptor, rewriter);
}

SmallVector<Value>
lower1DInput(ValueRange inputs, ReduceOp op,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Region &combineOp = op.getRegion();
int64_t vecSize = cast<VectorType>(inputs[0].getType()).getShape()[0];
SmallVector<int64_t> range(vecSize);
std::iota(range.begin(), range.end(), 0);

ArrayRef<Value> dummies = createShuffleDummies(loc, inputs, rewriter);
SmallVector<Value> res = inputs;
for (int64_t stride = vecSize / 2; stride > 0; stride = stride / 2) {
SmallVector<int64_t> shuffleIndices = range;
for (int64_t i = 0; i < stride; ++i) {
std::swap(shuffleIndices[i], shuffleIndices[i + stride]);
}
SmallVector<Value> shuffledInput;
for (auto [val, dummy] : llvm::zip(res, dummies)) {
shuffledInput.push_back(rewriter.create<vector::ShuffleOp>(
loc, val, dummy, shuffleIndices));
}

res = accumulate(shuffledInput, res, combineOp, rewriter);
}

// The results are in the first element of each produced vector.
Value zero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
for (size_t i = 0; i < res.size(); ++i) {
res[i] = rewriter.create<vector::ExtractElementOp>(loc, res[i], zero);
}
return res;
}

SmallVector<Value>
lowerLeadingDimension(ValueRange inputs, ReduceOp op,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
Region &combineOp = op.getRegion();
auto shape = cast<VectorType>(inputs[0].getType()).getShape();
SmallVector<Value> res;
for (int64_t idx = 0; idx < shape[0]; ++idx) {
SmallVector<Value> subInputs(inputs.size());
std::transform(inputs.begin(), inputs.end(), subInputs.begin(),
[&](auto val) {
return rewriter.create<vector::ExtractOp>(loc, val, idx);
});

res = accumulate(subInputs, res, combineOp, rewriter);
}
return res;
}

LogicalResult
mapToMultiDimReductionOp(triton::ReduceOp op,
ConversionPatternRewriter &rewriter) const {
if (op.getNumOperands() != 1 || op.getNumResults() != 1)
return failure();

Value src = rewriter.getRemappedValue(op.getOperand(0));
VectorType srcTy = dyn_cast<VectorType>(src.getType());
assert(srcTy);
VectorType srcTy = cast<VectorType>(src.getType());

Block *block = op.getBody();
if (block->getNumArguments() != 2)
return failure();
Value itArg = block->getArgument(0);
Value accArg = block->getArgument(1);
Value accArg = block->getArgument(0);
Value itArg = block->getArgument(1);

auto &blockOps = block->getOperations();
if (blockOps.size() != 2)
Expand Down Expand Up @@ -155,7 +213,18 @@ struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> {
elemTy, static_cast<int64_t>(
(1UL << (elemTy.getIntOrFloatBitWidth() - 1)) - 1));
else if (kind == vector::CombiningKind::MINIMUMF ||
kind == vector::CombiningKind::MINNUMF) {
kind == vector::CombiningKind::MAXIMUMF) {
if (elemTy.isF32())
initVal =
rewriter.getF32FloatAttr(std::numeric_limits<float>::quiet_NaN());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. It's like std::fmin and std::min.

else if (elemTy.isF64())
initVal =
rewriter.getF64FloatAttr(std::numeric_limits<double>::quiet_NaN());
else
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not urgent, maybe we can support F16/BF16 using its raw binary representations for quite_NaN?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this needs to be added. I couldn't yet find examples of how such constants are created, used, and lowered. Did you see any examples?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just saw one F16 testing cases in the test_core.py regarding reduction. But not urgent at all.

llvm_unreachable("Unsupported type for acc init value.");
}

else if (kind == vector::CombiningKind::MINNUMF) {
if (elemTy.isF32())
initVal =
rewriter.getF32FloatAttr(std::numeric_limits<float>::infinity());
Expand All @@ -164,8 +233,7 @@ struct ReduceOpConversion : public OpConversionPattern<triton::ReduceOp> {
rewriter.getF64FloatAttr(std::numeric_limits<double>::infinity());
else
llvm_unreachable("Unsupported type for acc init value.");
} else if (kind == vector::CombiningKind::MAXIMUMF ||
kind == vector::CombiningKind::MAXNUMF) {
} else if (kind == vector::CombiningKind::MAXNUMF) {
if (elemTy.isF32())
initVal =
rewriter.getF32FloatAttr(-std::numeric_limits<float>::infinity());
Expand Down
Loading