diff --git a/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td index 00cfd806e3a6..434a1c5d62d7 100644 --- a/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td +++ b/include/triton/Dialect/Triton/IR/TritonOpInterfaces.td @@ -69,7 +69,6 @@ def DotOpInterface : OpInterface<"DotOpInterface"> { auto aTy = cast($_op.getA().getType()); auto bTy = cast($_op.getB().getType()); auto cTy = cast($_op->getOperand(2).getType()); - auto dTy = cast($_op.getD().getType()); auto aShape = aTy.getShape(); auto bShape = bTy.getShape(); auto cShape = cTy.getShape(); diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 1560b4bb54db..31be7fad66d1 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -599,7 +599,6 @@ struct MapElementwiseOpConversion } auto &scalarOp = op.getScalarOp(); - Region &parent = *rewriter.getBlock()->getParent(); auto nOutputs = op.getNumResults(); SmallVector scalarOutputs(nOutputs * nElems); diff --git a/lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp b/lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp index c7b775cb7ab4..ab7d855b5ea0 100644 --- a/lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp +++ b/lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp @@ -50,7 +50,6 @@ class GluonResolveAutoEncodingsPass using BaseT::BaseT; void runOnOperation() override { - MLIRContext *context = &getContext(); ModuleOp m = getOperation(); // Do layout inference diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 09151ea5b7a6..de0c0cd4d4cf 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -702,7 +702,7 @@ static SmallVector repeatInterleave(const SmallVectorImpl &vs, SmallVector result; result.reserve(vs.size() * nRepeat); for (auto v : vs) - for (auto _ : llvm::seq(nRepeat)) + for (int i = 0; i < nRepeat; ++i) result.push_back(v); return result; } diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a37c3dc4c8e5..05cbad9b196e 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -4356,9 +4356,6 @@ bool triton::gpu::areLayoutsEquivalent(ArrayRef shape, } bool triton::gpu::isInnermostContiguous(MemDescType type, unsigned numElems) { - Attribute enc = type.getEncoding(); - MLIRContext *ctx = enc.getContext(); - LinearLayout actual = toLinearLayout(type); // Flatten actual outs in reverse order to produce a row-major flattening diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 6c7bbb9bfd6e..36ac0617e432 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1428,7 +1428,6 @@ LinearLayout chooseScaledWmmaScaleLayout( MLIRContext *ctx, int dotOperandIdx, ArrayRef dotOperandShape, unsigned wmmaMDim, unsigned wmmaNDim, bool isTransposed, unsigned scaleFactor, LinearLayout ctaLayout, CGAEncodingAttr cgaLayout) { - using basisT = std::vector>; unsigned rank = dotOperandShape.size(); bool hasBatchDim = rank == 3; auto outDimNames = standardOutDimNames(ctx, rank); @@ -1568,7 +1567,6 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx, unsigned mfmaMDim, ArrayRef tilesPerWarp, ArrayRef warpsPerCTA) { - using basisT = std::vector>; unsigned rank = dotOperandShape.size(); auto order = mlir::triton::gpu::getMatrixOrder(rank, /*rowMajor=*/true); auto standardOutDims = standardOutDimNames(ctx, rank); diff --git a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp index 5a06e6441fdf..4dda93bcecda 100644 --- a/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp +++ b/lib/Dialect/TritonGPU/Transforms/CombineTensorSelectAndIf.cpp @@ -81,7 +81,6 @@ class CombineTensorSelectAndIfPass CombineTensorSelectAndIfPass> { public: void runOnOperation() override { - MLIRContext *context = &getContext(); ModuleOp m = getOperation(); canonicalizeSelectUsersInSCFIf(m); diff --git a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp index 45c9475fac26..d38b2b6b4333 100644 --- a/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/FuseNestedLoops.cpp @@ -709,9 +709,6 @@ static void fuseOneLevel(LoopNestNode *parent, mlir::DominanceInfo &domInfo) { Value curI = fused.getRegionIterArg(1); Value i; - auto lenInnersIt = - ValueRange(fused.getRegionIterArgs()).begin() + lenInnersStartIdx; - ArrayRef ivars = fused.getRegionIterArgs().slice(ivarStartIdx); auto bodyOutsIt = ValueRange(fused.getRegionIterArgs()).begin() + innerOutsStartIdx; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp index 2d7fa5487389..cac30c0dfec6 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp @@ -21,7 +21,6 @@ class TMEMAllocWithUnusedInit LogicalResult matchAndRewrite(triton::nvidia_gpu::TMEMAllocOp op, PatternRewriter &rewriter) const override { - MLIRContext *ctx = op.getContext(); if (op.getSrc() == nullptr) return failure(); SmallVector users(op.getResult().getUsers().begin(), diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp index bcf274fb7bdb..b713f9606e7f 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp @@ -308,8 +308,6 @@ class TritonGPUOptimizeThreadLocalityPass auto srcEncoding = srcType.getEncoding(); assert(isa(srcEncoding) && "Thread locality optimization only supports blocked encoding"); - auto elemsPerThread = - triton::gpu::getElemsPerThread(srcType)[reduce.getAxis()]; auto rank = srcShape.size(); // create new layouts auto blocked3d = getThreadLocalityOptimizedEncoding(reduce); @@ -354,8 +352,8 @@ class TritonGPUOptimizeThreadLocalityPass // create new accum update auto newUpdate = createUpdate(builder, newLoop, newReduce, oldUpdate); // create new yield - auto newYield = createYield(builder, newLoop, oldYield, - newUpdate->getResult(0), blockArgNum); + createYield(builder, newLoop, oldYield, newUpdate->getResult(0), + blockArgNum); // create post loop reduction on the original reduce axis auto newReduce2 = createPostLoopReduce(builder, newLoop, reduce); // add convert_layout to get back to original layout, the result layout diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp index 10ac84aea67a..c0fc413bffb6 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/LowerLoops.cpp @@ -209,8 +209,6 @@ void createTMAAsyncCopy( Operation *firstUse = getFirstUseOfPipelinedOp({loadOp}, forOp, schedule); assert(firstUse && "LoadOp has no users"); - Attribute sharedMemorySpace = - ttg::SharedMemorySpaceAttr::get(forOp.getContext()); builder.setInsertionPoint(loadOp); builder.setStageCluster(schedule[loadOp]); @@ -957,9 +955,6 @@ void multibufferTensorMemory(scf::ForOp forOp, CoarseSchedule &schedule, scf::ForOp lowerMMA(ttng::MMAv5OpInterface mma, scf::ForOp forOp, CoarseSchedule &schedule) { - auto isLoadToBePipelined = [&](Operation *op) { - return schedule[mma].first > schedule[op].first; - }; Value alloc = mma.getAccumulator(); int mmaSelfLatency = getSelfLatencyFromAttr(mma.getOperation()); diff --git a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp index b3794f4ec56e..a0e78d4fcd37 100644 --- a/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp +++ b/lib/Dialect/TritonGPU/Transforms/WarpSpecialization/PartitionScheduling.cpp @@ -106,12 +106,10 @@ std::unique_ptr buildGraph(Operation *region) { // init iter args { - size_t idx = 0; - for (auto operand : forOp.getInitArgs()) { + for (size_t idx = 0; idx < forOp.getInitArgs().size(); ++idx) { auto iter_arg_node = node->getDefines()[idx + 1]; operands[std::make_pair(op, idx + 3)] = InputPort(iter_arg_node, 0); - idx++; } } diff --git a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp index 4939b9f2010a..d1b44c3b92b1 100644 --- a/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp +++ b/lib/Dialect/TritonInstrument/IR/FunctionBuilder.cpp @@ -436,7 +436,7 @@ void FunctionBuilder::createSetWaitingCall(ImplicitLocOpBuilder &b, Value mbar, createCallToCachedFunction( b, "set_waiting", args, /*assertInfo=*/std::nullopt, {barriersType, waitingType}, - [barriersType, waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value baseThread = entryBlock->getArgument(2); @@ -537,7 +537,7 @@ void FunctionBuilder::createClearWaitingCall(ImplicitLocOpBuilder &b, createCallToCachedFunction( b, "clear_waiting", args, /*assertInfo=*/std::nullopt, {barriersType, waitingType}, - [barriersType, waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [waitingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value baseThread = entryBlock->getArgument(2); @@ -785,8 +785,7 @@ void FunctionBuilder::createVerifyBarrierCanInitCall(ImplicitLocOpBuilder &b, createCallToCachedFunction( b, "verify_barrier_can_init", args, assertInfo, {barriersType, barrierStatesType}, - [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -841,8 +840,7 @@ void FunctionBuilder::createVerifyBarrierInitializedCall( createCallToCachedFunction( b, "verify_barrier_initialized", args, assertInfo, {barriersType, barrierStatesType}, - [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -899,8 +897,7 @@ void FunctionBuilder::createInitBarrierStateCall(ImplicitLocOpBuilder &b, createCallToCachedFunction( b, "init_barrier_state", args, /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType}, - [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value count = entryBlock->getArgument(2); @@ -975,8 +972,8 @@ void FunctionBuilder::createInvalidateBarrierStateCall(ImplicitLocOpBuilder &b, b, "invalidate_barrier_state", args, /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType, waitingType}, - [barriersType, barrierStatesType, waitingType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [barrierStatesType, waitingType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1058,8 +1055,7 @@ void FunctionBuilder::createVerifyBarrierArriveCall( createCallToCachedFunction( b, "verify_barrier_arrive", args, assertInfo, {barriersType, barrierStatesType}, - [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value count = entryBlock->getArgument(2); @@ -1176,8 +1172,7 @@ void FunctionBuilder::createUpdateBarrierStateCall( createCallToCachedFunction( b, "update_barrier_state", args, /*assertInfo=*/std::nullopt, {barriersType, barrierStatesType}, - [barriersType, barrierStatesType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [barrierStatesType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value count = entryBlock->getArgument(2); @@ -1307,8 +1302,7 @@ void FunctionBuilder::createSetWriteVisibilityCall( b, "set_write_visibility", args, /*assertInfo=*/std::nullopt, {buffersType, writeVisibilityType, (uint64_t)memType}, - [buffersType, writeVisibilityType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [writeVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1368,8 +1362,7 @@ void FunctionBuilder::createSetReadVisibilityCall( b, "set_read_visibility", args, /*assertInfo=*/std::nullopt, {buffersType, readVisibilityType, (uint64_t)memType}, - [buffersType, readVisibilityType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [readVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1434,8 +1427,7 @@ void FunctionBuilder::createClearWriteTrackingCall(ImplicitLocOpBuilder &b, b, "clear_write_tracking", args, /*assertInfo=*/std::nullopt, {buffersType, writeTrackingType, (uint64_t)memType}, - [buffersType, writeTrackingType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1491,8 +1483,7 @@ void FunctionBuilder::createClearReadVisibilityCall(ImplicitLocOpBuilder &b, b, "clear_read_visibility", args, /*assertInfo=*/std::nullopt, {buffersType, readVisibilityType, (uint64_t)memType}, - [buffersType, readVisibilityType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [readVisibilityType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1549,8 +1540,7 @@ void FunctionBuilder::createClearReadTrackingCall(ImplicitLocOpBuilder &b, b, "clear_read_tracking", args, /*assertInfo=*/std::nullopt, {buffersType, readTrackingType, (uint64_t)memType}, - [buffersType, readTrackingType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1614,8 +1604,8 @@ void FunctionBuilder::createTrackVisibleWritesCall(ImplicitLocOpBuilder &b, b, "track_visible_writes", args, /*assertInfo=*/std::nullopt, {barriersType, writeVisibilityType, writeTrackingType, (uint64_t)memType}, - [barriersType, writeVisibilityType, - writeTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [writeVisibilityType, writeTrackingType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1699,8 +1689,8 @@ void FunctionBuilder::createTrackVisibleReadsCall(ImplicitLocOpBuilder &b, b, "track_visible_reads", args, /*assertInfo=*/std::nullopt, {barriersType, readVisibilityType, readTrackingType, (uint64_t)memType}, - [barriersType, readVisibilityType, - readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [readVisibilityType, readTrackingType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1791,7 +1781,7 @@ void FunctionBuilder::createTrackBarrierWriteForBufferCall( /*assertInfo=*/std::nullopt, {barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType, (uint64_t)memType, (uint64_t)diagonalEffectRecipientCTAs}, - [barriersType, buffersType, writeTrackingType, barrierWriteRecipientsType, + [writeTrackingType, barrierWriteRecipientsType, diagonalEffectRecipientCTAs](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); @@ -1909,8 +1899,8 @@ void FunctionBuilder::createClearBarrierWriteTrackingCall( /*assertInfo=*/std::nullopt, {barriersType, writeTrackingType, barrierWriteRecipientsType, (uint64_t)memType}, - [barriersType, writeTrackingType, barrierWriteRecipientsType]( - ImplicitLocOpBuilder &fb, Block *entryBlock) { + [writeTrackingType, barrierWriteRecipientsType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -1989,8 +1979,7 @@ void FunctionBuilder::createClearBarrierReadTrackingCall( b, "clear_barrier_read_tracking", args, /*assertInfo=*/std::nullopt, {barriersType, readTrackingType, (uint64_t)memType}, - [barriersType, readTrackingType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -2063,9 +2052,8 @@ void FunctionBuilder::createTransferVisibleWritesCall( /*assertInfo=*/std::nullopt, {barriersType, writeVisibilityType, writeTrackingType, barrierWriteRecipientsType, (uint64_t)memType}, - [barriersType, writeVisibilityType, writeTrackingType, - barrierWriteRecipientsType](ImplicitLocOpBuilder &fb, - Block *entryBlock) { + [writeVisibilityType, writeTrackingType, barrierWriteRecipientsType]( + ImplicitLocOpBuilder &fb, Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -2170,8 +2158,8 @@ void FunctionBuilder::createTransferVisibleReadsCall( b, "transfer_visible_reads", args, /*assertInfo=*/std::nullopt, {barriersType, readVisibilityType, readTrackingType, (uint64_t)memType}, - [barriersType, readVisibilityType, - readTrackingType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [readVisibilityType, readTrackingType](ImplicitLocOpBuilder &fb, + Block *entryBlock) { Value mbarOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); @@ -2627,7 +2615,7 @@ void FunctionBuilder::createStageAccessForCommitCall( createCallToCachedFunction( b, "stage_access_for_commit", args, /*assertInfo=*/std::nullopt, {buffersType, commitsType}, - [buffersType, commitsType](ImplicitLocOpBuilder &fb, Block *entryBlock) { + [commitsType](ImplicitLocOpBuilder &fb, Block *entryBlock) { Value bufOffset = entryBlock->getArgument(0); Value lengthVal = entryBlock->getArgument(1); Value pred = entryBlock->getArgument(2); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp index 29c1955eacd3..f3d5ec2293a8 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/InterleaveTMem.cpp @@ -71,7 +71,7 @@ findBufferAccessMemdescSubview(Operation *subview) { src = indexOp.getSrc(); shape = to_vector(indexOp.getType().getShape()); offsets = {indexOp.getIndex()}; - for (auto i : llvm::seq(std::max(0, shape.size() - 1))) + for (int i = 0, e = std::max(0, shape.size() - 1); i < e; ++i) offsets.push_back(arith::ConstantIntOp::create(builder, loc, 0, 32)); } else { auto subsliceOp = cast(subview); @@ -261,7 +261,6 @@ struct TritonNvidiaGPUInterleaveTMemPass TritonNvidiaGPUInterleaveTMemPass>::TritonNvidiaGPUInterleaveTMemPassBase; void runOnOperation() override { - MLIRContext *context = &getContext(); ModuleOp m = getOperation(); SmallVector> opsToSink; m.walk([&](Operation *op) { diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp index df133f1faa42..bf255e9b2d1c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/MMALowering.cpp @@ -82,7 +82,6 @@ struct TCGen5MMAScaleSharedToTmemConversion LogicalResult matchAndRewrite(TCGen5MMAScaledOp op, PatternRewriter &rewriter) const override { - MLIRContext *context = op->getContext(); auto aScaleType = op.getAScale().getType(); auto bScaleType = op.getBScale().getType(); if (aScaleType.getShape() != aScaleType.getAllocShape() || diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp index b04b1044f1d5..6ebd9783e4d3 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp @@ -102,7 +102,6 @@ class TritonNvidiaGPUOptimizeDescriptorEncodingPass using BaseT::BaseT; void runOnOperation() override { - MLIRContext *context = &getContext(); ModuleOp m = getOperation(); NvidiaGPUAssignDescriptorMemoryLayouts assignMemoryLayouts; assignMemoryLayouts.assignMemoryLayouts(m); diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index a0c43e2137b3..2c004ee379c1 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -168,7 +168,6 @@ class TMACreateDescLowering : public OpRewritePattern { LogicalResult matchAndRewrite(MakeTensorDescOp op, PatternRewriter &rewriter) const override { - MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); auto alloc = triton::gpu::GlobalScratchAllocOp::create( rewriter, loc, getPointerType(rewriter.getI8Type()), TMA_SIZE_BYTES, diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp index b9da56d30c23..bd53aa0a87cf 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.cpp @@ -138,7 +138,6 @@ FailureOr getTMAElementType(Location loc, tt::TensorDescInterface ty) { LogicalResult createTMADesc(Value tmaPtr, MakeTensorDescOp op, OpBuilder &builder) { using namespace mlir; - MLIRContext *ctx = op.getContext(); auto loc = op.getLoc(); auto mkI32Constant = [&](int32_t val) { return arith::ConstantOp::create(builder, loc, builder.getI32Type(), diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp index 695c834063e2..814876deca0c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TensorMemoryAllocation.cpp @@ -398,7 +398,6 @@ class TritonTensorMemoryAllocationPass void runOnOperation() override { ModuleOp mod = getOperation(); - MLIRContext *ctx = &getContext(); DenseMap offsets; // TODO: handle cases with multiple function with TMEMAllocOp. diff --git a/lib/Target/LLVMIR/LLVMDILocalVariable.cpp b/lib/Target/LLVMIR/LLVMDILocalVariable.cpp index 514b68a3db48..d7e6aec28764 100644 --- a/lib/Target/LLVMIR/LLVMDILocalVariable.cpp +++ b/lib/Target/LLVMIR/LLVMDILocalVariable.cpp @@ -90,8 +90,8 @@ struct LLVMDILocalVariablePass // a subclass of mlir::Value, which is the value defined by this operation OpResult opResult = op->getResult(0); // create and insert this call-dbg-value intrinsic after the op - Operation *dbgOp = LLVM::DbgValueOp::create(builder, childLoc, opResult, - diLocalVarAttr, diExprAttr); + LLVM::DbgValueOp::create(builder, childLoc, opResult, diLocalVarAttr, + diExprAttr); } } @@ -117,7 +117,7 @@ struct LLVMDILocalVariablePass // Filename, line and colmun to associate to the function. LLVM::DIFileAttr fileAttr; - int64_t line = 1, col = 1; + int64_t line = 1; FileLineColLoc fileLoc = extractFileLoc(loc); if (!fileLoc && compileUnitAttr) { fileAttr = compileUnitAttr.getFile(); @@ -125,7 +125,6 @@ struct LLVMDILocalVariablePass fileAttr = LLVM::DIFileAttr::get(context, "", ""); } else { line = fileLoc.getLine(); - col = fileLoc.getColumn(); StringRef inputFilePath = fileLoc.getFilename().getValue(); fileAttr = LLVM::DIFileAttr::get( context, llvm::sys::path::filename(inputFilePath), @@ -367,8 +366,6 @@ struct LLVMDILocalVariablePass LLVM::DISubprogramAttr diSubprogramAttr = {}; void runOnOperation() override { - Operation *op = getOperation(); - getOperation()->walk([&](Operation *op) -> void { if (isa(op)) { auto funcOp = cast(op); diff --git a/lib/Target/LLVMIR/LLVMDIScope.cpp b/lib/Target/LLVMIR/LLVMDIScope.cpp index a6bb9cd9b784..e741b7ad3cce 100644 --- a/lib/Target/LLVMIR/LLVMDIScope.cpp +++ b/lib/Target/LLVMIR/LLVMDIScope.cpp @@ -44,7 +44,7 @@ struct LLVMDIScopePass : public impl::LLVMDIScopeBase { // Filename, line and colmun to associate to the function. LLVM::DIFileAttr fileAttr; - int64_t line = 1, col = 1; + int64_t line = 1; FileLineColLoc fileLoc = extractFileLoc(loc); if (!fileLoc && compileUnitAttr) { fileAttr = compileUnitAttr.getFile(); @@ -52,7 +52,6 @@ struct LLVMDIScopePass : public impl::LLVMDIScopeBase { fileAttr = LLVM::DIFileAttr::get(context, "", ""); } else { line = fileLoc.getLine(); - col = fileLoc.getColumn(); StringRef inputFilePath = fileLoc.getFilename().getValue(); fileAttr = LLVM::DIFileAttr::get( context, llvm::sys::path::filename(inputFilePath), diff --git a/python/src/interpreter.cc b/python/src/interpreter.cc index 747a0cc17191..b9b3bd7727f3 100644 --- a/python/src/interpreter.cc +++ b/python/src/interpreter.cc @@ -610,8 +610,6 @@ makeAtomicRMWOp(pybind11::dtype dtype, const uint64_t *ptr, const void *val, } // namespace void init_triton_interpreter(py::module &&m) { - using ret = py::return_value_policy; - py::enum_(m, "MEM_SEMANTIC", py::module_local()) .value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE) .value("ACQUIRE", MemSemantic::ACQUIRE) diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 795a60300e4c..e7961dfe09d3 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -468,8 +468,6 @@ LogicalResult ConcatOp::verify() { // 3. find, which input tile holds the dst value auto multiDimOperandIdx = LLVM::AMD::multiDimElementwise( elemCoordsArray, srcShape, std::divides()); - auto linearOperandIdx = - mlir::LLVM::linearize(multiDimOperandIdx, srcToDstShape, defaultOrder); // 4. subtract dst coordinates and start coordinates of the tile diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp index 0543b3ed36f2..cd25df397d4c 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp @@ -29,7 +29,6 @@ struct ConcatOpConversion : public ConvertOpToLLVMPattern { RankedTensorType srcType = cast(srcVal.getType()); ArrayRef srcShape = srcType.getShape(); - MLIRContext *context = resultType.getContext(); auto linearLayoutSrc = triton::gpu::toLinearLayout(srcType); auto outDimNames = llvm::to_vector(linearLayoutSrc.getOutDimNames()); // Call transposeOuts, to ensure that order of input and output tensor diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp index 5e2f8ab79d28..92795c4cc59b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp @@ -130,7 +130,6 @@ BufferEmitter::emitLoadToLds(Type type, Value byteWidth, Value rsrcDesc, SmallVector commonArgs; fillCommonArgs(type, rsrcDesc, offset, pred, cm, /*isBufferLoad=*/true, commonArgs); - Type bufferType = getBufferOpType(type, false); // buffer_load_to_lds is only supported on gfx942/gfx950 which always use // asyncmark. Emit the async intrinsic so LLVM's SIInsertWaitcnts tracks diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index b8ddc1f70c46..d974d8442f34 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -273,10 +273,8 @@ class ConvertLayoutOpInThreadSwap public: ConvertLayoutOpInThreadSwap(LLVMTypeConverter &typeConverter, - const TargetInfoBase &targetInfo, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) { - } + const TargetInfoBase &, PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} struct ByteLocation { int regIdx; @@ -720,9 +718,6 @@ class ConvertLayoutOpInThreadSwap LogicalResult matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto &amdTargInfo = - static_cast(targetInfo); - auto srcTy = op.getSrc().getType(); auto dstTy = op.getType(); @@ -768,9 +763,6 @@ class ConvertLayoutOpInThreadSwap transferWithVPerm(op, conversion, adaptor, rewriter); return success(); } - -private: - const TargetInfoBase &targetInfo; }; } // namespace diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp index 13039a4b96af..3d708dd36eee 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertWarpPipeline.cpp @@ -139,8 +139,7 @@ class ConvertPipelinedForPattern : public OpRewritePattern { // Set barrier before starting the loop. This resolves any outstanding // synchronization before beginning the specialized asymmetric // synchronization. - auto preBarrier = mlir::triton::gpu::BarrierOp::create( - b, loc, triton::gpu::AddrSpace::Local); + mlir::triton::gpu::BarrierOp::create(b, loc, triton::gpu::AddrSpace::Local); // Insert condbarrier::second_half before starting the loop // FIXME : correctly calculate numbers per the arch diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a21aa602551e..d8758166beb7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1820,7 +1820,6 @@ struct BufferAtomicCASOpConversion } mlir::Operation *lastCASOp; - MLIRContext *ctx = rewriter.getContext(); GCNBuilder waitcntBuilder; // Check if the op has users, if it does we set GLC=1, otherwise GLC=0 @@ -2528,8 +2527,6 @@ struct TDMPrefetchConversion // Return offsets Type llvmResultStructTy = getTypeConverter()->convertType(op.getType(0)); - auto structType = dyn_cast( - getTypeConverter()->convertType(op.getType(0))); Value resultStruct = packLLElements(loc, getTypeConverter(), offsets, rewriter, llvmResultStructTy); rewriter.replaceOp(op, {resultStruct}); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index feda8157617d..b02882b8e5f3 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -48,9 +48,7 @@ Operation *createIglpOpt(PatternRewriter &rewriter, Location loc, int value) { struct InstructionSchedHintsRewriter : public OpRewritePattern { - InstructionSchedHintsRewriter(MLIRContext *ctx, StringRef arch, - int32_t numStages) - : OpRewritePattern(ctx), numStages(numStages) {} + InstructionSchedHintsRewriter(MLIRContext *ctx) : OpRewritePattern(ctx) {} LogicalResult matchAndRewrite(triton::amdgpu::InstructionSchedHint instructionSchedHint, @@ -94,9 +92,6 @@ struct InstructionSchedHintsRewriter rewriter.eraseOp(instructionSchedHint); return success(); } - -private: - int32_t numStages; }; struct TritonAMDGPULowerInstructionSchedHints @@ -121,8 +116,7 @@ struct TritonAMDGPULowerInstructionSchedHints RewritePatternSet patterns(ctx); - patterns.add(ctx, this->gfxArch, - this->numStages); + patterns.add(ctx); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp index eeafa1fb60d7..d356553ce28a 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TensorPtrOpsToLLVM.cpp @@ -43,34 +43,6 @@ LogicalResult validateStridesAndSharedOrder(triton::MakeTensorDescOp op, return success(); } -// Collects all users of the value beyond the basic block boundaries -// defining a given value. -void collectUsers(Value value, llvm::SetVector &users) { - for (OpOperand &use : value.getUses()) { - Operation *userOp = use.getOwner(); - if (users.contains(userOp)) { - // stop recursion; avoid loops - return; - } - users.insert(userOp); - const unsigned argIdx = use.getOperandNumber(); - - if (auto unrealCast = dyn_cast(userOp)) { - collectUsers(unrealCast->getResult(argIdx), users); - } - - if (auto branch = dyn_cast(userOp)) { - auto successors = branch->getSuccessors(); - for (auto [idx, successor] : llvm::enumerate(successors)) { - auto operands = branch.getSuccessorOperands(idx); - if (argIdx < operands.size()) { - collectUsers(successor->getArgument(argIdx), users); - } - } - } - } -} - struct MakeTensorDescOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -113,7 +85,6 @@ struct MakeTensorDescOpConversion } auto sharedOrder = triton::gpu::getOrder( cast(sharedEnc), shapePerCTA); - bool isRowMajor = sharedOrder[0] == (sharedOrder.size() - 1); // Create TDM descriptor for 2D-5D tensors auto tdmDesc = LLVM::AMD::createTDMDescriptor( rewriter, loc, getTypeConverter(), elementType, shapePerCTA, numWarps, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index f628eab5bd60..e148fe4c92de 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -156,20 +156,10 @@ struct ConvertTritonAMDGPUToLLVM // Make benefit for AMD specific patterns higher so they apply before common // patterns int AMDBenefit = commonBenefit + 1; - auto populatePatterns1 = [&](auto populateFunc, int benefit) { - populateFunc(typeConverter, patterns, axisInfoAnalysis, allocation, - benefit); - }; - auto populatePatterns5 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, benefit); }; - auto populatePatterns6 = [&](auto populateFunc, int benefit) { - populateFunc(typeConverter, patterns, axisInfoAnalysis, allocation, - targetInfo, benefit); - }; - auto populatePatterns7 = [&](auto populateFunc, int benefit) { populateFunc(typeConverter, patterns, targetInfo, benefit); }; @@ -274,9 +264,10 @@ struct ConvertTritonAMDGPUToLLVM // Ask for 16B alignment on global_smem because that's the largest we should // ever need (4xi32). auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0); - auto global = LLVM::GlobalOp::create( + LLVM::GlobalOp::create( b, loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External, - "global_smem", /*value=*/Attribute(), /*alignment=*/16, + "global_smem", + /*value=*/Attribute(), /*alignment=*/16, // Add ROCm support. static_cast(NVVM::NVVMMemorySpace::Shared)); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index f5d6883b0660..7387d429625b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -980,7 +980,6 @@ class ScaledBlockedToScaledMFMAF8F6F4 final // kWidth is 16 for fp4. const unsigned kWidth = kBase; assert(kWidth == 32); - using basisT = std::vector>; auto aShape = a.getType().getShape(); auto bShape = b.getType().getShape(); @@ -1179,8 +1178,6 @@ class ScaledBlockedToScaledWMMAF8F6F4 final auto order = ttg::getMatrixOrder(rank, /*rowMajor=*/true); auto standardOutDims = standardOutDimNames(ctx, rank); - using basisT = std::vector>; - RankedTensorType aType = a.getType(); RankedTensorType bType = b.getType(); auto aCgaLayout = ttg::getCGALayout(aType.getEncoding()); @@ -1383,9 +1380,6 @@ FailureOr chooseWmmaInstruction(Location loc, int wmmaVersion, // number of matrix elements along k dim per one WMMA instruction unsigned kDim = 0; - auto resShape = cType.getShape(); - auto rank = resShape.size(); - unsigned mDim = 16; unsigned nDim = 16; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index b314bc9ba65b..be5b0871f9dc 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -939,13 +939,11 @@ void Pingponger::addAsymmetricSyncToLoop(OpBuilder &builder, Location loc) { warpIDX, constZero); auto warpHigh = arith::CmpIOp::create(builder, loc, arith::CmpIPredicate::ne, warpIDX, constZero); - auto condBarrierHigh = - tt::amdgpu::CondBarrierOp::create(builder, loc, warpHigh); + tt::amdgpu::CondBarrierOp::create(builder, loc, warpHigh); // Insert condbarrier::first_half after the end of the loop builder.setInsertionPointAfter(forOp); - auto condBarrierLow = - tt::amdgpu::CondBarrierOp::create(builder, loc, warpLow); + tt::amdgpu::CondBarrierOp::create(builder, loc, warpLow); } void Pingponger::getDotPingponged() { @@ -957,7 +955,6 @@ void Pingponger::getDotPingponged() { } OpBuilder builder(forOp); - MLIRContext *ctx = forOp.getContext(); Location loc = forOp.getLoc(); forOp->walk([&](Operation *op) { diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 465f1e7275b9..cbe353a9dfc6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -344,8 +344,6 @@ createDecomposeOffsetFromAdd(RewriterBase &rewriter, Location loc, Value expr, createAddOffsetsOfSameKind(rewriter, loc, uniformOffsetL, uniformOffsetR); Value nonUniformAdd = createAddOffsetsOfSameKind( rewriter, loc, nonUniformOffsetL, nonUniformOffsetR); - Value maybeDeadValue[] = {nonUniformOffsetL, nonUniformOffsetR, - uniformOffsetL, uniformOffsetR}; return {uniformAdd, nonUniformAdd}; } @@ -460,16 +458,6 @@ struct FatPointers { // for map default insert FatPtrAttrs() = default; - friend bool operator==(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { - return lhs.canNarrow == rhs.canNarrow && - lhs.isSmallTensor == rhs.isSmallTensor && - lhs.attributes.getArrayRef() == rhs.attributes.getArrayRef(); - } - - friend bool operator!=(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { - return !(lhs == rhs); - } - static FatPtrAttrs intersect(const FatPtrAttrs &lhs, const FatPtrAttrs &rhs) { FatPtrAttrs result; diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp index 15ce4ca8c7ee..3ae300c99abc 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization.cpp @@ -49,9 +49,7 @@ class NVGPUWarpSpecializationPass bool hasElse = false; funcOp->walk([&](scf::IfOp ifOp) { if (ifOp.elseBlock()) { - for (Operation &op : ifOp.elseBlock()->getOperations()) { - hasElse = true; - } + hasElse = true; } }); if (hasElse) diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/Utility.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/Utility.cpp index 9cf21cacaa1d..fe93e5534cb6 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/Utility.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/Utility.cpp @@ -30,8 +30,6 @@ void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds) { SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), asyncTaskIds.end()); sort(sortedAsyncTaskIds); - auto i32Ty = IntegerType::get(op->getContext(), 32); - auto size = static_cast(sortedAsyncTaskIds.size()); op->setAttr("async_task_id", DenseI32ArrayAttr::get(op->getContext(), sortedAsyncTaskIds)); } diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSBuffer.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSBuffer.cpp index a8dd013cada2..817e2d0a8b16 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSBuffer.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSBuffer.cpp @@ -142,8 +142,7 @@ scf::IfOp rewriteIfOp(scf::IfOp ifOp, SmallVector &taskTopOps, // Go through region ops in the thenBlock. updateAccumLoopCount takes current // accumCnt value and returns the value at the end of the thenBlock. - Value endAccum = - updateAccumLoopCount(opList, taskTopOps, regionsWithChannels, prevAccum); + updateAccumLoopCount(opList, taskTopOps, regionsWithChannels, prevAccum); SmallVector ifYieldOperands = newIfOp.thenYield().getOperands(); @@ -482,8 +481,7 @@ scf::ForOp createNewLoopWrapper(scf::ForOp origForOp, if (auto tOp = dyn_cast(&op)) opList.push_back(&op); } - Value endAccum = - updateAccumLoopCount(opList, taskTopOps, regionsWithChannels, prevAccum); + updateAccumLoopCount(opList, taskTopOps, regionsWithChannels, prevAccum); LLVM_DEBUG({ LDBG("-- before replacing yieldOp "); newForOp.dump(); diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp index 0ac39f6e854f..eb4068c198ff 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSDataPartition.cpp @@ -1080,8 +1080,7 @@ static Operation *sliceOp(Operation *op, int offset, IRMapping &mappings, for (unsigned i = 0; i < forOp.getInitArgs().size(); i++) { auto initArg = forOp.getInitArgs()[i]; Value newInitArg; - auto newInitArgOp = - sliceOp(initArg, offset, mappings, reverseMappings, partitionScheme); + sliceOp(initArg, offset, mappings, reverseMappings, partitionScheme); if (auto bbArg = dyn_cast(initArg)) { // find the corresponding new block argument Block *parentBlock = bbArg.getOwner(); @@ -1095,8 +1094,6 @@ static Operation *sliceOp(Operation *op, int offset, IRMapping &mappings, assert(argIndex < parentBlock->getNumArguments() && "new init argment not found"); Region *parentRegion = parentBlock->getParent(); - Region &newParentRegion = - newInitArgOp->getRegion(parentRegion->getRegionNumber()); newInitArg = parentRegion->getArgument(argIndex); } else { newInitArg = mappings.lookupOrNull(initArg); diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp index 52b2eabc99ff..9a4463d4b682 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSLowerMem.cpp @@ -204,8 +204,8 @@ Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, auto prodBarrier = getBarrierForPipelineStage(builder, barrierAlloc, bufferIdx); auto pred = builder.createWithAsyncTaskIds(loc, 1, 1); - auto expect = builder.createWithAsyncTaskIds( - loc, prodBarrier, sizeInBytes, pred); + builder.createWithAsyncTaskIds(loc, prodBarrier, + sizeInBytes, pred); // Convert all the producers to async_tma_copy_global_to_local Operation *copy = nullptr; @@ -225,8 +225,7 @@ Operation *optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, getBarrierForPipelineStage(builder, barrierAlloc, bufferIdxExtract); phase = builder.createWithAsyncTaskIds( loc, builder.getI32Type(), phase); - auto wait = builder.createWithAsyncTaskIds( - loc, consBarrier, phase); + builder.createWithAsyncTaskIds(loc, consBarrier, phase); // Convert all the consumers to local_load for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { diff --git a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTaskPartition.cpp b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTaskPartition.cpp index 4952e4d346eb..20bc60964101 100644 --- a/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTaskPartition.cpp +++ b/third_party/nvidia/hopper/lib/Transforms/WarpSpecialization/WSTaskPartition.cpp @@ -52,19 +52,6 @@ void doTaskPartition(triton::FuncOp &funcOp, unsigned numWarpGroups) { if (loops.empty() || loads.empty() || dots.empty()) return; - auto getLoopLevel = [&](Operation *op) { - // Compute loop depth - unsigned depth = 0; - Operation *parent = op->getParentOp(); - while (parent) { - if (isa(parent)) { - ++depth; - } - parent = parent->getParentOp(); - } - return depth; - }; - // Step 1. Select loads into the first task, which is the producer task by // default. Place dots into the second task, which is the consumer. // Only consider loads that are connected to a dot op in a loop. diff --git a/third_party/nvidia/include/cublas_instance.h b/third_party/nvidia/include/cublas_instance.h index d943f6a385f9..3d69ba13ce1a 100644 --- a/third_party/nvidia/include/cublas_instance.h +++ b/third_party/nvidia/include/cublas_instance.h @@ -209,8 +209,7 @@ class CublasLtInstance { cublasOperation_t transa = CUBLAS_OP_T; cublasOperation_t transb = CUBLAS_OP_N; - cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, - Ddesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL; // Use FP32 compute and accumulation cublasComputeType_t computeType = CUBLAS_COMPUTE_32F; @@ -266,8 +265,6 @@ class CublasLtInstance { successOrExit(cublasLtMatrixLayoutCreate(&Adesc, dataType, k, m, k)); successOrExit(cublasLtMatrixLayoutCreate(&Bdesc, dataType, k, n, k)); successOrExit(cublasLtMatrixLayoutCreate(&Cdesc, CUDA_R_16F, m, n, m)); - Ddesc = Cdesc; - float alpha = 1.0f; float beta = 0.0f; // No bias diff --git a/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp b/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp index 082d5c96f090..d41d16fc82c6 100644 --- a/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp +++ b/third_party/nvidia/lib/Dialect/NVWS/Transforms/AssignStagePhase.cpp @@ -550,7 +550,6 @@ class NVWSAssignStagePhase : public impl::NVWSAssignStagePhaseBase { public: void runOnOperation() override { - MLIRContext *context = &getContext(); mlir::ModuleOp m = getOperation(); m.walk([&](triton::FuncOp funcOp) { diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp index 3a8e28f922f2..71cc98c11a1a 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM.cpp @@ -30,10 +30,9 @@ struct ScaledDotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - ScaledDotOpConversion(LLVMTypeConverter &converter, int computeCapability, + ScaledDotOpConversion(LLVMTypeConverter &converter, int, PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit), - computeCapability(computeCapability) {} + : ConvertOpToLLVMPattern(converter, benefit) {} LogicalResult matchAndRewrite(triton::DotScaledOp op, triton::DotScaledOp::Adaptor adaptor, @@ -46,24 +45,18 @@ struct ScaledDotOpConversion "linear layout"); return convertMMADotScaled(op, adaptor, getTypeConverter(), rewriter); } - -private: - int computeCapability; }; struct DotOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - DotOpConversion(LLVMTypeConverter &converter, int computeCapability, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(converter, benefit), - computeCapability(computeCapability) {} + DotOpConversion(LLVMTypeConverter &converter, int, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} LogicalResult matchAndRewrite(triton::DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // D = A * B + C - Value D = op.getResult(); auto dType = op.getResult().getType(); auto dEncoding = dType.getEncoding(); @@ -90,9 +83,6 @@ struct DotOpConversion : public ConvertOpToLLVMPattern { llvm::report_fatal_error( "Unsupported DotOp found when converting TritonGPU to LLVM."); } - -private: - int computeCapability; }; struct WarpGroupDotOpConversion diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index 140d9d92ec82..77745b61cd7e 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -1103,7 +1103,6 @@ static LinearLayout getMsgToPackedOffsetLayout(ttg::MemDescType ty, ttg::TMAMode mode) { auto ctx = ty.getContext(); auto kMsg = str_attr("msg"); - auto kBlock = str_attr("block"); auto shapePerCTA = ttg::getShapePerCTA(ty); int rank = shapePerCTA.size(); auto blockShape = ttng::getTMABlockShape(ty, /*packedSize=*/true, mode); @@ -1194,7 +1193,6 @@ struct AsyncTMACopyGlobalToLocalOpConversion auto kMsg = str_attr("msg"); auto kBlock = str_attr("block"); const auto numCopies = msgToOffset.getInDimSize(kMsg); - auto zero = b.i32_val(0); auto ctaId = nvgpu::ClusterCTAIdOp::create(rewriter, loc); // We multicast if the flag is on and the block layout has broadcasting bool multicast = op.getMulticast(); diff --git a/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp b/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp index 1c3fb1d4938f..9a7f9e03269b 100644 --- a/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp +++ b/third_party/proton/Dialect/lib/ProtonToProtonGPU/ProtonToProtonGPUPass.cpp @@ -142,8 +142,6 @@ LogicalResult replaceProtonRecordOp(OpBuilder &builder, FuncOp func, int getAllocSharedMemSize(int maxSharedMemSize, int sharedMemUsed, int segmentNum) { const int bytesPerEntry = gpu::getBytesPerClockEntry(); - const int wordsPerEntry = bytesPerEntry / 4; // 1 word = 4 bytes - const int circularHeaderSize = gpu::getCircularHeaderSize(); // byte size sharedMemUsed = llvm::alignTo(sharedMemUsed, bytesPerEntry); if (sharedMemUsed >= maxSharedMemSize) { // We just assume there's enough shared memory and error out if not during diff --git a/third_party/proton/csrc/Proton.cpp b/third_party/proton/csrc/Proton.cpp index f54d5a0b6105..f5e60e23569c 100644 --- a/third_party/proton/csrc/Proton.cpp +++ b/third_party/proton/csrc/Proton.cpp @@ -36,7 +36,6 @@ std::map convertPythonMetrics( } // namespace static void initProton(pybind11::module &&m) { - using ret = pybind11::return_value_policy; using namespace pybind11::literals; // Accept raw integer pointers from Python (e.g., Tensor.data_ptr()) instead diff --git a/third_party/proton/csrc/include/Session/Session.h b/third_party/proton/csrc/include/Session/Session.h index ec587b5f4065..087877198f67 100644 --- a/third_party/proton/csrc/include/Session/Session.h +++ b/third_party/proton/csrc/include/Session/Session.h @@ -36,11 +36,11 @@ class Session { Profiler *getProfiler() const { return profiler; } private: - Session(size_t id, const std::string &path, Profiler *profiler, + Session(const std::string &path, Profiler *profiler, std::unique_ptr contextSource, std::unique_ptr data) - : id(id), path(path), profiler(profiler), - contextSource(std::move(contextSource)), data(std::move(data)) {} + : path(path), profiler(profiler), contextSource(std::move(contextSource)), + data(std::move(data)) {} template std::vector getInterfaces() { std::vector interfaces; @@ -60,7 +60,6 @@ class Session { } const std::string path{}; - size_t id{}; Profiler *profiler{}; std::unique_ptr contextSource{}; std::unique_ptr data{}; @@ -136,7 +135,7 @@ class SessionManager : public Singleton { Profiler *validateAndSetProfilerMode(Profiler *profiler, const std::string &mode); - std::unique_ptr makeSession(size_t id, const std::string &path, + std::unique_ptr makeSession(const std::string &path, const std::string &profilerName, const std::string &contextSourceName, const std::string &dataName, diff --git a/third_party/proton/csrc/lib/Data/TraceData.cpp b/third_party/proton/csrc/lib/Data/TraceData.cpp index 0ea4a71e9893..3ae7b0144292 100644 --- a/third_party/proton/csrc/lib/Data/TraceData.cpp +++ b/third_party/proton/csrc/lib/Data/TraceData.cpp @@ -539,12 +539,6 @@ json buildCallStackJson(const std::vector &contexts) { return callStack; } -json buildCallStackJson(const Context &context) { - json callStack = json::array(); - callStack.push_back(context.name); - return callStack; -} - json buildFlexibleMetricsJson( const DataEntry::FlexibleMetricMap &flexibleMetrics) { json metrics = json::object(); diff --git a/third_party/proton/csrc/lib/Session/Session.cpp b/third_party/proton/csrc/lib/Session/Session.cpp index 74eeb0f55915..125fb9b764ee 100644 --- a/third_party/proton/csrc/lib/Session/Session.cpp +++ b/third_party/proton/csrc/lib/Session/Session.cpp @@ -88,15 +88,15 @@ Profiler *SessionManager::validateAndSetProfilerMode(Profiler *profiler, } std::unique_ptr SessionManager::makeSession( - size_t id, const std::string &path, const std::string &profilerName, + const std::string &path, const std::string &profilerName, const std::string &contextSourceName, const std::string &dataName, const std::string &mode) { auto *profiler = makeProfiler(profilerName); profiler = validateAndSetProfilerMode(profiler, mode); auto contextSource = makeContextSource(contextSourceName); auto data = makeData(dataName, path, contextSource.get()); - auto *session = new Session(id, path, profiler, std::move(contextSource), - std::move(data)); + auto *session = + new Session(path, profiler, std::move(contextSource), std::move(data)); return std::unique_ptr(session); } @@ -191,8 +191,8 @@ size_t SessionManager::addSession(const std::string &path, return sessionId; } auto sessionId = nextSessionId++; - auto newSession = makeSession(sessionId, path, profilerName, - contextSourceName, dataName, mode); + auto newSession = + makeSession(path, profilerName, contextSourceName, dataName, mode); sessionPaths[path] = sessionId; sessions[sessionId] = std::move(newSession); return sessionId;