diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index e3c23caf..6d95f878 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -967,39 +967,54 @@ struct MulHiUIOpConverter : public OpConversionPattern { struct MatmulConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto opa = adaptor.getA(); - auto opb = adaptor.getB(); - auto opc = adaptor.getC(); - auto opcOrig = op.getC(); - - bool skipC = false; - if (auto splatOp = opcOrig.getDefiningOp()) { - if (auto val = splatOp.getSrc().getDefiningOp()) { - if (cast(val.getValue()).getValueAsDouble() == 0.) { - skipC = true; + // true means tensor elements are zeros + // false means not zero or it cannot be determined + bool isZeroTensor(Value &v, bool integers) const { + if (auto splatOp = v.getDefiningOp()) { + if (auto constOp = splatOp.getSrc().getDefiningOp()) { + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValueAsDouble() == 0.; + } + if (auto val = dyn_cast(constOp.getValue())) { + return val.getValue() == 0; + } } + return false; } - } else if (auto constOp = opcOrig.getDefiningOp()) { - if (auto denseAttr = dyn_cast(constOp.getValue())) { - if (denseAttr.isSplat() && - denseAttr.getSplatValue().getValueAsDouble() == 0.) { - skipC = true; + + if (auto constOp = v.getDefiningOp()) { + if (auto denseAttr = dyn_cast(constOp.getValue())) { + if (denseAttr.isSplat()) { + if (integers) + return denseAttr.getSplatValue().isZero(); + return denseAttr.getSplatValue().isZero(); + } } } - } - auto dstType = cast(op.getType()); - auto elemType = dstType.getElementType(); + return false; + } + + LogicalResult + matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); + auto opa = op.getA(); + auto opb = op.getB(); + auto opc = op.getC(); + auto dstType = cast(op.getType()); + auto elementType = dstType.getElementType(); + bool integers = elementType.isInteger(); + bool skipC = isZeroTensor(opc, integers); auto init = - rewriter.create(loc, dstType.getShape(), elemType); + rewriter.create(loc, dstType.getShape(), elementType); + TypedAttr constantAttr = integers ? + static_cast(rewriter.getIntegerAttr(elementType, 0)) : + static_cast(rewriter.getFloatAttr(elementType, 0)); auto zero = rewriter.create( - op.getLoc(), elemType, rewriter.getFloatAttr(elemType, 0)); + op.getLoc(), elementType, constantAttr); auto zeroes = rewriter.create(loc, ValueRange{zero}, ValueRange{init}) @@ -1011,7 +1026,11 @@ struct MatmulConverter : public OpConversionPattern { .getResult(0); if (!skipC) { - res = rewriter.create(loc, res, opc); + if (integers) { + res = rewriter.create(loc, res, opc); + } else { + res = rewriter.create(loc, res, opc); + } } rewriter.replaceOp(op, res); diff --git a/python/examples/conftest.py b/python/examples/conftest.py index d8d048a1..77b4d46b 100644 --- a/python/examples/conftest.py +++ b/python/examples/conftest.py @@ -64,7 +64,6 @@ def device(request): "test_chain_reduce", "test_generic_reduction", "test_trans_4d", - "test_dot", "test_dot3d", "test_constexpr", "test_arange", @@ -90,6 +89,8 @@ def pytest_collection_modifyitems(config, items): skip_marker = pytest.mark.skip(reason="CPU backend does not support it yet") # There is a dependency issue on build machine which breaks bfloat16 skip_marker_bfloat = pytest.mark.skip(reason="bfloat16 linking issue") + skip_marker_tf32 = pytest.mark.skip(reason="tf32 is not supported on CPU") + skip_marker_float8 = pytest.mark.skip(reason="float8 is not supported on CPU") for item in items: test_func_name = item.originalname if item.originalname else item.name @@ -100,5 +101,9 @@ def pytest_collection_modifyitems(config, items): if "parametrize" in item.keywords: for param_name, param_value in item.callspec.params.items(): - if param_name.startswith('dtype') and param_value == 'bfloat16': - item.add_marker(skip_marker_bfloat) \ No newline at end of file + if (param_name.startswith('dtype') or param_name.endswith('dtype')) and param_value == 'bfloat16': + item.add_marker(skip_marker_bfloat) + if param_name.startswith('input_precision') and param_value.startswith('tf32'): + item.add_marker(skip_marker_tf32) + if param_name.endswith('dtype') and ('float8' in str(param_value)): + item.add_marker(skip_marker_float8)