diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 633d8c4c49c6..9d9c1906ec50 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1669,6 +1669,7 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr, ARG_HASH: tl.constex np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) +@pytest.mark.cpu @pytest.mark.interpreter @pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) diff --git a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp index 218dd827619a..cadec818910b 100644 --- a/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp +++ b/third_party/cpu/lib/TritonToTritonCPU/ConvertElementwiseOps.cpp @@ -60,6 +60,7 @@ class ElementwiseOpConversionTarget : public ConversionTarget { addIllegalOp(); addIllegalOp(); addIllegalOp(); + addIllegalOp(); } }; @@ -236,6 +237,24 @@ struct JoinOpConversion : public OpConversionPattern { } }; +struct CatOpConversion : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto lhs = rewriter.getRemappedValue(op.getLhs()); + auto rhs = rewriter.getRemappedValue(op.getRhs()); + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + SmallVector indices(lhsTy.getShape()[0] + rhsTy.getShape()[0]); + std::iota(indices.begin(), indices.end(), 0); + rewriter.replaceOpWithNewOp(op, lhs, rhs, indices); + return success(); + } +}; + struct ConvertElementwiseOps : public triton::impl::ConvertElementwiseOpsBase { using ConvertElementwiseOpsBase::ConvertElementwiseOpsBase; @@ -320,6 +339,7 @@ struct ConvertElementwiseOps patterns.add(typeConverter, context); patterns.add(typeConverter, context); patterns.add(typeConverter, context); + patterns.add(typeConverter, context); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure();