Skip to content

Commit

Permalink
Relax dot operand constrains with FMA based dot
Browse files Browse the repository at this point in the history
This PR:
- Refactors FMA dot implementation
- Supports dot3d in FMA path
- Fixes several issues in operand offset computation
- Enables small dot operands
  • Loading branch information
binarman committed Aug 15, 2024
1 parent 45af9a9 commit 68350e9
Show file tree
Hide file tree
Showing 12 changed files with 338 additions and 308 deletions.
16 changes: 16 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,22 @@ inline bool isLayoutMmaV1(Attribute layout) {
return isMmaV1;
}

inline SharedMemoryObject
getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
SharedMemoryObject smemObj,
ArrayRef<int64_t> shape) {
auto strides = smemObj.getStrides();
auto offsets = smemObj.getOffsets();
auto rank = strides.size();
if (rank == 3)
return smemObj;
strides.insert(strides.begin(), i32_val(shape[0] * shape[1]));
offsets.insert(offsets.begin(), i32_val(0));
auto expandedSmemObj = SharedMemoryObject(
smemObj.getBase(), smemObj.getBaseElemType(), strides, offsets);
return expandedSmemObj;
}

} // namespace mlir

#endif
16 changes: 16 additions & 0 deletions include/triton/Dialect/TritonGPU/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ void dumpHWLayout(RankedTensorType tensorType);
// Return a string representation of the layout of the tensor.
std::string getLayoutStr(RankedTensorType tensorType, bool useHWPointOfView);

template <typename T>
llvm::SmallVector<T> expandMatrixShapeWithBatch(llvm::ArrayRef<T> s) {
llvm::SmallVector<T> expanded(3 - s.size(), 1);
expanded.append(s.begin(), s.end());
return expanded;
}

template <typename T>
llvm::SmallVector<T> expandMatrixOrderWithBatch(llvm::ArrayRef<T> o) {
int oldRank = o.size();
llvm::SmallVector<T> expanded(3, 0);
for (int i = 0; i < oldRank; ++i)
expanded[i] += o[i] + 3 - oldRank;
return expanded;
}

} // namespace gpu
} // namespace triton
} // namespace mlir
Expand Down
12 changes: 9 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,12 +480,18 @@ bool supportMMA(triton::DotOp op, int version) {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
auto rank = retShapePerCTA.size();
auto aTensorTy = cast<RankedTensorType>(op.getA().getType());
auto aShape = aTensorTy.getShape();
auto encoding = cast<DotOperandEncodingAttr>(aTensorTy.getEncoding());
if (retShapePerCTA[rank - 2] < 16 || retShapePerCTA[rank - 1] < 16 ||
aShape[rank - 1] < 16)
return false;
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
return false;
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
auto rank = retShapePerCTA.size();
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
// TODO(Keren): for now, fallback to MMAv2 if handling batch matmul.
Expand Down
Loading

0 comments on commit 68350e9

Please sign in to comment.