Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
e6f9d93
[BUILD] Update CMake target name MLIRGPU{Ops => Dialect}. (#1733)
ingomueller-net Jun 5, 2023
b1a9b62
[BACKEND] Fixes for https://github.com/llvm/llvm-project/commit/641b1…
chsigg Jun 12, 2023
1240e9d
[CI] add text file containing LLVM commit hash
ashay Jun 14, 2023
09170bd
[BACKEND] Fix for LLVM change b126ee65.
chsigg Jun 27, 2023
0e7ea31
Bump LLVM version.
chsigg Jun 27, 2023
c12fd7f
Remove python/, add BUILD files.
chsigg Jul 10, 2023
ee1163a
OpenXLA-specific changes.
chsigg Jul 10, 2023
99c1cdd
[BUILD] Update CMake target name MLIRGPU{Ops => Dialect}. (#1733)
ingomueller-net Jun 5, 2023
942c409
[BACKEND] Fixes for https://github.com/llvm/llvm-project/commit/641b1…
chsigg Jun 12, 2023
cb21d34
[CI] add text file containing LLVM commit hash
ashay Jun 14, 2023
f0a4947
[BACKEND] Fix for LLVM change b126ee65.
chsigg Jun 27, 2023
4508d87
Bump LLVM version.
chsigg Jun 27, 2023
33ccac4
Specify a name for "glob_lit_tests"
rupprecht Jul 13, 2023
336d0e9
Remove python/, add BUILD files.
chsigg Jul 10, 2023
8379a05
OpenXLA-specific changes.
chsigg Jul 10, 2023
afdd4c7
Merge branch 'llvm-head' into patch-1
rupprecht Jul 18, 2023
ef1a4fd
[BUILD] Update CMake target name MLIRGPU{Ops => Dialect}. (#1733)
ingomueller-net Jun 5, 2023
58d9922
[BACKEND] Fixes for https://github.com/llvm/llvm-project/commit/641b1…
chsigg Jun 12, 2023
61ffcd4
[CI] add text file containing LLVM commit hash
ashay Jun 14, 2023
f068eb5
[BACKEND] Fix for LLVM change b126ee65.
chsigg Jun 27, 2023
7930dc4
Bump LLVM version.
chsigg Jun 27, 2023
fb265d3
[BACKEND] Fixes for latest llvm changes.
chsigg Aug 5, 2023
8a313ed
Remove python/, add BUILD files.
chsigg Jul 10, 2023
dcc85b4
OpenXLA-specific changes.
chsigg Jul 10, 2023
763a686
Merge branch 'llvm-head' into patch-1
rupprecht Aug 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
647 changes: 647 additions & 0 deletions BUILD

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions include/triton/Analysis/Alias.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ class AliasInfo {
// Shared Memory Alias Analysis
//===----------------------------------------------------------------------===//
class SharedMemoryAliasAnalysis
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
: public dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AliasInfo>> {
public:
using dataflow::SparseDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::SparseDataFlowAnalysis;
using dataflow::SparseDataFlowAnalysis<
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::SparseForwardDataFlowAnalysis;
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AliasInfo>>::getLatticeElement;

/// XXX(Keren): Compatible interface with MLIR AliasAnalysis for future use.
Expand Down
4 changes: 2 additions & 2 deletions include/triton/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class AxisInfoVisitorList {
};

class AxisInfoAnalysis
: public dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
: public dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>> {
private:
AxisInfoVisitorList visitors;

Expand All @@ -284,7 +284,7 @@ class AxisInfoAnalysis

public:
AxisInfoAnalysis(DataFlowSolver &solver);
using dataflow::SparseDataFlowAnalysis<
using dataflow::SparseForwardDataFlowAnalysis<
dataflow::Lattice<AxisInfo>>::getLatticeElement;
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;

Expand Down
6 changes: 3 additions & 3 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,9 @@ SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort);

/// This uses the toplogicalSort above
SetVector<Operation *>
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
TransitiveFilter forwardFilter = nullptr);
SetVector<Operation *> multiRootGetSlice(
Operation *op, SliceOptions::TransitiveFilter backwardFilter = nullptr,
SliceOptions::TransitiveFilter forwardFilter = nullptr);

/// Create a basic DataFlowSolver with constant and dead code analysis included.
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
Expand Down
3 changes: 3 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,9 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
operand_range getArgOperands() {
return {arg_operand_begin(), arg_operand_end()};
}
MutableOperandRange getArgOperandsMutable() {
return getOperandsMutable();
}

operand_iterator arg_operand_begin() { return operand_begin(); }
operand_iterator arg_operand_end() { return operand_end(); }
Expand Down
2 changes: 1 addition & 1 deletion lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@ class MaxMinOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
//===----------------------------------------------------------------------===//

AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
: dataflow::SparseDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
: dataflow::SparseForwardDataFlowAnalysis<dataflow::Lattice<AxisInfo>>(solver) {
// UnrealizedConversionCast:
// This is needed by TritonGPUToLLVM, to get AxisInfo when the graph is
// in the process of a PartialConversion, where UnrealizedConversionCast
Expand Down
12 changes: 6 additions & 6 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() {

unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ty.getIntOrFloatBitWidth() / 8;
bytesPerElem += (ty.getIntOrFloatBitWidth() + 7) / 8;
}
return bytesPerElem * elems;
}
Expand Down Expand Up @@ -473,9 +473,9 @@ multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
return res;
}

SetVector<Operation *> multiRootGetSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Operation *> multiRootGetSlice(
Operation *op, SliceOptions::TransitiveFilter backwardFilter,
SliceOptions::TransitiveFilter forwardFilter) {
SetVector<Operation *> slice;
slice.insert(op);

Expand All @@ -486,12 +486,12 @@ SetVector<Operation *> multiRootGetSlice(Operation *op,
auto *currentOp = (slice)[currentIndex];
// Compute and insert the backwardSlice starting from currentOp.
backwardSlice.clear();
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
getBackwardSlice(currentOp, &backwardSlice, {{backwardFilter}});
slice.insert(backwardSlice.begin(), backwardSlice.end());

// Compute and insert the forwardSlice starting from currentOp.
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
getForwardSlice(currentOp, &forwardSlice, {{forwardFilter}});
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ add_mlir_conversion_library(TritonGPUToLLVM
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRGPUOps
MLIRGPUDialect
MLIRGPUToNVVMTransforms
MLIRGPUToROCDLTransforms
MLIRGPUTransforms
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ struct ExtractSliceOpConversion
// newShape = rank_reduce(shape)
// Triton only supports static tensor sizes
SmallVector<Value, 4> strideVals;
for (auto i = 0; i < op.static_sizes().size(); ++i) {
for (auto i = 0; i < op.getStaticSizes().size(); ++i) {
if (op.getStaticSize(i) == 1) {
offsetVals.erase(offsetVals.begin() + i);
} else {
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ struct FuncOpConversionBase : public ConvertOpToLLVMPattern<triton::FuncOp> {
}
auto newFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
funcOp.getLoc(), funcOp.getName(), llvmType, linkage,
/*dsoLocal*/ false, LLVM::CConv::C, attributes);
/*dsoLocal*/ false, LLVM::CConv::C, /*comdat=*/nullptr, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());
if (failed(rewriter.convertRegionTypes(&newFuncOp.getBody(), *typeConverter,
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TritonGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ add_mlir_dialect_library(TritonGPUIR
TritonGPUAttrDefsIncGen

LINK_LIBS PUBLIC
MLIRGPUOps
MLIRGPUDialect
TritonIR
)
8 changes: 4 additions & 4 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1218,10 +1218,10 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
op->getLoc(), newType, extract_slice.getSource());
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
op, resType, newArg.getResult(), extract_slice.offsets(),
extract_slice.sizes(), extract_slice.strides(),
extract_slice.static_offsets(), extract_slice.static_sizes(),
extract_slice.static_strides());
op, resType, newArg.getResult(), extract_slice.getOffsets(),
extract_slice.getSizes(), extract_slice.getStrides(),
extract_slice.getStaticOffsets(), extract_slice.getStaticSizes(),
extract_slice.getStaticStrides());
return mlir::success();
}

Expand Down
8 changes: 5 additions & 3 deletions lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
auto filter = [&dotOp](Operation *op) {
mlir::TransitiveFilter filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, {filter});
Expand Down Expand Up @@ -91,7 +91,7 @@ class BlockedToMMA : public mlir::RewritePattern {
int finalBitWidth = getElementTypeOrSelf(x).getIntOrFloatBitWidth();
int origBitWidth = finalBitWidth;
SetVector<Operation *> slice;
mlir::getBackwardSlice(x, &slice, bwdFilter);
mlir::getBackwardSlice(x, &slice, {{bwdFilter}});
Operation *firstOp = slice.empty() ? nullptr : *slice.begin();
if (firstOp)
if (Value arg = firstOp->getOperand(0))
Expand Down Expand Up @@ -137,7 +137,9 @@ class BlockedToMMA : public mlir::RewritePattern {
triton::gpu::MmaEncodingAttr mmaEnc;
if (versionMajor == 1) {
SetVector<Operation *> aBwdSlices, bBwdSlices;
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
mlir::TransitiveFilter isCvt = [](Operation *op) {
return isa<ConvertLayoutOp>(op);
};
getBackwardSlice(a, &aBwdSlices, {isCvt});
getBackwardSlice(b, &bBwdSlices, {isCvt});
// get the source of the first conversion found in slices
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ class RematerializeForward : public mlir::RewritePattern {
return failure();

SetVector<Operation *> cvtSlices;
auto filter = [&](Operation *op) {
mlir::TransitiveFilter filter = [&](Operation *op) {
return op->getBlock() == cvt->getBlock() &&
!isa<triton::gpu::ConvertLayoutOp, scf::YieldOp>(op) &&
!(isa<triton::ReduceOp>(op) &&
Expand Down
5 changes: 5 additions & 0 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ bool isExpensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// same
if (isSingleValue(op->getOperand(0)))
return false;
// TODO(manany): Investigate with Openai why the change here
// https://github.com/openai/triton/commit/640f3c392184cd14291c1bca6a4795eb0f32a61a
// which introduces Case 2 causes breakage to this test
// //third_party/py/jax_triton/tests:pallas_test_sm80 --test_filter=test_fused_attention_bwd
return true;
// Case 2: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
Expand Down
2 changes: 1 addition & 1 deletion lib/Target/PTX/PTXTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
auto *shortPtr =
static_cast<llvm::cl::opt<bool> *>(options["nvptx-short-ptr"]);
assert(shortPtr);
shortPtr->setValue(true);
shortPtr->setValue(false);
std::string sm = cc == 90 ? "sm_90a" : "sm_" + std::to_string(cc);
// max PTX version
int ptxMajor = maxPTX / 10;
Expand Down
1 change: 1 addition & 0 deletions llvm-hash.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
50665511c79f62d97c0d6602e1577b0abe1b982f
5 changes: 0 additions & 5 deletions python/MANIFEST.in

This file was deleted.

18 changes: 0 additions & 18 deletions python/examples/copy_strided.py

This file was deleted.

13 changes: 0 additions & 13 deletions python/examples/empty.py

This file was deleted.

8 changes: 0 additions & 8 deletions python/pyproject.toml

This file was deleted.

Loading