-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][sparse] simplify reader construction of new sparse tensor #69036
Conversation
Making the materialize-from-reader method part of the swiss army knife suite agains removes a lot of redundant boiler plate code and unifies the parameter setup into a single centralized utility. Furthermore, we now have minimized the number of entry points into the library that need a non-permutation map setup, simplifying what comes next
@llvm/pr-subscribers-mlir-sparse @llvm/pr-subscribers-mlir Author: Aart Bik (aartbik) ChangesMaking the materialize-from-reader method part of the swiss army knife suite agains removes a lot of redundant boiler plate code and unifies the parameter setup into a single centralized utility. Furthermore, we now have minimized the number of entry points into the library that need a non-permutation map setup, simplifying what comes next Full diff: https://github.com/llvm/llvm-project/pull/69036.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
index 1434c649acd29b4..0caf83a63b531f2 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/Enums.h
@@ -146,6 +146,7 @@ enum class Action : uint32_t {
kEmptyForward = 1,
kFromCOO = 2,
kSparseToSparse = 3,
+ kFromReader = 4,
kToCOO = 5,
kPack = 7,
kSortCOOInPlace = 8,
diff --git a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
index e8dd50d6730c784..a470afc2f0c8cd1 100644
--- a/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
+++ b/mlir/include/mlir/ExecutionEngine/SparseTensorRuntime.h
@@ -115,16 +115,6 @@ MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_createCheckedSparseTensorReader(
char *filename, StridedMemRefType<index_type, 1> *dimShapeRef,
PrimaryType valTp);
-/// Constructs a new sparse-tensor storage object with the given encoding,
-/// initializes it by reading all the elements from the file, and then
-/// closes the file.
-MLIR_CRUNNERUTILS_EXPORT void *_mlir_ciface_newSparseTensorFromReader(
- void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
- StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
- StridedMemRefType<index_type, 1> *dim2lvlRef,
- StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
- OverheadType crdTp, PrimaryType valTp);
-
/// SparseTensorReader method to obtain direct access to the
/// dimension-sizes array.
MLIR_CRUNNERUTILS_EXPORT void _mlir_ciface_getSparseTensorReaderDimSizes(
@@ -197,24 +187,9 @@ MLIR_SPARSETENSOR_FOREVERY_V(DECL_DELCOO)
/// defined with the naming convention ${TENSOR0}, ${TENSOR1}, etc.
MLIR_CRUNNERUTILS_EXPORT char *getTensorFilename(index_type id);
-/// Helper function to read the header of a file and return the
-/// shape/sizes, without parsing the elements of the file.
-MLIR_CRUNNERUTILS_EXPORT void readSparseTensorShape(char *filename,
- std::vector<uint64_t> *out);
-
-/// Returns the rank of the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderRank(void *p);
-
-/// Returns the is_symmetric bit for the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT bool getSparseTensorReaderIsSymmetric(void *p);
-
/// Returns the number of stored elements for the sparse tensor being read.
MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderNSE(void *p);
-/// Returns the size of a dimension for the sparse tensor being read.
-MLIR_CRUNNERUTILS_EXPORT index_type getSparseTensorReaderDimSize(void *p,
- index_type d);
-
/// Releases the SparseTensorReader and closes the associated file.
MLIR_CRUNNERUTILS_EXPORT void delSparseTensorReader(void *p);
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index a76f81410aa87a0..638475a80343d91 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -199,12 +199,15 @@ class NewCallParams final {
/// type-level information such as the encoding and sizes), generating
/// MLIR buffers as needed, and returning `this` for method chaining.
NewCallParams &genBuffers(SparseTensorType stt,
- ArrayRef<Value> dimSizesValues) {
+ ArrayRef<Value> dimSizesValues,
+ Value dimSizesBuffer = Value()) {
assert(dimSizesValues.size() == static_cast<size_t>(stt.getDimRank()));
// Sparsity annotations.
params[kParamLvlTypes] = genLvlTypesBuffer(builder, loc, stt);
// Construct dimSizes, lvlSizes, dim2lvl, and lvl2dim buffers.
- params[kParamDimSizes] = allocaBuffer(builder, loc, dimSizesValues);
+ params[kParamDimSizes] = dimSizesBuffer
+ ? dimSizesBuffer
+ : allocaBuffer(builder, loc, dimSizesValues);
params[kParamLvlSizes] =
genMapBuffers(builder, loc, stt, dimSizesValues, params[kParamDimSizes],
params[kParamDim2Lvl], params[kParamLvl2Dim]);
@@ -342,33 +345,15 @@ class SparseTensorNewConverter : public OpConversionPattern<NewOp> {
const auto stt = getSparseTensorType(op);
if (!stt.hasEncoding())
return failure();
- // Construct the reader opening method calls.
+ // Construct the `reader` opening method calls.
SmallVector<Value> dimShapesValues;
Value dimSizesBuffer;
Value reader = genReader(rewriter, loc, stt, adaptor.getOperands()[0],
dimShapesValues, dimSizesBuffer);
- // Now construct the lvlSizes, dim2lvl, and lvl2dim buffers.
- Value dim2lvlBuffer;
- Value lvl2dimBuffer;
- Value lvlSizesBuffer =
- genMapBuffers(rewriter, loc, stt, dimShapesValues, dimSizesBuffer,
- dim2lvlBuffer, lvl2dimBuffer);
// Use the `reader` to parse the file.
- Type opaqueTp = getOpaquePointerType(rewriter);
- Type eltTp = stt.getElementType();
- Value valTp = constantPrimaryTypeEncoding(rewriter, loc, eltTp);
- SmallVector<Value, 8> params{
- reader,
- lvlSizesBuffer,
- genLvlTypesBuffer(rewriter, loc, stt),
- dim2lvlBuffer,
- lvl2dimBuffer,
- constantPosTypeEncoding(rewriter, loc, stt.getEncoding()),
- constantCrdTypeEncoding(rewriter, loc, stt.getEncoding()),
- valTp};
- Value tensor = createFuncCall(rewriter, loc, "newSparseTensorFromReader",
- opaqueTp, params, EmitCInterface::On)
- .getResult(0);
+ Value tensor = NewCallParams(rewriter, loc)
+ .genBuffers(stt, dimShapesValues, dimSizesBuffer)
+ .genNewCall(Action::kFromReader, reader);
// Free the memory for `reader`.
createFuncCall(rewriter, loc, "delSparseTensorReader", {}, {reader},
EmitCInterface::Off);
diff --git a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
index ae33a869497a01c..fbd98f6cf183793 100644
--- a/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
+++ b/mlir/lib/ExecutionEngine/SparseTensorRuntime.cpp
@@ -138,6 +138,12 @@ extern "C" {
dimRank, dimSizes, lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim, \
dimRank, tensor); \
} \
+ case Action::kFromReader: { \
+ assert(ptr && "Received nullptr for SparseTensorReader object"); \
+ SparseTensorReader &reader = *static_cast<SparseTensorReader *>(ptr); \
+ return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
+ lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim)); \
+ } \
case Action::kToCOO: { \
assert(ptr && "Received nullptr for SparseTensorStorage object"); \
auto &tensor = *static_cast<SparseTensorStorage<P, C, V> *>(ptr); \
@@ -442,113 +448,6 @@ void _mlir_ciface_getSparseTensorReaderDimSizes(
MLIR_SPARSETENSOR_FOREVERY_V_O(IMPL_GETNEXT)
#undef IMPL_GETNEXT
-void *_mlir_ciface_newSparseTensorFromReader(
- void *p, StridedMemRefType<index_type, 1> *lvlSizesRef,
- StridedMemRefType<DimLevelType, 1> *lvlTypesRef,
- StridedMemRefType<index_type, 1> *dim2lvlRef,
- StridedMemRefType<index_type, 1> *lvl2dimRef, OverheadType posTp,
- OverheadType crdTp, PrimaryType valTp) {
- assert(p);
- SparseTensorReader &reader = *static_cast<SparseTensorReader *>(p);
- ASSERT_NO_STRIDE(lvlSizesRef);
- ASSERT_NO_STRIDE(lvlTypesRef);
- ASSERT_NO_STRIDE(dim2lvlRef);
- ASSERT_NO_STRIDE(lvl2dimRef);
- const uint64_t dimRank = reader.getRank();
- const uint64_t lvlRank = MEMREF_GET_USIZE(lvlSizesRef);
- ASSERT_USIZE_EQ(lvlTypesRef, lvlRank);
- ASSERT_USIZE_EQ(dim2lvlRef, dimRank);
- ASSERT_USIZE_EQ(lvl2dimRef, lvlRank);
- (void)dimRank;
- const index_type *lvlSizes = MEMREF_GET_PAYLOAD(lvlSizesRef);
- const DimLevelType *lvlTypes = MEMREF_GET_PAYLOAD(lvlTypesRef);
- const index_type *dim2lvl = MEMREF_GET_PAYLOAD(dim2lvlRef);
- const index_type *lvl2dim = MEMREF_GET_PAYLOAD(lvl2dimRef);
-#define CASE(p, c, v, P, C, V) \
- if (posTp == OverheadType::p && crdTp == OverheadType::c && \
- valTp == PrimaryType::v) \
- return static_cast<void *>(reader.readSparseTensor<P, C, V>( \
- lvlRank, lvlSizes, lvlTypes, dim2lvl, lvl2dim));
-#define CASE_SECSAME(p, v, P, V) CASE(p, p, v, P, P, V)
- // Rewrite kIndex to kU64, to avoid introducing a bunch of new cases.
- // This is safe because of the static_assert above.
- if (posTp == OverheadType::kIndex)
- posTp = OverheadType::kU64;
- if (crdTp == OverheadType::kIndex)
- crdTp = OverheadType::kU64;
- // Double matrices with all combinations of overhead storage.
- CASE(kU64, kU64, kF64, uint64_t, uint64_t, double);
- CASE(kU64, kU32, kF64, uint64_t, uint32_t, double);
- CASE(kU64, kU16, kF64, uint64_t, uint16_t, double);
- CASE(kU64, kU8, kF64, uint64_t, uint8_t, double);
- CASE(kU32, kU64, kF64, uint32_t, uint64_t, double);
- CASE(kU32, kU32, kF64, uint32_t, uint32_t, double);
- CASE(kU32, kU16, kF64, uint32_t, uint16_t, double);
- CASE(kU32, kU8, kF64, uint32_t, uint8_t, double);
- CASE(kU16, kU64, kF64, uint16_t, uint64_t, double);
- CASE(kU16, kU32, kF64, uint16_t, uint32_t, double);
- CASE(kU16, kU16, kF64, uint16_t, uint16_t, double);
- CASE(kU16, kU8, kF64, uint16_t, uint8_t, double);
- CASE(kU8, kU64, kF64, uint8_t, uint64_t, double);
- CASE(kU8, kU32, kF64, uint8_t, uint32_t, double);
- CASE(kU8, kU16, kF64, uint8_t, uint16_t, double);
- CASE(kU8, kU8, kF64, uint8_t, uint8_t, double);
- // Float matrices with all combinations of overhead storage.
- CASE(kU64, kU64, kF32, uint64_t, uint64_t, float);
- CASE(kU64, kU32, kF32, uint64_t, uint32_t, float);
- CASE(kU64, kU16, kF32, uint64_t, uint16_t, float);
- CASE(kU64, kU8, kF32, uint64_t, uint8_t, float);
- CASE(kU32, kU64, kF32, uint32_t, uint64_t, float);
- CASE(kU32, kU32, kF32, uint32_t, uint32_t, float);
- CASE(kU32, kU16, kF32, uint32_t, uint16_t, float);
- CASE(kU32, kU8, kF32, uint32_t, uint8_t, float);
- CASE(kU16, kU64, kF32, uint16_t, uint64_t, float);
- CASE(kU16, kU32, kF32, uint16_t, uint32_t, float);
- CASE(kU16, kU16, kF32, uint16_t, uint16_t, float);
- CASE(kU16, kU8, kF32, uint16_t, uint8_t, float);
- CASE(kU8, kU64, kF32, uint8_t, uint64_t, float);
- CASE(kU8, kU32, kF32, uint8_t, uint32_t, float);
- CASE(kU8, kU16, kF32, uint8_t, uint16_t, float);
- CASE(kU8, kU8, kF32, uint8_t, uint8_t, float);
- // Two-byte floats with both overheads of the same type.
- CASE_SECSAME(kU64, kF16, uint64_t, f16);
- CASE_SECSAME(kU64, kBF16, uint64_t, bf16);
- CASE_SECSAME(kU32, kF16, uint32_t, f16);
- CASE_SECSAME(kU32, kBF16, uint32_t, bf16);
- CASE_SECSAME(kU16, kF16, uint16_t, f16);
- CASE_SECSAME(kU16, kBF16, uint16_t, bf16);
- CASE_SECSAME(kU8, kF16, uint8_t, f16);
- CASE_SECSAME(kU8, kBF16, uint8_t, bf16);
- // Integral matrices with both overheads of the same type.
- CASE_SECSAME(kU64, kI64, uint64_t, int64_t);
- CASE_SECSAME(kU64, kI32, uint64_t, int32_t);
- CASE_SECSAME(kU64, kI16, uint64_t, int16_t);
- CASE_SECSAME(kU64, kI8, uint64_t, int8_t);
- CASE_SECSAME(kU32, kI64, uint32_t, int64_t);
- CASE_SECSAME(kU32, kI32, uint32_t, int32_t);
- CASE_SECSAME(kU32, kI16, uint32_t, int16_t);
- CASE_SECSAME(kU32, kI8, uint32_t, int8_t);
- CASE_SECSAME(kU16, kI64, uint16_t, int64_t);
- CASE_SECSAME(kU16, kI32, uint16_t, int32_t);
- CASE_SECSAME(kU16, kI16, uint16_t, int16_t);
- CASE_SECSAME(kU16, kI8, uint16_t, int8_t);
- CASE_SECSAME(kU8, kI64, uint8_t, int64_t);
- CASE_SECSAME(kU8, kI32, uint8_t, int32_t);
- CASE_SECSAME(kU8, kI16, uint8_t, int16_t);
- CASE_SECSAME(kU8, kI8, uint8_t, int8_t);
- // Complex matrices with wide overhead.
- CASE_SECSAME(kU64, kC64, uint64_t, complex64);
- CASE_SECSAME(kU64, kC32, uint64_t, complex32);
-
- // Unsupported case (add above if needed).
- MLIR_SPARSETENSOR_FATAL(
- "unsupported combination of types: <P=%d, C=%d, V=%d>\n",
- static_cast<int>(posTp), static_cast<int>(crdTp),
- static_cast<int>(valTp));
-#undef CASE_SECSAME
-#undef CASE
-}
-
void _mlir_ciface_outSparseTensorWriterMetaData(
void *p, index_type dimRank, index_type nse,
StridedMemRefType<index_type, 1> *dimSizesRef) {
@@ -635,34 +534,10 @@ char *getTensorFilename(index_type id) {
return env;
}
-void readSparseTensorShape(char *filename, std::vector<uint64_t> *out) {
- assert(out && "Received nullptr for out-parameter");
- SparseTensorReader reader(filename);
- reader.openFile();
- reader.readHeader();
- reader.closeFile();
- const uint64_t dimRank = reader.getRank();
- const uint64_t *dimSizes = reader.getDimSizes();
- out->reserve(dimRank);
- out->assign(dimSizes, dimSizes + dimRank);
-}
-
-index_type getSparseTensorReaderRank(void *p) {
- return static_cast<SparseTensorReader *>(p)->getRank();
-}
-
-bool getSparseTensorReaderIsSymmetric(void *p) {
- return static_cast<SparseTensorReader *>(p)->isSymmetric();
-}
-
index_type getSparseTensorReaderNSE(void *p) {
return static_cast<SparseTensorReader *>(p)->getNSE();
}
-index_type getSparseTensorReaderDimSize(void *p, index_type d) {
- return static_cast<SparseTensorReader *>(p)->getDimSize(d);
-}
-
void delSparseTensorReader(void *p) {
delete static_cast<SparseTensorReader *>(p);
}
diff --git a/mlir/test/Dialect/SparseTensor/conversion.mlir b/mlir/test/Dialect/SparseTensor/conversion.mlir
index 96300a98a6a4bc5..2ff4887dae7b8c9 100644
--- a/mlir/test/Dialect/SparseTensor/conversion.mlir
+++ b/mlir/test/Dialect/SparseTensor/conversion.mlir
@@ -78,11 +78,11 @@ func.func @sparse_dim3d_const(%arg0: tensor<10x20x30xf64, #SparseTensor>) -> ind
// CHECK-DAG: %[[DimShape0:.*]] = memref.alloca() : memref<1xindex>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<1xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
-// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
-// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<1xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<1xi8> to memref<?xi8>
-// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<1xindex>
+// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<1xindex> to memref<?xindex>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimShape]], %[[DimShape]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector> {
@@ -96,11 +96,11 @@ func.func @sparse_new1d(%arg0: !llvm.ptr<i8>) -> tensor<128xf64, #SparseVector>
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<2xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
-// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<2xi8>
// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<2xi8> to memref<?xi8>
-// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}})
+// CHECK-DAG: %[[Iota0:.*]] = memref.alloca() : memref<2xindex>
+// CHECK-DAG: %[[Iota:.*]] = memref.cast %[[Iota0]] : memref<2xindex> to memref<?xindex>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[DimSizes]], %[[LvlTypes]], %[[Iota]], %[[Iota]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
@@ -114,15 +114,15 @@ func.func @sparse_new2d(%arg0: !llvm.ptr<i8>) -> tensor<?x?xf32, #CSR> {
// CHECK-DAG: %[[DimShape:.*]] = memref.cast %[[DimShape0]] : memref<3xindex> to memref<?xindex>
// CHECK: %[[Reader:.*]] = call @createCheckedSparseTensorReader(%[[A]], %[[DimShape]], %{{.*}})
// CHECK: %[[DimSizes:.*]] = call @getSparseTensorReaderDimSizes(%[[Reader]])
-// CHECK: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
-// CHECK: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
-// CHECK: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
-// CHECK: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
-// CHECK: %[[T:.*]] = call @newSparseTensorFromReader(%[[Reader]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}})
+// CHECK-DAG: %[[LvlTypes0:.*]] = memref.alloca() : memref<3xi8>
+// CHECK-DAG: %[[LvlTypes:.*]] = memref.cast %[[LvlTypes0]] : memref<3xi8> to memref<?xi8>
+// CHECK-DAG: %[[Dim2Lvl0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG: %[[Dim2Lvl:.*]] = memref.cast %[[Dim2Lvl0]] : memref<3xindex> to memref<?xindex>
+// CHECK-DAG: %[[Lvl2Dim0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG: %[[Lvl2Dim:.*]] = memref.cast %[[Lvl2Dim0]] : memref<3xindex> to memref<?xindex>
+// CHECK-DAG: %[[LvlSizes0:.*]] = memref.alloca() : memref<3xindex>
+// CHECK-DAG: %[[LvlSizes:.*]] = memref.cast %[[LvlSizes0]] : memref<3xindex> to memref<?xindex>
+// CHECK: %[[T:.*]] = call @newSparseTensor(%[[DimSizes]], %[[LvlSizes]], %[[LvlTypes]], %[[Dim2Lvl]], %[[Lvl2Dim]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %[[Reader]])
// CHECK: call @delSparseTensorReader(%[[Reader]])
// CHECK: return %[[T]] : !llvm.ptr<i8>
func.func @sparse_new3d(%arg0: !llvm.ptr<i8>) -> tensor<?x?x?xf32, #SparseTensor> {
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One nit: in the PR description, you have a typo "agains".
Making the materialize-from-reader method part of the Swiss army knife suite again removes a lot of redundant boiler plate code and unifies the parameter setup into a single centralized utility. Furthermore, we now have minimized the number of entry points into the library that need a non-permutation map setup, simplifying what comes next