@@ -849,18 +849,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
849849 return rewriter.notifyMatchFailure (storeScatterOp,
850850 " Expected 1D offsets and mask vector" );
851851 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
852- assert (storeVecTy.getRank () <= 2 &&
853- " Expected at most 2D result at SG level" );
854- VectorType distStoreVecTy;
855- if (storeVecTy.getRank () == 2 )
856- distStoreVecTy = VectorType::Builder (storeVecTy).dropDim (0 );
857- else // rank 1
858- distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
859- // Assume offset and mask producers will be distributed as well.
860- VectorType distOffsetsTy =
861- VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
862- VectorType distMaskTy = VectorType::get (
863- {1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
852+ if (storeVecTy.getRank () > 2 )
853+ return rewriter.notifyMatchFailure (
854+ storeScatterOp, " Expected at most 2D result at SG level" );
855+
864856 std::string layoutPayloadName =
865857 xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
866858 std::string layoutOffsetsName =
@@ -884,17 +876,20 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
884876 if (failed (distStoreVecByWarpOpOrFailure) ||
885877 failed (distOffsetsByWarpOpOrFailure) ||
886878 failed (distMaskByWarpOpOrFailure)) {
887- storeScatterOp.emitWarning (
879+ return rewriter.notifyMatchFailure (
880+ storeScatterOp,
888881 " Some vector operands have no layouts, using defaults instead." );
889882 }
890- distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or (distStoreVecTy );
891- distOffsetsTy = distOffsetsByWarpOpOrFailure. value_or (distOffsetsTy);
892- distMaskTy = distMaskByWarpOpOrFailure. value_or (distMaskTy );
883+ VectorType distPayloadTy = distStoreVecByWarpOpOrFailure.value ( );
884+ VectorType expectedPayloadTy = VectorType::get (
885+ {distPayloadTy. getNumElements ()}, distPayloadTy. getElementType () );
893886
894887 SmallVector<size_t > newRetIndices;
895888 SmallVector<Value> operands = storeScatterOp->getOperands ();
896889 SmallVector<Type> operandTypesToYield = {
897- distStoreVecTy, operands[1 ].getType (), distOffsetsTy, distMaskTy};
890+ expectedPayloadTy, operands[1 ].getType (),
891+ distOffsetsByWarpOpOrFailure.value (),
892+ distMaskByWarpOpOrFailure.value ()};
898893
899894 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
900895 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -958,10 +953,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
958953 return rewriter.notifyMatchFailure (loadGatherOp,
959954 " Expected 1D offsets and mask vector" );
960955 // Assume offset and mask producers will be distributed as well.
961- VectorType distOffsetsTy =
962- VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
963- VectorType distMaskTy = VectorType::get ({1 }, getElementTypeOrSelf (maskTy));
964-
965956 std::string layoutOffsetsName =
966957 xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
967958 std::string layoutMaskName =
@@ -978,16 +969,16 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
978969 getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
979970 if (failed (distOffsetsByWarpOpOrFailure) ||
980971 failed (distMaskByWarpOpOrFailure)) {
981- loadGatherOp.emitWarning (
972+ return rewriter.notifyMatchFailure (
973+ loadGatherOp,
982974 " Some vector operands have no layouts, using defaults instead." );
983975 }
984- distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or (distOffsetsTy);
985- distMaskTy = distMaskByWarpOpOrFailure.value_or (distMaskTy);
986976
987977 SmallVector<size_t > newRetIndices;
988978 SmallVector<Value> operands = loadGatherOp->getOperands ();
989- SmallVector<Type> operandTypesToYield = {operands[0 ].getType (),
990- distOffsetsTy, distMaskTy};
979+ SmallVector<Type> operandTypesToYield = {
980+ operands[0 ].getType (), distOffsetsByWarpOpOrFailure.value (),
981+ distMaskByWarpOpOrFailure.value ()};
991982
992983 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
993984 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -998,7 +989,6 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
998989 const unsigned operandIdx = producedByLastLoad->getOperandNumber ();
999990 VectorType loadVecTy =
1000991 cast<VectorType>(warpOp.getResult (operandIdx).getType ());
1001- assert (loadVecTy.getRank () == 1 && " Expected a distributed vector" );
1002992
1003993 rewriter.setInsertionPointAfter (newWarpOp);
1004994 xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
0 commit comments