Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][sparse] connect MapRef's lvl2dim with latest AffineMap computation #69540

Merged
merged 1 commit into from
Oct 19, 2023

Conversation

aartbik
Copy link
Contributor

@aartbik aartbik commented Oct 19, 2023

This makes sure

  • GEN MAP dim=2 lvl=4
    (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2)
    --
    (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3)

is indeed encoded as

MAP-REF (dim=2, lvl=4) isperm=0
d2l = [ d0/2 d1/2 d0%2 d1%2 ]
ld2 = [ l2+2l0 l3+2l1 ]

@llvmbot llvmbot added mlir:sparse Sparse compiler in MLIR mlir labels Oct 19, 2023
@llvmbot
Copy link

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir-sparse

@llvm/pr-subscribers-mlir

Author: Aart Bik (aartbik)

Changes

This makes sure

  • GEN MAP dim=2 lvl=4
    (d0, d1) -> (d0 floordiv 2, d1 floordiv 2, d0 mod 2, d1 mod 2)
    --
    (d0, d1, d2, d3) -> (d0 * 2 + d2, d1 * 2 + d3)

is indeed encoded as

MAP-REF (dim=2, lvl=4) isperm=0
d2l = [ d0/2 d1/2 d0%2 d1%2 ]
ld2 = [ l2+2l0 l3+2l1 ]


Full diff: https://github.com/llvm/llvm-project/pull/69540.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp (+42-20)
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
index 98b412c8ec9eb5b..b1b1d67ac2d420d 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp
@@ -691,6 +691,7 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
   // This code deals with permutations as well as non-permutations that
   // arise from rank changing blocking.
   const auto dimToLvl = stt.getDimToLvl();
+  const auto lvlToDim = stt.getLvlToDim();
   SmallVector<Value> dim2lvlValues(lvlRank); // for each lvl, expr in dim vars
   SmallVector<Value> lvl2dimValues(dimRank); // for each dim, expr in lvl vars
   SmallVector<Value> lvlSizesValues(lvlRank);
@@ -705,34 +706,26 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
     Dimension d = 0;
     uint64_t cf = 0, cm = 0;
     switch (exp.getKind()) {
-    case AffineExprKind::DimId:
+    case AffineExprKind::DimId: {
       d = exp.cast<AffineDimExpr>().getPosition();
       break;
-    case AffineExprKind::FloorDiv:
-      d = exp.cast<AffineBinaryOpExpr>()
-              .getLHS()
-              .cast<AffineDimExpr>()
-              .getPosition();
-      cf = exp.cast<AffineBinaryOpExpr>()
-               .getRHS()
-               .cast<AffineConstantExpr>()
-               .getValue();
+    }
+    case AffineExprKind::FloorDiv: {
+      auto floor = exp.cast<AffineBinaryOpExpr>();
+      d = floor.getLHS().cast<AffineDimExpr>().getPosition();
+      cf = floor.getRHS().cast<AffineConstantExpr>().getValue();
       break;
-    case AffineExprKind::Mod:
-      d = exp.cast<AffineBinaryOpExpr>()
-              .getLHS()
-              .cast<AffineDimExpr>()
-              .getPosition();
-      cm = exp.cast<AffineBinaryOpExpr>()
-               .getRHS()
-               .cast<AffineConstantExpr>()
-               .getValue();
+    }
+    case AffineExprKind::Mod: {
+      auto mod = exp.cast<AffineBinaryOpExpr>();
+      d = mod.getLHS().cast<AffineDimExpr>().getPosition();
+      cm = mod.getRHS().cast<AffineConstantExpr>().getValue();
       break;
+    }
     default:
       llvm::report_fatal_error("unsupported dim2lvl in sparse tensor type");
     }
     dim2lvlValues[l] = constantIndex(builder, loc, encodeDim(d, cf, cm));
-    lvl2dimValues[d] = constantIndex(builder, loc, l); // FIXME, use lvlToDim
     // Compute the level sizes.
     //    (1) l = d        : size(d)
     //    (2) l = d / c    : size(d) / c
@@ -751,6 +744,35 @@ Value sparse_tensor::genMapBuffers(OpBuilder &builder, Location loc,
     }
     lvlSizesValues[l] = lvlSz;
   }
+  // Generate lvl2dim.
+  assert(dimRank == lvlToDim.getNumResults());
+  for (Dimension d = 0; d < dimRank; d++) {
+    AffineExpr exp = lvlToDim.getResult(d);
+    // We expect:
+    //    (1) d = l
+    //    (2) d = l' * c + l
+    Level l = 0, ll = 0;
+    uint64_t c = 0;
+    switch (exp.getKind()) {
+    case AffineExprKind::DimId: {
+      l = exp.cast<AffineDimExpr>().getPosition();
+      break;
+    }
+    case AffineExprKind::Add: {
+      // Always mul on lhs, symbol/constant on rhs.
+      auto add = exp.cast<AffineBinaryOpExpr>();
+      assert(add.getLHS().getKind() == AffineExprKind::Mul);
+      auto mul = add.getLHS().cast<AffineBinaryOpExpr>();
+      ll = mul.getLHS().cast<AffineDimExpr>().getPosition();
+      c = mul.getRHS().cast<AffineConstantExpr>().getValue();
+      l = add.getRHS().cast<AffineDimExpr>().getPosition();
+      break;
+    }
+    default:
+      llvm::report_fatal_error("unsupported lvl2dim in sparse tensor type");
+    }
+    lvl2dimValues[d] = constantIndex(builder, loc, encodeLvl(l, c, ll));
+  }
   // Return buffers.
   dim2lvlBuffer = allocaBuffer(builder, loc, dim2lvlValues);
   lvl2dimBuffer = allocaBuffer(builder, loc, lvl2dimValues);

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:sparse Sparse compiler in MLIR mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants