-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[MLIR][Conversion][XeGPU][XeVM] Add XeGPUToXeVM conversion pass and tests. #154556
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) ChangesAdd XeGPU to XeVM conversion pass and tests. Patch is 80.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/154556.diff 13 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 91b2ecf8922a3..da061b269daf7 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -82,6 +82,7 @@
#include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRVPass.h"
#include "mlir/Conversion/VectorToXeGPU/VectorToXeGPU.h"
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
#include "mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h"
namespace mlir {
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 2058aba7f9e37..323af3e97e2d4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1555,4 +1555,16 @@ def ConvertXeVMToLLVMPass : Pass<"convert-xevm-to-llvm"> {
let dependentDialects = ["LLVM::LLVMDialect"];
}
+//===----------------------------------------------------------------------===//
+// XeGPUToXeVM
+//===----------------------------------------------------------------------===//
+
+def ConvertXeGPUToXeVMPass : Pass<"convert-xegpu-to-xevm"> {
+ let summary = "Convert XeGPU to XeVM dialect";
+ let dependentDialects = ["xevm::XeVMDialect", "vector::VectorDialect",
+ "memref::MemRefDialect", "arith::ArithDialect",
+ "LLVM::LLVMDialect", "index::IndexDialect",
+ "gpu::GPUDialect", "scf::SCFDialect"];
+}
+
#endif // MLIR_CONVERSION_PASSES
diff --git a/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
new file mode 100644
index 0000000000000..fb23d24b0161b
--- /dev/null
+++ b/mlir/include/mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h
@@ -0,0 +1,27 @@
+//===-- XeGPUToXeVM.h - Convert XeGPU to XeVM dialect ---------_--*- C++-*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+#define MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
+
+#include <memory>
+
+namespace mlir {
+class DialectRegistry;
+class LLVMTypeConverter;
+class RewritePatternSet;
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+
+void populateXeGPUToXeVMConversionPatterns(
+ mlir::RewritePatternSet &patterns, mlir::LLVMTypeConverter &typeConverter);
+
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_XEGPUTOXEVM_XEGPUTOXEVMPASS_H_
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index 171f7169fd41d..134fe8e14ca38 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -76,3 +76,4 @@ add_subdirectory(VectorToSCF)
add_subdirectory(VectorToSPIRV)
add_subdirectory(VectorToXeGPU)
add_subdirectory(XeVMToLLVM)
+add_subdirectory(XeGPUToXeVM)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
new file mode 100644
index 0000000000000..ed54b0bb5ee81
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/CMakeLists.txt
@@ -0,0 +1,25 @@
+add_mlir_conversion_library(MLIRXeGPUToXeVM
+ XeGPUToXeVM.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/XeGPUToXeVM
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRFuncDialect
+ MLIRGPUDialect
+ MLIRLLVMCommonConversion
+ MLIRLLVMDialect
+ MLIRXeVMDialect
+ MLIRVectorDialect
+ MLIRArithDialect
+ MLIRIndexDialect
+ MLIRXeGPUDialect
+ MLIRPass
+ MLIRTransforms
+)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
new file mode 100644
index 0000000000000..380409afbc62e
--- /dev/null
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -0,0 +1,932 @@
+//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
+//
+// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
+#include "mlir/Dialect/LLVMIR/XeVMDialect.h"
+
+#include "mlir/Conversion/LLVMCommon/Pattern.h"
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/GPU/IR/GPUDialect.h"
+#include "mlir/Dialect/Index/IR/IndexDialect.h"
+#include "mlir/Dialect/Index/IR/IndexOps.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/MemRef/IR/MemRef.h"
+#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Transforms/Patterns.h"
+#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/FormatVariadic.h"
+
+#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Types.h"
+
+#include "llvm/ADT/TypeSwitch.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+namespace {
+
+enum class NdDescI32Layout : uint32_t {
+ BasePtr = 0,
+ BaseShapeW = 2,
+ BaseShapeH = 3,
+ TensorOffsetW = 4,
+ TensorOffsetH = 5
+};
+
+static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
+ switch (xeGpuMemspace) {
+ case xegpu::MemorySpace::Global:
+ return static_cast<int>(xevm::AddrSpace::GLOBAL);
+ case xegpu::MemorySpace::SLM:
+ return static_cast<int>(xevm::AddrSpace::SHARED);
+ }
+ llvm_unreachable("Unknown XeGPU memory space.");
+}
+
+template <typename T>
+std::tuple<bool, int32_t, int32_t> checkAllLinear(SmallVector<T> denseAttr) {
+ assert(!denseAttr.empty());
+ const int32_t intercept{static_cast<int32_t>(denseAttr[0])};
+ if (denseAttr.size() < 2)
+ return {true, 0, intercept};
+ const T slope{denseAttr[1] - denseAttr[0]};
+ for (size_t i = 1; i < denseAttr.size(); ++i)
+ if (denseAttr[i] - denseAttr[i - 1] != slope)
+ return {false, 0, 0};
+ return {true, static_cast<int32_t>(slope), intercept};
+}
+
+VectorType encodeVectorTypeTo(VectorType currentVecType, Type toElemType) {
+ auto elemType = currentVecType.getElementType();
+ auto currentBitWidth = elemType.getIntOrFloatBitWidth();
+ auto newBitWidth = toElemType.getIntOrFloatBitWidth();
+ const int size =
+ currentVecType.getNumElements() * currentBitWidth / newBitWidth;
+ return VectorType::get(size, toElemType);
+}
+
+xevm::LoadCacheControl
+translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal =
+ L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+ auto L3hintVal =
+ L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::CACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1C_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::CACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3C;
+ else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::LoadCacheControl::L1S_L2UC_L3UC;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::READ_INVALIDATE:
+ return xevm::LoadCacheControl::INVALIDATE_READ;
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+xevm::StoreCacheControl
+translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
+ std::optional<xegpu::CachePolicy> L3hint) {
+ auto L1hintVal =
+ L1hint.has_value() ? L1hint.value() : xegpu::CachePolicy::UNCACHED;
+ auto L3hintVal =
+ L3hint.has_value() ? L3hint.value() : xegpu::CachePolicy::UNCACHED;
+ switch (L1hintVal) {
+ case xegpu::CachePolicy::UNCACHED:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::STREAMING:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1S_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1S_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_BACK:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ case xegpu::CachePolicy::WRITE_THROUGH:
+ if (L3hintVal == xegpu::CachePolicy::UNCACHED)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
+ else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
+ return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
+ else
+ llvm_unreachable("Unsupported cache control.");
+ default:
+ llvm_unreachable("Unsupported cache control.");
+ }
+}
+
+class CreateNdDescToXeVMPattern
+ : public OpConversionPattern<xegpu::CreateNdDescOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::CreateNdDescOp op,
+ xegpu::CreateNdDescOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto source = op.getSource();
+ Type payloadElemTy = rewriter.getI32Type();
+ Type i64Ty = rewriter.getI64Type();
+ VectorType payloadTy = VectorType::get(8, payloadElemTy);
+ VectorType payloadI64Ty = VectorType::get(4, i64Ty);
+ Value payload = arith::ConstantOp::create(
+ rewriter, loc,
+ DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
+
+ Value baseAddr;
+ Value baseShapeW;
+ Value baseShapeH;
+ Value offsetW;
+ Value offsetH;
+
+ bool sourceIsMemref = false;
+ auto sourceTy = source.getType();
+ int64_t rank;
+ if (isa<MemRefType>(sourceTy)) {
+ sourceIsMemref = true;
+ baseAddr =
+ memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
+ auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+ if (!sourceMemrefTy.hasStaticShape()) {
+ op.emitError() << "Expected static memref shape.";
+ return failure();
+ }
+ rank = sourceMemrefTy.getRank();
+ if (rank != 2) {
+ op.emitError() << "Expected a 2D memref.";
+ return failure();
+ }
+ } else if (sourceTy == rewriter.getIntegerType(64, false)) {
+ rank = op.getMixedSizes().size();
+ } else {
+ op.emitError() << "Expected source to be a 2D memref or ui64.";
+ return failure();
+ }
+ auto createOffset = [&](unsigned idx) -> Value {
+ Value val;
+ OpFoldResult ofr = op.getMixedOffsets()[idx];
+ if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+ val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+ val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+ } else {
+ int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+ val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+ }
+ return val;
+ };
+ auto offsets = op.getMixedOffsets();
+ if (offsets.size() == 2) {
+ offsetW = createOffset(rank - 1);
+ offsetH = createOffset(rank - 2);
+ } else {
+ offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
+ }
+ auto createShape = [&](unsigned idx) -> Value {
+ Value val;
+ OpFoldResult ofr = op.getMixedSizes()[idx];
+ if (auto v = llvm::dyn_cast_if_present<Value>(ofr)) {
+ val = arith::IndexCastOp::create(rewriter, loc, i64Ty, v);
+ val = arith::TruncIOp::create(rewriter, loc, payloadElemTy, val);
+ } else {
+ int32_t off = llvm::cast<IntegerAttr>(cast<Attribute>(ofr)).getInt();
+ val = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, off);
+ }
+ return val;
+ };
+ if (sourceIsMemref) {
+ auto sourceMemrefTy = cast<MemRefType>(sourceTy);
+ baseShapeW = arith::ConstantIntOp::create(
+ rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 1));
+ baseShapeH = arith::ConstantIntOp::create(
+ rewriter, loc, payloadElemTy, sourceMemrefTy.getDimSize(rank - 2));
+ baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
+ } else {
+ baseShapeW = createShape(rank - 1);
+ baseShapeH = createShape(rank - 2);
+ baseAddr = adaptor.getSource();
+ }
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
+ payLoadAsI64 =
+ vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
+ static_cast<int>(NdDescI32Layout::BasePtr));
+ payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
+ static_cast<int>(NdDescI32Layout::BaseShapeW));
+ payload =
+ vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
+ static_cast<int>(NdDescI32Layout::BaseShapeH));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetW, payload,
+ static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ payload = vector::InsertOp::create(
+ rewriter, loc, offsetH, payload,
+ static_cast<int>(NdDescI32Layout::TensorOffsetH));
+ rewriter.replaceOp(op, payload);
+ return success();
+ }
+};
+
+class UpdateNdOffsetToXeVMPattern
+ : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
+ using OpConversionPattern::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(xegpu::UpdateNdOffsetOp op,
+ xegpu::UpdateNdOffsetOp::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto offsets = op.getOffsets();
+ auto tdesc = adaptor.getTensorDesc();
+ for (size_t offsetDim = 0; offsetDim < offsets.size(); offsetDim++) {
+ auto offset = offsets[offsetDim];
+ if (auto cst =
+ dyn_cast_if_present<arith::ConstantOp>(offset.getDefiningOp()))
+ if (auto attr = dyn_cast_if_present<IntegerAttr>(cst.getValue());
+ attr && !attr.getInt())
+ continue;
+ const int offsetPos =
+ static_cast<int>(offsetDim ? NdDescI32Layout::TensorOffsetW
+ : NdDescI32Layout::TensorOffsetH);
+ auto oldOffset =
+ vector::ExtractOp::create(rewriter, loc, tdesc, offsetPos);
+ offset = arith::IndexCastUIOp::create(rewriter, loc,
+ rewriter.getI32Type(), offset);
+ auto newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
+ tdesc =
+ vector::InsertOp::create(rewriter, loc, newOffset, tdesc, offsetPos);
+ }
+ rewriter.replaceOp(op, tdesc);
+ return success();
+ }
+};
+
+template <
+ typename OpType,
+ typename = std::enable_if_t<llvm::is_one_of<
+ OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
+class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ auto loc = op.getLoc();
+ auto ctxt = rewriter.getContext();
+
+ auto tdesc = adaptor.getTensorDesc();
+ auto tdescTy = op.getTensorDescType();
+ if (tdescTy.getRank() != 2) {
+ return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
+ }
+
+ VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
+ Value payLoadAsI64 =
+ vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
+ Value basePtr =
+ vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
+ static_cast<int>(NdDescI32Layout::BasePtr));
+ Value baseShapeW = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeW));
+ Value baseShapeH = vector::ExtractOp::create(
+ rewriter, loc, tdesc, static_cast<int>(NdDescI32Layout::BaseShapeH));
+ // Offsets can come from three sources:
+ // 1. Constant offsets, which are provided by the op.
+ // 2. Offsets as operands, which are provided by the op.
+ // 3. Offsets extracted from the tensor descriptor.
+ Value offsetW;
+ Value offsetH;
+ auto cOffsets = op.getConstOffsets();
+ auto offsets = op.getOffsets();
+ if (cOffsets) {
+ offsetW = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), (*cOffsets)[0]);
+ offsetH = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), (*cOffsets)[1]);
+ } else if (offsets.size() != 0) {
+ // offsets are provided as operands
+ if (offsets[0].getType() != rewriter.getI32Type()) {
+ if (offsets[0].getType() != rewriter.getIndexType()) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected offsets to be of type i32 or index.");
+ }
+ offsetW = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), offsets[0]);
+ } else {
+ offsetW = offsets[0];
+ }
+ if (offsets[1].getType() != rewriter.getI32Type()) {
+ if (offsets[1].getType() != rewriter.getIndexType()) {
+ return rewriter.notifyMatchFailure(
+ op, "Expected offsets to be of type i32 or index.");
+ }
+ offsetH = arith::IndexCastUIOp::create(
+ rewriter, loc, rewriter.getI32Type(), offsets[1]);
+ } else {
+ offsetH = offsets[1];
+ }
+ } else {
+ // If offsets are not available, we need to extract them from the tensor
+ // descriptor.
+ offsetW = vector::ExtractOp::create(
+ rewriter, loc, tdesc,
+ static_cast<int>(NdDescI32Layout::TensorOffsetW));
+ offsetH = vector::ExtractOp::create(
+ rewriter, loc, tdesc,
+ static_cast<int>(NdDescI32Layout::TensorOffsetH));
+ }
+ auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
+ ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
+ Value basePtrLLVM =
+ LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
+ auto elemType = tdescTy.getElementType();
+ auto elemBitSize = elemType.getIntOrFloatBitWidth();
+ // auto elemBitSizeAttr = rewriter.getIntegerAttr(rewriter.getI32Type(),
+ // elemBitSize);
+ Value elemByteSize = arith::ConstantIntOp::create(
+ rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
+ Value surfaceW =
+ arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
+
+ auto tileW = tdescTy.getDimSize(1);
+ auto tileH = tdescTy.getDimSize(0);
+ int32_t vblocks = tdescTy.getArrayLength();
+ if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
+ VectorType srcVecTy = cast<VectorType>(op.getValue().getType());
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.get...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Dead code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Dead code. Removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we deliberately avoid a more natural i64 (which is also part of XeGPU_BaseAddrType) and only support ui64? I would imagine the opposite. Did I miss a good reason to only rely on ui64?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing out.
That was my misunderstanding, still thought ui64 is only base address form other than memref.
Will need to add support for other base addr types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure it should really be an enum, since all usages involve static casts to int, maybe some constexprs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if individual constexprs are better. Code readability would get worse as the logical grouping of offsets is gone.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
XeGPU definition says:
llvm-project/mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Lines 504 to 507 in fdfcebb
| * offsets: a vector containing offsets of each access point. Its size | |
| is fixed to the hardware supportted subgroup size, e.g., 16 on PVC, | |
| implying each element in the vector corresponds to a work-item (SIMT lane) | |
| in the subgroup. |
According to your tests and laneAddr calculation (without extract ops to get lane offset), you treat it as a single-element vector (i.e., SIMT offset), which contradicts the op description. Please clarify the conflict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That part of the description needs to be updated.
vector of offsets fit subgroup level, but after SIMT distribution,
ops consuming the descriptor are no longer cooperative and becomes per lane regular load / store / prefetch.
vector<1xindex> fits that abstraction better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume in this case vector<1xindex> will be converted by the LLVM type converter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, please modify the op description with more explanation to avoid confusion, as there is no obvious intermediate step in-between the description (offsets vector) and lowering (scalar offset).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will create a separate PR to update the description. That one should merge before this one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for mask. I think it is best to give a pseudo IR example of how a description-conformant op transforms into the expected lowering input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actual transform will be done in SIMT distribution pass, but will add example in op description to help user understand valid op form.
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a comment describing what this class mean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure about this. const offsets will always be present. if some offset is dynamic it will be present as kDynamic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. using getMixedOffsets can guarantee every dim contains meaningful values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using getMixedOffsets now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can be done in a loop to make it future proof.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactored code. Not loop based but scalable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment.
|
@akroviakov Created #154653 that cleans up XeGPU ops and update descriptions. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to fail if (offsets) check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. offsets is optional and returns a Value{} which evaluates to false if not present.
if source / dest is a tensor descriptor, offsets is not allowed. Offsets are coming from tensor descriptor in such case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be nice to make expose these by moving the custom conversion into a separate populate*TypeConversions function. Then others can still use the conversion patterns in custom passes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not yet clear if these conversions are reusable.
Would try to refactor in a future PR after giving some thoughts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, are the patterns in populateXeGPUToXeVMConversionPatterns usable without these conversions?
Conversion might not be reusable in a generic sense but without these type conversions I'm forced to use this pass and can't freely add patterns to my custom wrapper etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am curious is it ok for integer sourceTy that rank != 2?
getMixedSizes also works on memref type. So it is safe to use it at the beginning. Just in case you need to check rank for all types of sourceTy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the tip. Using getMixedSizes now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure whether some utils in arith can help to simplify the implementation here.
e.g., getValueOrCreateConstantIntOp and getValueOrCreateCastToIndexLike
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Those utils are very useful. Now the code sequence is much shorter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this assume rank could be larger than 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added restriction to stick to 2D.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks very similar to createOffset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Combined them into a single lambda.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getMixedSize should cover memref and integer type source. So for baseShapeW and baseShapeH, there is no need to check the sourceTy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the check up and using getMixedSize now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this assuming offsets to be 2D?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Added offset rank restriction and cleaned up the code. Should read better now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems getConstantIntValue can be used here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated code sequence.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1. using getMixedOffsets can guarantee every dim contains meaningful values.
chencha3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have a general question. Since we are retiring ops, e.g., CreateNdOp with offsets, and LoadNd/StoreNd/PrefetchNd without offsets. The support for these cases are still needed in this pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: braces are not needed for single statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed redundant braces.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems BaseShapeH e.t.c, represents related positions in the playload. Not sure whether NdDescI32Layout is an appropriate name. But adding a comments for NdDescI32Layout seems necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to NdTdescOffset since the enum is about offsets to payload.
Added comments as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By the definition, the offsets are always index type. is this I32 type check still needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated code sequence. The check does not apply anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since it shares similar code with offsets[0], may be worth to have lambda for it. I am not sure whether using isa<I32>() is better than rewriter.getI32Type(). I had an impression when I push the first PR, medhi metioned methods e.g., rewriter.getI32Type is heavy. Maybe I am wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated code. Now using MLIR util functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe using adaptor.getSource() too to make it consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now using adaptor.getSource() whenever used as an operand.
I'm not sure. But from a coverage point for view would be better to have support. |
|
@adam-smnk @chencha3 @akroviakov @charithaintc |
chencha3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % some comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe can mention NdTdescOffset in the comment. It may help understanding the payload a little bit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added comment about offset enum
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move payloadTy one line up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using adaptor.getSource()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is intentional, adaptor.getSource() for memref yields i64 for base address as memref is used like raw pointers. Whenever memref type info is required, need to use op.getSource().getType() to get original memref type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: move the check up? or to line 180?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved the check up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not just use 0 and 1 to replace rank - 2 and rank - 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: this check maybe not necessary, it seems be handled by the op definition?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked the op definition and it doen't restrict rank.
fbfd800 to
b01086e
Compare
charithaintc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have additional concerns. thanks for the great work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume in this case vector<1xindex> will be converted by the LLVM type converter?
| // Below are uArch dependent values, should move away from hardcoding | ||
| constexpr int32_t systolicDepth{8}; | ||
| constexpr int32_t executionSize{16}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would use static constexpr and move them to the top of file with a TODO note, so we remember to clean this up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
| unsigned sum = 1; | ||
| for (unsigned i = 0; i < rank; i++) { | ||
| sum *= type.getShape()[i]; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| unsigned sum = 1; | |
| for (unsigned i = 0; i < rank; i++) { | |
| sum *= type.getShape()[i]; | |
| } | |
| int64_t sum = std::accumulate(type.getShape().begin(), | |
| type.getShape().end(), int64_t{1}, | |
| std::multiplies<int64_t>()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
|
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/18234 Here is the relevant piece of the build log for the reference |
Add missing dependency MLIRSCFDialect from #154556
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/204/builds/20154 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/203/builds/21342 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/205/builds/20131 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/80/builds/15676 Here is the relevant piece of the build log for the reference |
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/130/builds/15137 Here is the relevant piece of the build log for the reference |
Add XeGPU to XeVM conversion pass and tests.