@@ -844,9 +844,10 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
844844 return rewriter.notifyMatchFailure (
845845 storeScatterOp, " Store op must have a vector of offsets argument" );
846846 VectorType offsetsTy = cast<VectorType>(offsets.getType ());
847- if (offsetsTy.getRank () != 1 )
847+ VectorType maskTy = cast<VectorType>(storeScatterOp.getMask ().getType ());
848+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
848849 return rewriter.notifyMatchFailure (storeScatterOp,
849- " Expected 1D offsets vector" );
850+ " Expected 1D offsets and mask vector" );
850851 VectorType storeVecTy = cast<VectorType>(storeScatterOp.getValueType ());
851852 assert (storeVecTy.getRank () <= 2 &&
852853 " Expected at most 2D result at SG level" );
@@ -855,17 +856,45 @@ struct StoreDistribution final : public gpu::WarpDistributionPattern {
855856 distStoreVecTy = VectorType::Builder (storeVecTy).dropDim (0 );
856857 else // rank 1
857858 distStoreVecTy = VectorType::Builder (storeVecTy).setDim (0 , 1 );
858-
859- SmallVector<size_t > newRetIndices;
860- SmallVector<Value> operands = storeScatterOp->getOperands ();
861- SmallVector<Type> operandTypesToYield =
862- llvm::to_vector_of<Type>(storeScatterOp->getOperandTypes ());
863- operandTypesToYield[0 ] = distStoreVecTy;
864859 // Assume offset and mask producers will be distributed as well.
865- operandTypesToYield[ 2 ] =
860+ VectorType distOffsetsTy =
866861 VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
867- operandTypesToYield[ 3 ] = VectorType::get (
862+ VectorType distMaskTy = VectorType::get (
868863 {1 }, getElementTypeOrSelf (storeScatterOp.getMask ().getType ()));
864+ std::string layoutPayloadName =
865+ xegpu::getLayoutName (storeScatterOp->getOpOperand (0 ));
866+ std::string layoutOffsetsName =
867+ xegpu::getLayoutName (storeScatterOp->getOpOperand (2 ));
868+ std::string layoutMaskName =
869+ xegpu::getLayoutName (storeScatterOp->getOpOperand (3 ));
870+
871+ xegpu::LayoutAttr layoutPayload =
872+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutPayloadName);
873+ xegpu::LayoutAttr layoutOffsets =
874+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
875+ xegpu::LayoutAttr layoutMask =
876+ storeScatterOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
877+
878+ FailureOr<VectorType> distStoreVecByWarpOpOrFailure =
879+ getDistVecTypeBasedOnLaneLayout (layoutPayload, storeVecTy);
880+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
881+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
882+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
883+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
884+ if (failed (distStoreVecByWarpOpOrFailure) ||
885+ failed (distOffsetsByWarpOpOrFailure) ||
886+ failed (distMaskByWarpOpOrFailure)) {
887+ storeScatterOp.emitWarning (
888+ " Some vector operands have no layouts, using defaults instead." );
889+ }
890+ distStoreVecTy = distStoreVecByWarpOpOrFailure.value_or (distStoreVecTy);
891+ distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or (distOffsetsTy);
892+ distMaskTy = distMaskByWarpOpOrFailure.value_or (distMaskTy);
893+
894+ SmallVector<size_t > newRetIndices;
895+ SmallVector<Value> operands = storeScatterOp->getOperands ();
896+ SmallVector<Type> operandTypesToYield = {
897+ distStoreVecTy, operands[1 ].getType (), distOffsetsTy, distMaskTy};
869898
870899 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
871900 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -918,23 +947,47 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
918947 auto loadGatherOp =
919948 producedByLastLoad->get ().getDefiningOp <xegpu::LoadGatherOp>();
920949 auto offsets = loadGatherOp.getOffsets ();
921- if (!offsets || !isa<VectorType>(offsets.getType ()))
950+ if (!offsets || !isa<VectorType>(offsets.getType ()) ||
951+ !isa<VectorType>(loadGatherOp.getMask ().getType ()))
922952 return rewriter.notifyMatchFailure (
923- loadGatherOp, " Load op must have a vector of offsets argument" );
953+ loadGatherOp,
954+ " Load op must have a vector arguments for offsets and mask" );
924955 VectorType offsetsTy = cast<VectorType>(offsets.getType ());
925- if (offsetsTy.getRank () != 1 )
956+ VectorType maskTy = cast<VectorType>(loadGatherOp.getMask ().getType ());
957+ if (offsetsTy.getRank () != 1 || maskTy.getRank () != 1 )
926958 return rewriter.notifyMatchFailure (loadGatherOp,
927- " Expected 1D offsets vector" );
959+ " Expected 1D offsets and mask vector" );
960+ // Assume offset and mask producers will be distributed as well.
961+ VectorType distOffsetsTy =
962+ VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
963+ VectorType distMaskTy = VectorType::get ({1 }, getElementTypeOrSelf (maskTy));
964+
965+ std::string layoutOffsetsName =
966+ xegpu::getLayoutName (loadGatherOp->getOpOperand (1 ));
967+ std::string layoutMaskName =
968+ xegpu::getLayoutName (loadGatherOp->getOpOperand (2 ));
969+
970+ xegpu::LayoutAttr layoutOffsets =
971+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutOffsetsName);
972+ xegpu::LayoutAttr layoutMask =
973+ loadGatherOp->getAttrOfType <xegpu::LayoutAttr>(layoutMaskName);
974+
975+ FailureOr<VectorType> distOffsetsByWarpOpOrFailure =
976+ getDistVecTypeBasedOnLaneLayout (layoutOffsets, offsetsTy);
977+ FailureOr<VectorType> distMaskByWarpOpOrFailure =
978+ getDistVecTypeBasedOnLaneLayout (layoutMask, maskTy);
979+ if (failed (distOffsetsByWarpOpOrFailure) ||
980+ failed (distMaskByWarpOpOrFailure)) {
981+ loadGatherOp.emitWarning (
982+ " Some vector operands have no layouts, using defaults instead." );
983+ }
984+ distOffsetsTy = distOffsetsByWarpOpOrFailure.value_or (distOffsetsTy);
985+ distMaskTy = distMaskByWarpOpOrFailure.value_or (distMaskTy);
928986
929987 SmallVector<size_t > newRetIndices;
930988 SmallVector<Value> operands = loadGatherOp->getOperands ();
931- SmallVector<Type> operandTypesToYield =
932- llvm::to_vector_of<Type>(loadGatherOp->getOperandTypes ());
933- // Assume offset and mask producers will be distributed as well.
934- operandTypesToYield[1 ] =
935- VectorType::get ({1 }, getElementTypeOrSelf (offsetsTy));
936- operandTypesToYield[2 ] =
937- VectorType::get ({1 }, getElementTypeOrSelf (loadGatherOp.getMaskType ()));
989+ SmallVector<Type> operandTypesToYield = {operands[0 ].getType (),
990+ distOffsetsTy, distMaskTy};
938991
939992 gpu::WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
940993 rewriter, warpOp, operands, operandTypesToYield, newRetIndices);
@@ -951,6 +1004,7 @@ struct LoadDistribution final : public gpu::WarpDistributionPattern {
9511004 xegpu::LoadGatherOp newOp = rewriter.create <xegpu::LoadGatherOp>(
9521005 newWarpOp.getLoc (), loadVecTy, newLoadGatherOperands,
9531006 loadGatherOp->getAttrs ());
1007+ xegpu::removeLayoutAttrs (newOp);
9541008 Value distributedVal = newWarpOp.getResult (operandIdx);
9551009 rewriter.replaceAllUsesWith (distributedVal, newOp->getResult (0 ));
9561010 return success ();
@@ -990,7 +1044,7 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
9901044
9911045 // Vectors operands of these ops have a fixed and implicit layout.
9921046 if (isa<xegpu::LoadGatherOp, xegpu::StoreScatterOp>(op))
993- continue ;
1047+ continue ;
9941048 auto layout =
9951049 xegpu::getDistributeLayoutAttrOfType<xegpu::LayoutAttr>(operand);
9961050 if (!layout) {
0 commit comments