Skip to content

[X86] fix inefficient llvm.masked.gather mask generation#175385

Closed
folkertdev wants to merge 3 commits intollvm:mainfrom
folkertdev:portable-masked-gather
Closed

[X86] fix inefficient llvm.masked.gather mask generation#175385
folkertdev wants to merge 3 commits intollvm:mainfrom
folkertdev:portable-masked-gather

Conversation

@folkertdev
Copy link
Contributor

An (incomplete) attempt at fixing #59789.

The issue describes inefficient mask generation when using the portable masked gather intrinsic. I've replicated it here.

https://godbolt.org/z/h7b7c5Tb1

The issue seems to be how the mask bitmask is converted into a vector. Based on the logs that ultimately happens when masked_gather is lowered:

Optimized type-legalized selection DAG: %bb.0 'gather_portable_bits:'
SelectionDAG has 21 nodes:
  t0: ch,glue = EntryToken
      t17: v4i32 = BUILD_VECTOR Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>
            t4: i32,ch = CopyFromReg t0, Register:i32 %1
          t25: i8 = truncate t4
        t31: ch = store<(store (s8) into %stack.0), trunc to i4> t0, t25, FrameIndex:i64<0>, undef:i64
      t36: v4i32,ch = load<(load (s8) from %stack.0), sext from v4i1> t31, FrameIndex:i64<0>, undef:i64
      t7: i64,ch = CopyFromReg t0, Register:i64 %2
        t2: v4i32,ch = CopyFromReg t0, Register:v4i32 %0
      t10: v4i64 = zero_extend t2
    t24: v4i32,ch = masked_gather<(load unknown-size, align 4), unsigned scaled offset> t0, t17, t36, t7, t10, TargetConstant:i64<4>
  t22: ch,glue = CopyToReg t0, Register:v4i32 $xmm0, t24
  t23: ch = X86ISD::RET_GLUE t22, TargetConstant:i32<0>, Register:v4i32 $xmm0, t22:1

Vector-legalized selection DAG: %bb.0 'gather_portable_bits:'
SelectionDAG has 40 nodes:
  t0: ch,glue = EntryToken
      t17: v4i32 = BUILD_VECTOR Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>
            t41: i8 = and t39, Constant:i8<1>
          t42: i1 = truncate t41
        t43: i32 = sign_extend t42
              t44: i8 = srl t39, Constant:i8<1>
            t45: i8 = and t44, Constant:i8<1>
          t46: i1 = truncate t45
        t47: i32 = sign_extend t46
              t49: i8 = srl t39, Constant:i8<2>
            t50: i8 = and t49, Constant:i8<1>
          t51: i1 = truncate t50
        t52: i32 = sign_extend t51
              t54: i8 = srl t39, Constant:i8<3>
            t55: i8 = and t54, Constant:i8<1>
          t56: i1 = truncate t55
        t57: i32 = sign_extend t56
      t58: v4i32 = BUILD_VECTOR t43, t47, t52, t57
      t7: i64,ch = CopyFromReg t0, Register:i64 %2
        t2: v4i32,ch = CopyFromReg t0, Register:v4i32 %0
      t10: v4i64 = zero_extend t2
    t59: v4i32,ch = X86ISD::MGATHER<(load unknown-size, align 4)> t0, t17, t58, t7, t10, TargetConstant:i64<4>
  t22: ch,glue = CopyToReg t0, Register:v4i32 $xmm0, t59
        t4: i32,ch = CopyFromReg t0, Register:i32 %1
      t25: i8 = truncate t4
    t31: ch = store<(store (s8) into %stack.0), trunc to i4> t0, t25, FrameIndex:i64<0>, undef:i64
  t39: i8,ch = load<(load (s8) from %stack.0), anyext from i4> t31, FrameIndex:i64<0>, undef:i64
  t23: ch = X86ISD::RET_GLUE t22, TargetConstant:i32<0>, Register:v4i32 $xmm0, t22:1

What this branch implements is to much earlier rewrite:

t9: v4i1 = bitcast t8

Into

t25: i32 = zero_extend t8
t26: v4i32 = X86ISD::VBROADCAST t25
t32: v4i32 = and t26, t31
t33: v4i32 = X86ISD::PCMPEQ t32, t31

That transformation is sufficient to prevent the scalarization of the mask vector construction. This does work in the case of the original issue, but I'm not really sure whether it is the right approach. It also runs into issues with avx512 code, I think it just kind of breaks the assumption that the mask is a boolean vector.

So maybe this should actually be solved in LowerMGATHER? I think I'm just kind of missing something that would make this simpler and more robust.

@llvmbot
Copy link
Member

llvmbot commented Jan 10, 2026

@llvm/pr-subscribers-backend-x86

Author: Folkert de Vries (folkertdev)

Changes

An (incomplete) attempt at fixing #59789.

The issue describes inefficient mask generation when using the portable masked gather intrinsic. I've replicated it here.

https://godbolt.org/z/h7b7c5Tb1

The issue seems to be how the mask bitmask is converted into a vector. Based on the logs that ultimately happens when masked_gather is lowered:

Optimized type-legalized selection DAG: %bb.0 'gather_portable_bits:'
SelectionDAG has 21 nodes:
  t0: ch,glue = EntryToken
      t17: v4i32 = BUILD_VECTOR Constant:i32&lt;0&gt;, Constant:i32&lt;0&gt;, Constant:i32&lt;0&gt;, Constant:i32&lt;0&gt;
            t4: i32,ch = CopyFromReg t0, Register:i32 %1
          t25: i8 = truncate t4
        t31: ch = store&lt;(store (s8) into %stack.0), trunc to i4&gt; t0, t25, FrameIndex:i64&lt;0&gt;, undef:i64
      t36: v4i32,ch = load&lt;(load (s8) from %stack.0), sext from v4i1&gt; t31, FrameIndex:i64&lt;0&gt;, undef:i64
      t7: i64,ch = CopyFromReg t0, Register:i64 %2
        t2: v4i32,ch = CopyFromReg t0, Register:v4i32 %0
      t10: v4i64 = zero_extend t2
    t24: v4i32,ch = masked_gather&lt;(load unknown-size, align 4), unsigned scaled offset&gt; t0, t17, t36, t7, t10, TargetConstant:i64&lt;4&gt;
  t22: ch,glue = CopyToReg t0, Register:v4i32 $xmm0, t24
  t23: ch = X86ISD::RET_GLUE t22, TargetConstant:i32&lt;0&gt;, Register:v4i32 $xmm0, t22:1

Vector-legalized selection DAG: %bb.0 'gather_portable_bits:'
SelectionDAG has 40 nodes:
  t0: ch,glue = EntryToken
      t17: v4i32 = BUILD_VECTOR Constant:i32&lt;0&gt;, Constant:i32&lt;0&gt;, Constant:i32&lt;0&gt;, Constant:i32&lt;0&gt;
            t41: i8 = and t39, Constant:i8&lt;1&gt;
          t42: i1 = truncate t41
        t43: i32 = sign_extend t42
              t44: i8 = srl t39, Constant:i8&lt;1&gt;
            t45: i8 = and t44, Constant:i8&lt;1&gt;
          t46: i1 = truncate t45
        t47: i32 = sign_extend t46
              t49: i8 = srl t39, Constant:i8&lt;2&gt;
            t50: i8 = and t49, Constant:i8&lt;1&gt;
          t51: i1 = truncate t50
        t52: i32 = sign_extend t51
              t54: i8 = srl t39, Constant:i8&lt;3&gt;
            t55: i8 = and t54, Constant:i8&lt;1&gt;
          t56: i1 = truncate t55
        t57: i32 = sign_extend t56
      t58: v4i32 = BUILD_VECTOR t43, t47, t52, t57
      t7: i64,ch = CopyFromReg t0, Register:i64 %2
        t2: v4i32,ch = CopyFromReg t0, Register:v4i32 %0
      t10: v4i64 = zero_extend t2
    t59: v4i32,ch = X86ISD::MGATHER&lt;(load unknown-size, align 4)&gt; t0, t17, t58, t7, t10, TargetConstant:i64&lt;4&gt;
  t22: ch,glue = CopyToReg t0, Register:v4i32 $xmm0, t59
        t4: i32,ch = CopyFromReg t0, Register:i32 %1
      t25: i8 = truncate t4
    t31: ch = store&lt;(store (s8) into %stack.0), trunc to i4&gt; t0, t25, FrameIndex:i64&lt;0&gt;, undef:i64
  t39: i8,ch = load&lt;(load (s8) from %stack.0), anyext from i4&gt; t31, FrameIndex:i64&lt;0&gt;, undef:i64
  t23: ch = X86ISD::RET_GLUE t22, TargetConstant:i32&lt;0&gt;, Register:v4i32 $xmm0, t22:1

What this branch implements is to much earlier rewrite:

t9: v4i1 = bitcast t8

Into

t25: i32 = zero_extend t8
t26: v4i32 = X86ISD::VBROADCAST t25
t32: v4i32 = and t26, t31
t33: v4i32 = X86ISD::PCMPEQ t32, t31

That transformation is sufficient to prevent the scalarization of the mask vector construction. This does work in the case of the original issue, but I'm not really sure whether it is the right approach. It also runs into issues with avx512 code, I think it just kind of breaks the assumption that the mask is a boolean vector.

So maybe this should actually be solved in LowerMGATHER? I think I'm just kind of missing something that would make this simpler and more robust.


Patch is 33.37 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175385.diff

2 Files Affected:

  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+75-14)
  • (added) llvm/test/CodeGen/X86/masked_gather_scatter_portable.ll (+600)
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index 40ea3cb76bae4..92d7944d65a45 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -57397,14 +57397,18 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
   return SDValue();
 }
 
-static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
+static SDValue rebuildGatherScatter(SelectionDAG &DAG,
+                                    MaskedGatherScatterSDNode *GorS,
                                     SDValue Index, SDValue Base, SDValue Scale,
-                                    SelectionDAG &DAG) {
+                                    SDValue Mask = SDValue()) {
   SDLoc DL(GorS);
 
+  if (!Mask.getNode())
+    Mask = GorS->getMask();
+
   if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
-    SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
-                      Gather->getMask(), Base, Index, Scale } ;
+    SDValue Ops[] = {
+        Gather->getChain(), Gather->getPassThru(), Mask, Base, Index, Scale};
     return DAG.getMaskedGather(Gather->getVTList(),
                                Gather->getMemoryVT(), DL, Ops,
                                Gather->getMemOperand(),
@@ -57412,8 +57416,8 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
                                Gather->getExtensionType());
   }
   auto *Scatter = cast<MaskedScatterSDNode>(GorS);
-  SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
-                    Scatter->getMask(), Base, Index, Scale };
+  SDValue Ops[] = {
+      Scatter->getChain(), Scatter->getValue(), Mask, Base, Index, Scale};
   return DAG.getMaskedScatter(Scatter->getVTList(),
                               Scatter->getMemoryVT(), DL,
                               Ops, Scatter->getMemOperand(),
@@ -57460,7 +57464,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
                                          Index.getOperand(0), NewShAmt);
           SDValue NewScale =
               DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
-          return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
+          return rebuildGatherScatter(DAG, GorS, NewIndex, Base, NewScale);
         }
       }
     }
@@ -57478,7 +57482,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       // a split.
       if (SDValue TruncIndex =
               DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, NewVT, Index))
-        return rebuildGatherScatter(GorS, TruncIndex, Base, Scale, DAG);
+        return rebuildGatherScatter(DAG, GorS, TruncIndex, Base, Scale);
 
       // Shrink any sign/zero extends from 32 or smaller to larger than 32 if
       // there are sufficient sign bits. Only do this before legalize types to
@@ -57487,13 +57491,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
            Index.getOpcode() == ISD::ZERO_EXTEND) &&
           Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
         Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
-        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+        return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
       }
 
       // Shrink if we remove an illegal type.
       if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
         Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
-        return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+        return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
       }
     }
   }
@@ -57518,13 +57522,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
               SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
                                             DAG.getConstant(Adder, DL, PtrVT));
               SDValue NewIndex = Index.getOperand(1 - I);
-              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+              return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
             }
             // For non-constant cases, limit this to non-scaled cases.
             if (ScaleAmt == 1) {
               SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
               SDValue NewIndex = Index.getOperand(1 - I);
-              return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+              return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
             }
           }
         }
@@ -57539,7 +57543,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
           SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
                                          Index.getOperand(1 - I), Splat);
           SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
-          return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
+          return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
         }
       }
   }
@@ -57550,12 +57554,69 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
       MVT EltVT = IndexWidth > 32 ? MVT::i64 : MVT::i32;
       IndexVT = IndexVT.changeVectorElementType(*DAG.getContext(), EltVT);
       Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
-      return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
+      return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
     }
   }
 
   // With vector masks we only demand the upper bit of the mask.
   SDValue Mask = GorS->getMask();
+
+  // Replace a mask that looks like:
+  //
+  //   t9: v4i1 = bitcast t8
+  //
+  // With one that looks like:
+  //
+  //  t25: i32 = zero_extend t8
+  //  t26: v4i32 = X86ISD::VBROADCAST t25
+  //  t32: v4i32 = and t26, t31
+  //  t33: v4i32 = X86ISD::PCMPEQ t32, t31
+  //
+  // The default expansion from an integer to a mask vector generates a lot more
+  // instructions.
+  if (DCI.isBeforeLegalize()) {
+    EVT MaskVT = Mask.getValueType();
+
+    if (MaskVT.isVector() && MaskVT.getVectorElementType() == MVT::i1 &&
+        Mask.getOpcode() == ISD::BITCAST) {
+
+      SDValue Bits = Mask.getOperand(0);
+      if (Bits.getValueType().isScalarInteger()) {
+        unsigned NumElts = MaskVT.getVectorNumElements();
+        if (NumElts == 4 || NumElts == 8) {
+
+          EVT ValueVT = N->getValueType(0);
+          EVT IntMaskVT = ValueVT.changeVectorElementTypeToInteger();
+          if (!IntMaskVT.isSimple() || !TLI.isTypeLegal(IntMaskVT))
+            return SDValue();
+
+          MVT MaskVecVT = IntMaskVT.getSimpleVT();
+          MVT MaskEltVT = MaskVecVT.getVectorElementType();
+
+          if (MaskVecVT.getVectorNumElements() != NumElts)
+            return SDValue();
+
+          SDValue BitsElt = DAG.getZExtOrTrunc(Bits, DL, MaskEltVT);
+          SDValue Bc = DAG.getNode(X86ISD::VBROADCAST, DL, MaskVecVT, BitsElt);
+
+          SmallVector<SDValue, 8> Lanes;
+          Lanes.reserve(NumElts);
+          for (unsigned i = 0; i < NumElts; ++i) {
+            uint64_t Bit = 1ull << i;
+            Lanes.push_back(DAG.getConstant(Bit, DL, MaskEltVT));
+          }
+
+          SDValue LaneBits = DAG.getBuildVector(MaskVecVT, DL, Lanes);
+          SDValue And = DAG.getNode(ISD::AND, DL, MaskVecVT, Bc, LaneBits);
+          SDValue NewMask =
+              DAG.getNode(X86ISD::PCMPEQ, DL, MaskVecVT, And, LaneBits);
+
+          return rebuildGatherScatter(DAG, GorS, Index, Base, Scale, NewMask);
+        }
+      }
+    }
+  }
+
   if (Mask.getScalarValueSizeInBits() != 1) {
     APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
     if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) {
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter_portable.ll b/llvm/test/CodeGen/X86/masked_gather_scatter_portable.ll
new file mode 100644
index 0000000000000..016137ed7cc86
--- /dev/null
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter_portable.ll
@@ -0,0 +1,600 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6
+; RUN: llc -mtriple=x86_64-unknown-unknown -O3 -mattr=+avx2 -mcpu=skylake < %s | FileCheck %s --check-prefix=AVX2
+
+define <4 x i32> @gather_avx_dd_128(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dd_128:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vmovaps %xmm0, %xmm1
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    movq %rsi, %rdi
+; AVX2-NEXT:    movl $4, %esi
+; AVX2-NEXT:    jmp llvm.x86.avx2.gather.d.d.128@PLT # TAILCALL
+  %m4 = trunc i8 %maskbits to i4
+  %m  = bitcast i4 %m4 to <4 x i1>
+  %m32 = sext <4 x i1> %m to <4 x i32>
+  %res = tail call <4 x i32> @llvm.x86.avx2.gather.d.d.128(<4 x i32> zeroinitializer, ptr %data, <4 x i32> %indices, <4 x i32> %m32, i8 4)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @gather_portable_dd_128(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dd_128:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm0
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+  %m4 = trunc i8 %maskbits to i4
+  %m  = bitcast i4 %m4 to <4 x i1>
+  %idx64 = zext <4 x i32> %indices to <4 x i64>
+  %ptrs = getelementptr i32, ptr %data, <4 x i64> %idx64
+  %res = tail call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> %ptrs, i32 4, <4 x i1> %m, <4 x i32> zeroinitializer)
+  ret <4 x i32> %res
+}
+
+define <8 x i32> @gather_avx_dd_256(<8 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dd_256:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vmovd %edi, %xmm1
+; AVX2-NEXT:    vpbroadcastb %xmm1, %ymm1
+; AVX2-NEXT:    vmovdqa {{.*#+}} ymm2 = [1,2,4,8,16,32,64,128]
+; AVX2-NEXT:    vpand %ymm2, %ymm1, %ymm1
+; AVX2-NEXT:    vpcmpeqd %ymm2, %ymm1, %ymm2
+; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT:    vpgatherdd %ymm2, (%rsi,%ymm0,4), %ymm1
+; AVX2-NEXT:    vmovdqa %ymm1, %ymm0
+; AVX2-NEXT:    retq
+  %m  = bitcast i8 %maskbits to <8 x i1>
+  %m32 = sext <8 x i1> %m to <8 x i32>
+  %res = tail call <8 x i32> @llvm.x86.avx2.gather.d.d.256(<8 x i32> zeroinitializer, ptr %data, <8 x i32> %indices, <8 x i32> %m32, i8 4)
+  ret <8 x i32> %res
+}
+
+define <8 x i32> @gather_portable_dd_256(<8 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dd_256:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT:    vextracti128 $1, %ymm0, %xmm0
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm0 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT:    vmovd %edi, %xmm2
+; AVX2-NEXT:    vpbroadcastd %xmm2, %ymm2
+; AVX2-NEXT:    vmovdqa {{.*#+}} ymm3 = [1,2,4,8,16,32,64,128]
+; AVX2-NEXT:    vpand %ymm3, %ymm2, %ymm2
+; AVX2-NEXT:    vpcmpeqd %ymm3, %ymm2, %ymm2
+; AVX2-NEXT:    vextracti128 $1, %ymm2, %xmm3
+; AVX2-NEXT:    vpxor %xmm4, %xmm4, %xmm4
+; AVX2-NEXT:    vpxor %xmm5, %xmm5, %xmm5
+; AVX2-NEXT:    vpgatherqd %xmm3, (%rsi,%ymm0,4), %xmm5
+; AVX2-NEXT:    vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm4
+; AVX2-NEXT:    vinserti128 $1, %xmm5, %ymm4, %ymm0
+; AVX2-NEXT:    retq
+  %m  = bitcast i8 %maskbits to <8 x i1>
+  %idx64 = zext <8 x i32> %indices to <8 x i64>
+  %ptrs = getelementptr i32, ptr %data, <8 x i64> %idx64
+  %res = tail call <8 x i32> @llvm.masked.gather.v8i32.v8p0(<8 x ptr> %ptrs, i32 4, <8 x i1> %m, <8 x i32> zeroinitializer)
+  ret <8 x i32> %res
+}
+
+define <2 x i32> @gather_avx_qd_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_qd_128:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT:    vpbroadcastq {{.*#+}} xmm2 = [1,2,1,2]
+; AVX2-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    movq %rsi, %rdi
+; AVX2-NEXT:    movl $4, %esi
+; AVX2-NEXT:    jmp llvm.x86.avx2.gather.q.d.128@PLT # TAILCALL
+  %m2 = trunc i8 %maskbits to i2
+  %m  = bitcast i2 %m2 to <2 x i1>
+  %idx64 = zext <2 x i32> %indices to <2 x i64>
+  %m32 = sext <2 x i1> %m to <2 x i32>
+  %res = tail call <2 x i32> @llvm.x86.avx2.gather.q.d.128(<2 x i32> zeroinitializer, ptr %data, <2 x i64> %idx64, <2 x i32> %m32, i8 4)
+  ret <2 x i32> %res
+}
+
+define <2 x i32> @gather_portable_qd_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_qd_128:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    movl %edi, %eax
+; AVX2-NEXT:    andb $2, %al
+; AVX2-NEXT:    shrb %al
+; AVX2-NEXT:    andb $1, %dil
+; AVX2-NEXT:    vmovd %edi, %xmm1
+; AVX2-NEXT:    vpinsrb $8, %eax, %xmm1, %xmm1
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} xmm2 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT:    vpshufd {{.*#+}} xmm0 = xmm1[0,2,2,3]
+; AVX2-NEXT:    vpslld $31, %xmm0, %xmm1
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vpgatherqd %xmm1, (%rsi,%xmm2,4), %xmm0
+; AVX2-NEXT:    retq
+  %m2 = trunc i8 %maskbits to i2
+  %m  = bitcast i2 %m2 to <2 x i1>
+  %idx64 = zext <2 x i32> %indices to <2 x i64>
+  %ptrs = getelementptr i32, ptr %data, <2 x i64> %idx64
+  %res = tail call <2 x i32> @llvm.masked.gather.v2i32.v2p0(<2 x ptr> %ptrs, i32 4, <2 x i1> %m, <2 x i32> zeroinitializer)
+  ret <2 x i32> %res
+}
+
+define <4 x i32> @gather_avx_qd_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_qd_256:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm0
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+  %m4 = trunc i8 %maskbits to i4
+  %m  = bitcast i4 %m4 to <4 x i1>
+  %idx64 = zext <4 x i32> %indices to <4 x i64>
+  %m32 = sext <4 x i1> %m to <4 x i32>
+  %res = tail call <4 x i32> @llvm.x86.avx2.gather.q.d.256(<4 x i32> zeroinitializer, ptr %data, <4 x i64> %idx64, <4 x i32> %m32, i8 4)
+  ret <4 x i32> %res
+}
+
+define <4 x i32> @gather_portable_qd_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_qd_256:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [1,2,4,8]
+; AVX2-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    vpcmpeqd %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vpgatherqd %xmm2, (%rsi,%ymm1,4), %xmm0
+; AVX2-NEXT:    vzeroupper
+; AVX2-NEXT:    retq
+  %m4 = trunc i8 %maskbits to i4
+  %m  = bitcast i4 %m4 to <4 x i1>
+  %idx64 = zext <4 x i32> %indices to <4 x i64>
+  %ptrs = getelementptr i32, ptr %data, <4 x i64> %idx64
+  %res = tail call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> %ptrs, i32 4, <4 x i1> %m, <4 x i32> zeroinitializer)
+  ret <4 x i32> %res
+}
+
+define <2 x i64> @gather_avx_dq_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dq_128:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vmovaps %xmm0, %xmm1
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [1,2]
+; AVX2-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    vpcmpeqq %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    movq %rsi, %rdi
+; AVX2-NEXT:    movl $8, %esi
+; AVX2-NEXT:    jmp llvm.x86.avx2.gather.d.q.128@PLT # TAILCALL
+  %m2 = trunc i8 %maskbits to i2
+  %m  = bitcast i2 %m2 to <2 x i1>
+  %m64 = sext <2 x i1> %m to <2 x i64>
+  %res = tail call <2 x i64> @llvm.x86.avx2.gather.d.q.128(<2 x i64> zeroinitializer, ptr %data, <2 x i32> %indices, <2 x i64> %m64, i8 8)
+  ret <2 x i64> %res
+}
+
+define <2 x i64> @gather_portable_dq_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dq_128:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    movl %edi, %eax
+; AVX2-NEXT:    andl $1, %eax
+; AVX2-NEXT:    negq %rax
+; AVX2-NEXT:    vmovq %rax, %xmm1
+; AVX2-NEXT:    andb $2, %dil
+; AVX2-NEXT:    shrb %dil
+; AVX2-NEXT:    movzbl %dil, %eax
+; AVX2-NEXT:    negq %rax
+; AVX2-NEXT:    vmovq %rax, %xmm2
+; AVX2-NEXT:    vpunpcklqdq {{.*#+}} xmm1 = xmm1[0],xmm2[0]
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} xmm2 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vpgatherqq %xmm1, (%rsi,%xmm2,8), %xmm0
+; AVX2-NEXT:    retq
+  %m2 = trunc i8 %maskbits to i2
+  %m  = bitcast i2 %m2 to <2 x i1>
+  %idx64 = zext <2 x i32> %indices to <2 x i64>
+  %ptrs = getelementptr i64, ptr %data, <2 x i64> %idx64
+  %res = tail call <2 x i64> @llvm.masked.gather.v2i64.v2p0(<2 x ptr> %ptrs, i32 8, <2 x i1> %m, <2 x i64> zeroinitializer)
+  ret <2 x i64> %res
+}
+
+define <4 x i64> @gather_avx_dq_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_dq_256:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vmovd %edi, %xmm1
+; AVX2-NEXT:    vpbroadcastd %xmm1, %ymm1
+; AVX2-NEXT:    vmovdqa {{.*#+}} ymm2 = [1,2,4,8]
+; AVX2-NEXT:    vpand %ymm2, %ymm1, %ymm1
+; AVX2-NEXT:    vpcmpeqq %ymm2, %ymm1, %ymm2
+; AVX2-NEXT:    vpxor %xmm1, %xmm1, %xmm1
+; AVX2-NEXT:    vpgatherdq %ymm2, (%rsi,%xmm0,8), %ymm1
+; AVX2-NEXT:    vmovdqa %ymm1, %ymm0
+; AVX2-NEXT:    retq
+  %m4 = trunc i8 %maskbits to i4
+  %m  = bitcast i4 %m4 to <4 x i1>
+  %m64 = sext <4 x i1> %m to <4 x i64>
+  %res = tail call <4 x i64> @llvm.x86.avx2.gather.d.q.256(<4 x i64> zeroinitializer, ptr %data, <4 x i32> %indices, <4 x i64> %m64, i8 8)
+  ret <4 x i64> %res
+}
+
+define <4 x i64> @gather_portable_dq_256(<4 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_portable_dq_256:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} ymm1 = xmm0[0],zero,xmm0[1],zero,xmm0[2],zero,xmm0[3],zero
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %ymm0
+; AVX2-NEXT:    vmovdqa {{.*#+}} ymm2 = [1,2,4,8]
+; AVX2-NEXT:    vpand %ymm2, %ymm0, %ymm0
+; AVX2-NEXT:    vpcmpeqq %ymm2, %ymm0, %ymm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    vpgatherqq %ymm2, (%rsi,%ymm1,8), %ymm0
+; AVX2-NEXT:    retq
+  %m4 = trunc i8 %maskbits to i4
+  %m  = bitcast i4 %m4 to <4 x i1>
+  %idx64 = zext <4 x i32> %indices to <4 x i64>
+  %ptrs = getelementptr i64, ptr %data, <4 x i64> %idx64
+  %res = tail call <4 x i64> @llvm.masked.gather.v4i64.v4p0(<4 x ptr> %ptrs, i32 8, <4 x i1> %m, <4 x i64> zeroinitializer)
+  ret <4 x i64> %res
+}
+
+define <2 x i64> @gather_avx_qq_128(<2 x i32> %indices, i8 %maskbits, ptr noundef readonly %data) nounwind {
+; AVX2-LABEL: gather_avx_qq_128:
+; AVX2:       # %bb.0:
+; AVX2-NEXT:    vpmovzxdq {{.*#+}} xmm1 = xmm0[0],zero,xmm0[1],zero
+; AVX2-NEXT:    vmovd %edi, %xmm0
+; AVX2-NEXT:    vpbroadcastd %xmm0, %xmm0
+; AVX2-NEXT:    vmovdqa {{.*#+}} xmm2 = [1,2]
+; AVX2-NEXT:    vpand %xmm2, %xmm0, %xmm0
+; AVX2-NEXT:    vpcmpeqq %xmm2, %xmm0, %xmm2
+; AVX2-NEXT:    vpxor %xmm0, %xmm0, %xmm0
+; AVX2-NEXT:    movq %rsi, %rdi
+; AVX2-NEXT:    movl $8, %esi
+; AVX2-NEXT:    jmp llvm.x86.avx2.gather.q.q.128@PLT # TAILCALL
+  %m2 = trunc i8 %maskbits to i2
+  %m  = bitcast i2 %m2 to <2 x i1>
+  %idx64 = zext <2 x i32> %indices to <2 x i64>
+  %m64 = sext <2 x i1> %m to <2 x i6...
[truncated]

@github-actions
Copy link

github-actions bot commented Jan 10, 2026

🐧 Linux x64 Test Results

  • 188362 tests passed
  • 5005 tests skipped

✅ The build succeeded and all tests passed.

@github-actions
Copy link

github-actions bot commented Jan 10, 2026

🪟 Windows x64 Test Results

  • 129359 tests passed
  • 2861 tests skipped

✅ The build succeeded and all tests passed.

@folkertdev
Copy link
Contributor Author

Hmm, doing it in LowerMGATHER is not that simple either: the mask vector construction is lowered before the masked_gather, and that looks like a nasty pattern to match.

SelectionDAG has 42 nodes:
  t0: ch,glue = EntryToken
      t4: i32,ch = CopyFromReg t0, Register:i32 %1
    t25: i8 = truncate t4
  t31: ch = store<(store (s8) into %stack.0), trunc to i4> t0, t25, FrameIndex:i64<0>, undef:i64
  t36: v4i32,ch = load<(load (s8) from %stack.0), sext from v4i1> t31, FrameIndex:i64<0>, undef:i64
      t17: v4i32 = BUILD_VECTOR Constant:i32<0>, Constant:i32<0>, Constant:i32<0>, Constant:i32<0>
            t41: i8 = and t39, Constant:i8<1>
          t42: i1 = truncate t41
        t43: i32 = sign_extend t42
              t44: i8 = srl t39, Constant:i8<1>
            t45: i8 = and t44, Constant:i8<1>
          t46: i1 = truncate t45
        t47: i32 = sign_extend t46
              t49: i8 = srl t39, Constant:i8<2>
            t50: i8 = and t49, Constant:i8<1>
          t51: i1 = truncate t50
        t52: i32 = sign_extend t51
              t54: i8 = srl t39, Constant:i8<3>
            t55: i8 = and t54, Constant:i8<1>
          t56: i1 = truncate t55
        t57: i32 = sign_extend t56
      t58: v4i32 = BUILD_VECTOR t43, t47, t52, t57
      t7: i64,ch = CopyFromReg t0, Register:i64 %2
        t2: v4i32,ch = CopyFromReg t0, Register:v4i32 %0
      t10: v4i64 = zero_extend t2
    t24: v4i32,ch = masked_gather<(load unknown-size, align 4), unsigned scaled offset> t0, t17, t58, t7, t10, TargetConstant:i64<4>
  t22: ch,glue = CopyToReg t0, Register:v4i32 $xmm0, t24
  t39: i8,ch = load<(load (s8) from %stack.0), anyext from i4> t31, FrameIndex:i64<0>, undef:i64
  t40: i8 = Constant<0>
  t23: ch = X86ISD::RET_GLUE t22, TargetConstant:i32<0>, Register:v4i32 $xmm0, t22:1

@RKSimon RKSimon self-requested a review January 11, 2026 11:26
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please investigate the CodeGen/X86/masked_gather_scatter.ll failures

@folkertdev
Copy link
Contributor Author

I will, but could you provide some guidance on what the best place to tackle this is? in the combine, or in the lower?

@folkertdev folkertdev force-pushed the portable-masked-gather branch from 264d4aa to 44f95a1 Compare January 12, 2026 21:29
@folkertdev
Copy link
Contributor Author

So, a bit of a cop out, but the tests in that existing file all assume avx512. The inefficient mask generation is only a problem when avx512 is not available (it has those special mask registers). So the optimization in this PR now just does not run if avx512 is available.

I also experimented further with performing the optimization in lowerMGATHER and it was miserable.

@folkertdev folkertdev requested a review from RKSimon January 12, 2026 21:36
RKSimon added a commit to RKSimon/llvm-project that referenced this pull request Jan 13, 2026
RKSimon added a commit that referenced this pull request Jan 13, 2026
… for "fast-gather" avx2 targets (#175736)

Test coverage to help #175385
RKSimon added a commit to RKSimon/llvm-project that referenced this pull request Jan 13, 2026
…l vectors for masked load/store

Test coverage to help llvm#175385
RKSimon added a commit that referenced this pull request Jan 13, 2026
…l vectors for masked load/store (#175746)

Test coverage to help #175385
@RKSimon
Copy link
Collaborator

RKSimon commented Jan 13, 2026

I've improved the test coverage at #175736 and #175746 - masked load/stores are affected as well so ideally we need a more general fix as these are a lot more common than gathers on avx2.

RKSimon added a commit to RKSimon/llvm-project that referenced this pull request Jan 13, 2026
…nsion before it might get split by legalisation

Masked load/store/gathers often need to bitcast the mask from a bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization if we don't custom expand the mask first.

This patch uses the combineToExtendBoolVectorInReg helper function to canonicalise the masks, similar to what we already do for vselect expansion.

Alternative to llvm#175385

Fixes llvm#175385
@RKSimon
Copy link
Collaborator

RKSimon commented Jan 13, 2026

@folkertdev I ended up coming up with #175769, which I think should address the issue reusing existing code.

@folkertdev
Copy link
Contributor Author

Nice, thanks for looking into it!

@folkertdev folkertdev closed this Jan 13, 2026
@RKSimon
Copy link
Collaborator

RKSimon commented Jan 13, 2026

Nice, thanks for looking into it!

Sorry for poaching it :(

RKSimon added a commit that referenced this pull request Jan 14, 2026
…or extension before it might get split by legalisation (#175769)

Masked load/store/gathers often need to bitcast the mask from a
bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization
if we don't custom expand the mask first.

This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg
helper functions to canonicalise the masks, similar to what we already
do for vselect expansion.

Alternative to #175385

Fixes #59789
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Jan 18, 2026
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Jan 18, 2026
Priyanshu3820 pushed a commit to Priyanshu3820/llvm-project that referenced this pull request Jan 18, 2026
…or extension before it might get split by legalisation (llvm#175769)

Masked load/store/gathers often need to bitcast the mask from a
bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization
if we don't custom expand the mask first.

This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg
helper functions to canonicalise the masks, similar to what we already
do for vselect expansion.

Alternative to llvm#175385

Fixes llvm#59789
BStott6 pushed a commit to BStott6/llvm-project that referenced this pull request Jan 22, 2026
…or extension before it might get split by legalisation (llvm#175769)

Masked load/store/gathers often need to bitcast the mask from a
bitcasted integer.

On pre-AVX512 targets this can lead to some rather nasty scalarization
if we don't custom expand the mask first.

This patch uses the canonicalizeBoolMask /combineToExtendBoolVectorInReg
helper functions to canonicalise the masks, similar to what we already
do for vselect expansion.

Alternative to llvm#175385

Fixes llvm#59789
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants