Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3105,6 +3105,7 @@ def convert_fp8_to_fp32(x, device, dtype_str):
assert "Unsupported float8 dtype"


@pytest.mark.cpu
@pytest.mark.interpreter
@pytest.mark.parametrize(
"M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dtype, out_dtype, kpack",
Expand All @@ -3117,7 +3118,7 @@ def convert_fp8_to_fp32(x, device, dtype_str):
[(*shape_nw, col_a, col_b, 'none', input_precision, in_dtype, out_dtype, kpack)
for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4],
[32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]]
for input_precision in ["ieee" if is_hip() else "tf32"]
for input_precision in ["ieee" if is_hip() or is_cpu() else "tf32"]
for col_a in [True, False]
for col_b in [True, False]
for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32',
Expand All @@ -3133,6 +3134,18 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, input_precision, in_dty
if is_interpreter():
if in_dtype == 'bfloat16':
pytest.skip("bfloat16 is not supported in the interpreter")
elif is_cpu():
if input_precision != "ieee":
pytest.skip(f"{input_precision} not supported on CPU")
if in_dtype == 'float8e4nv' or in_dtype == 'float8e5':
pytest.skip("float8e4nv and float8e5 not supported on CPU")
# This test kernel runs in a single thread and can take a long time
# for bigger sizes with the current codegen on CPU. Limit input sizes
# by default to get more reasonable tests execution time.
if os.environ.get('TRITON_CPU_TEST_DOT_FULL_SIZE', '0') != '1':
M = min(M, 64)
N = min(N, 64)
K = min(K, 32)
else:
if is_cuda():
capability = torch.cuda.get_device_capability()
Expand Down
7 changes: 5 additions & 2 deletions third_party/cpu/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ def make_tttcir(self, mod, metadata, opt):
# TTCIR -> Target TTCIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
if self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features:
cpu.passes.ttcpuir.add_convert_unsupported_ops(pm)
promote_bf16_to_fp32 = self.cpu_arch == "x86_64" and "avx512bf16" not in self.cpu_features
# We don't have any lowering for mixed precision matmuls, so always use casts for now
convert_mixed_precision_matmul = True
cpu.passes.ttcpuir.add_convert_unsupported_ops(pm, promote_bf16_to_fp32, convert_mixed_precision_matmul)
if promote_bf16_to_fp32:
cpu.passes.ttcpuir.add_decompose_fp_conversions(pm)
passes.common.add_cse(pm)
passes.common.add_symbol_dce(pm)
Expand Down
3 changes: 3 additions & 0 deletions third_party/cpu/include/TritonCPUTransforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ namespace cpu {
#include "cpu/include/TritonCPUTransforms/Passes.h.inc"

std::unique_ptr<OperationPass<ModuleOp>> createConvertUnsupportedOps();
std::unique_ptr<OperationPass<ModuleOp>>
createConvertUnsupportedOps(bool promoteBf16ToFp32,
bool convertMixedPrecisionMatmul);
std::unique_ptr<OperationPass<ModuleOp>> createDecomposeFpConversions();

#define GEN_PASS_REGISTRATION
Expand Down
11 changes: 10 additions & 1 deletion third_party/cpu/include/TritonCPUTransforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,16 @@ def ConvertUnsupportedOps : Pass<"triton-cpu-add-casts-for-unsupported-ops", "ml
by the target natively. Operations are converted to a supported data type
with casts added for inputs and the result.
}];
// TODO: add options to specify which operations to convert.

let options = [
Option<"promoteBf16ToFp32", "promote-bf16-to-fp32",
"bool", /*default*/"false",
"Convert BF16 operations to FP32.">,
Option<"convertMixedPrecisionMatmul", "convert-mixed-precision-matmul",
"bool", /*default*/"false",
"Convert inputs of a mixed-precision matmul to a destination type.">,
];

let constructor = "mlir::triton::cpu::createConvertUnsupportedOps()";

let dependentDialects = ["mlir::arith::ArithDialect",
Expand Down
103 changes: 92 additions & 11 deletions third_party/cpu/lib/TritonCPUTransforms/ConvertUnsupportedOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

namespace mlir {
namespace triton {
namespace cpu {
#define GEN_PASS_DEF_CONVERTUNSUPPORTEDOPS
#include "cpu/include/TritonCPUTransforms/Passes.h.inc"
} // namespace cpu
} // namespace triton
} // namespace mlir

Expand Down Expand Up @@ -165,24 +167,96 @@ struct ConvertBf16Abs : public OpRewritePattern<math::AbsFOp> {
}
};

struct ConvertMixedPrecisionMatmul
: public OpRewritePattern<vector::ContractionOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value lhs = op.getLhs();
Value rhs = op.getRhs();
Value acc = op.getAcc();
auto lhsTy = cast<VectorType>(lhs.getType());
auto rhsTy = cast<VectorType>(rhs.getType());
auto accTy = cast<VectorType>(acc.getType());
auto resTy = cast<VectorType>(op.getType());

if (lhsTy.getElementType() == resTy.getElementType() &&
rhsTy.getElementType() == resTy.getElementType() &&
accTy.getElementType() == resTy.getElementType())
return failure();

Type commonElemTy = resTy.getElementType();
if (lhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth())
commonElemTy = lhsTy;
if (rhsTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth())
commonElemTy = rhsTy;
if (accTy.getElementTypeBitWidth() > commonElemTy.getIntOrFloatBitWidth())
commonElemTy = accTy;

lhs = castElemTy(loc, lhs, commonElemTy, rewriter);
rhs = castElemTy(loc, rhs, commonElemTy, rewriter);
acc = castElemTy(loc, acc, commonElemTy, rewriter);

Value newRes = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, acc, op.getIndexingMaps(), op.getIteratorTypes());
newRes = castElemTy(loc, newRes, resTy.getElementType(), rewriter);

rewriter.replaceOp(op, newRes);
return success();
}

Value castElemTy(Location loc, Value val, Type elemTy,
PatternRewriter &rewriter) const {
auto valTy = cast<VectorType>(val.getType());
if (valTy.getElementType() == elemTy)
return val;

auto resTy = toTyOrVectorOf(valTy, elemTy);
if (valTy.getElementType().isInteger()) {
if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth())
return rewriter.create<arith::TruncIOp>(loc, resTy, val);
else
return rewriter.create<arith::ExtSIOp>(loc, resTy, val);
} else {
if (valTy.getElementTypeBitWidth() > elemTy.getIntOrFloatBitWidth())
return rewriter.create<arith::TruncFOp>(loc, resTy, val);
else
return rewriter.create<arith::ExtFOp>(loc, resTy, val);
}
}
};

struct ConvertUnsupportedOps
: public triton::impl::ConvertUnsupportedOpsBase<ConvertUnsupportedOps> {
using ConvertUnsupportedOpsBase::ConvertUnsupportedOpsBase;
: public triton::cpu::impl::ConvertUnsupportedOpsBase<
ConvertUnsupportedOps> {
ConvertUnsupportedOps() = default;

ConvertUnsupportedOps(bool promoteBf16ToFp32,
bool convertMixedPrecisionMatmul) {
this->promoteBf16ToFp32 = promoteBf16ToFp32;
this->convertMixedPrecisionMatmul = convertMixedPrecisionMatmul;
}

void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();

RewritePatternSet patterns(context);
patterns.add<ConvertBf16ToFp32<arith::AddFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::SubFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::MulFOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::SIToFPOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::UIToFPOp>>(context);
patterns.add<ConvertBf16MaskedLoadOp>(context);
patterns.add<ConvertBf16MaskedStoreOp>(context);

patterns.add<ConvertBf16Abs>(context);
if (promoteBf16ToFp32) {
patterns.add<ConvertBf16ToFp32<arith::AddFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::SubFOp>>(context);
patterns.add<ConvertBf16ToFp32<arith::MulFOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::SIToFPOp>>(context);
patterns.add<ConvertIToBf16ToFp32<arith::UIToFPOp>>(context);
patterns.add<ConvertBf16MaskedLoadOp>(context);
patterns.add<ConvertBf16MaskedStoreOp>(context);
patterns.add<ConvertBf16Abs>(context);
}
if (convertMixedPrecisionMatmul) {
patterns.add<ConvertMixedPrecisionMatmul>(context);
}

if (failed(mlir::applyPatternsAndFoldGreedily(mod, std::move(patterns))))
return signalPassFailure();
Expand All @@ -199,6 +273,13 @@ std::unique_ptr<OperationPass<ModuleOp>> createConvertUnsupportedOps() {
return std::make_unique<ConvertUnsupportedOps>();
}

std::unique_ptr<OperationPass<ModuleOp>>
createConvertUnsupportedOps(bool promoteBf16ToFp32,
bool convertMixedPrecisionMatmul) {
return std::make_unique<ConvertUnsupportedOps>(promoteBf16ToFp32,
convertMixedPrecisionMatmul);
}

} // namespace cpu
} // namespace triton
} // namespace mlir
9 changes: 6 additions & 3 deletions third_party/cpu/triton_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ void init_triton_cpu_passes_ttcpuir(py::module &&m) {
m.def("add_triton_cpu_to_llvmir_pipeline", [](mlir::PassManager &pm) {
mlir::triton::cpu::tritonCPUToLLVMPipelineBuilder(pm);
});
m.def("add_convert_unsupported_ops", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps());
});
m.def("add_convert_unsupported_ops",
[](mlir::PassManager &pm, bool promote_bf16_to_fp32,
bool convert_mixed_precision_matmul) {
pm.addPass(mlir::triton::cpu::createConvertUnsupportedOps(
promote_bf16_to_fp32, convert_mixed_precision_matmul));
});
m.def("add_decompose_fp_conversions", [](mlir::PassManager &pm) {
pm.addPass(mlir::triton::cpu::createDecomposeFpConversions());
});
Expand Down