Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -967,39 +967,54 @@ struct MulHiUIOpConverter : public OpConversionPattern<triton::MulhiUIOp> {
struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern<triton::DotOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto opa = adaptor.getA();
auto opb = adaptor.getB();
auto opc = adaptor.getC();
auto opcOrig = op.getC();

bool skipC = false;
if (auto splatOp = opcOrig.getDefiningOp<triton::SplatOp>()) {
if (auto val = splatOp.getSrc().getDefiningOp<arith::ConstantOp>()) {
if (cast<FloatAttr>(val.getValue()).getValueAsDouble() == 0.) {
skipC = true;
// true means tensor elements are zeros
// false means not zero or it cannot be determined
bool isZeroTensor(Value &v, bool integers) const {
if (auto splatOp = v.getDefiningOp<triton::SplatOp>()) {
if (auto constOp = splatOp.getSrc().getDefiningOp<arith::ConstantOp>()) {
if (auto val = dyn_cast<FloatAttr>(constOp.getValue())) {
return val.getValueAsDouble() == 0.;
}
if (auto val = dyn_cast<IntegerAttr>(constOp.getValue())) {
return val.getValue() == 0;
}
}
return false;
}
} else if (auto constOp = opcOrig.getDefiningOp<arith::ConstantOp>()) {
if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue())) {
if (denseAttr.isSplat() &&
denseAttr.getSplatValue<FloatAttr>().getValueAsDouble() == 0.) {
skipC = true;

if (auto constOp = v.getDefiningOp<arith::ConstantOp>()) {
if (auto denseAttr = dyn_cast<DenseElementsAttr>(constOp.getValue())) {
if (denseAttr.isSplat()) {
if (integers)
return denseAttr.getSplatValue<APInt>().isZero();
return denseAttr.getSplatValue<APFloat>().isZero();
}
}
}
}

auto dstType = cast<RankedTensorType>(op.getType());
auto elemType = dstType.getElementType();
return false;
}

LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op.getLoc();
auto opa = op.getA();
auto opb = op.getB();
auto opc = op.getC();

auto dstType = cast<RankedTensorType>(op.getType());
auto elementType = dstType.getElementType();
bool integers = elementType.isInteger();
bool skipC = isZeroTensor(opc, integers);
auto init =
rewriter.create<tensor::EmptyOp>(loc, dstType.getShape(), elemType);
rewriter.create<tensor::EmptyOp>(loc, dstType.getShape(), elementType);
TypedAttr constantAttr = integers ?
static_cast<TypedAttr>(rewriter.getIntegerAttr(elementType, 0)) :
static_cast<TypedAttr>(rewriter.getFloatAttr(elementType, 0));

auto zero = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), elemType, rewriter.getFloatAttr(elemType, 0));
op.getLoc(), elementType, constantAttr);

auto zeroes =
rewriter.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init})
Expand All @@ -1011,7 +1026,11 @@ struct MatmulConverter : public OpConversionPattern<triton::DotOp> {
.getResult(0);

if (!skipC) {
res = rewriter.create<arith::AddFOp>(loc, res, opc);
if (integers) {
res = rewriter.create<arith::AddIOp>(loc, res, opc);
} else {
res = rewriter.create<arith::AddFOp>(loc, res, opc);
}
}

rewriter.replaceOp(op, res);
Expand Down
11 changes: 8 additions & 3 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def device(request):
"test_chain_reduce",
"test_generic_reduction",
"test_trans_4d",
"test_dot",
"test_dot3d",
"test_constexpr",
"test_arange",
Expand All @@ -90,6 +89,8 @@ def pytest_collection_modifyitems(config, items):
skip_marker = pytest.mark.skip(reason="CPU backend does not support it yet")
# There is a dependency issue on build machine which breaks bfloat16
skip_marker_bfloat = pytest.mark.skip(reason="bfloat16 linking issue")
skip_marker_tf32 = pytest.mark.skip(reason="tf32 is not supported on CPU")
skip_marker_float8 = pytest.mark.skip(reason="float8 is not supported on CPU")

for item in items:
test_func_name = item.originalname if item.originalname else item.name
Expand All @@ -100,5 +101,9 @@ def pytest_collection_modifyitems(config, items):

if "parametrize" in item.keywords:
for param_name, param_value in item.callspec.params.items():
if param_name.startswith('dtype') and param_value == 'bfloat16':
item.add_marker(skip_marker_bfloat)
if (param_name.startswith('dtype') or param_name.endswith('dtype')) and param_value == 'bfloat16':
item.add_marker(skip_marker_bfloat)
if param_name.startswith('input_precision') and param_value.startswith('tf32'):
item.add_marker(skip_marker_tf32)
if param_name.endswith('dtype') and ('float8' in str(param_value)):
item.add_marker(skip_marker_float8)