Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2328,9 +2328,11 @@ class BuildVectorSDNode : public SDNode {

LLVM_ABI bool isConstant() const;

/// If this BuildVector is constant and represents the numerical series
/// "<a, a+n, a+2n, a+3n, ...>" where a is integer and n is a non-zero integer,
/// the value "<a,n>" is returned.
/// If this BuildVector is constant and represents an arithmetic sequence
/// "<a, a+n, a+2n, a+3n, ...>" where a is integer and n is a non-zero
/// integer, the value "<a, n>" is returned. Arithmetic is performed modulo
/// 2^BitWidth, so this also matches sequences that wrap around. Poison
/// elements are ignored and can take any value.
LLVM_ABI std::optional<std::pair<APInt, APInt>> isConstantSequence() const;

/// Recast bit data \p SrcBitElements to \p DstEltSizeInBits wide elements.
Expand Down
61 changes: 47 additions & 14 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14037,26 +14037,59 @@ BuildVectorSDNode::isConstantSequence() const {
if (NumOps < 2)
return std::nullopt;

if (!isa<ConstantSDNode>(getOperand(0)) ||
!isa<ConstantSDNode>(getOperand(1)))
return std::nullopt;

unsigned EltSize = getValueType(0).getScalarSizeInBits();
APInt Start = getConstantOperandAPInt(0).trunc(EltSize);
APInt Stride = getConstantOperandAPInt(1).trunc(EltSize) - Start;

if (Stride.isZero())
return std::nullopt;
APInt Start, Stride;
int FirstIdx = -1, SecondIdx = -1;

for (unsigned i = 2; i < NumOps; ++i) {
if (!isa<ConstantSDNode>(getOperand(i)))
// Find the first two non-undef constant elements to determine Start and
// Stride, then verify all remaining elements match the sequence.
for (unsigned I = 0; I < NumOps; ++I) {
SDValue Op = getOperand(I);
if (Op->isUndef())
continue;
if (!isa<ConstantSDNode>(Op))
return std::nullopt;

APInt Val = getConstantOperandAPInt(i).trunc(EltSize);
if (Val != (Start + (Stride * i)))
return std::nullopt;
APInt Val = getConstantOperandAPInt(I).trunc(EltSize);
if (FirstIdx < 0) {
FirstIdx = I;
Start = Val;
} else if (SecondIdx < 0) {
SecondIdx = I;
// Compute stride using modular arithmetic. Simple division would handle
// common strides (1, 2, -1, etc.), but modular inverse maximizes matches.
// Example: <0, poison, poison, 0xFF> has stride 0x55 since 3*0x55 = 0xFF
// Note that modular arithmetic is agnostic to signed/unsigned.
unsigned IdxDiff = I - FirstIdx;
APInt ValDiff = Val - Start;

// Step 1: Factor out common powers of 2 from IdxDiff and ValDiff.
unsigned CommonPow2Bits = llvm::countr_zero(IdxDiff);
if (ValDiff.countr_zero() < CommonPow2Bits)
return std::nullopt; // ValDiff not divisible by 2^CommonPow2Bits
IdxDiff >>= CommonPow2Bits;
ValDiff.lshrInPlace(CommonPow2Bits);

// Step 2: IdxDiff is now odd, so its inverse mod 2^EltSize exists.
// TODO: There are 2^CommonPow2Bits valid strides; currently we only try
// one, but we could try all candidates to handle more cases.
Stride = ValDiff * APInt(EltSize, IdxDiff).multiplicativeInverse();
if (Stride.isZero())
return std::nullopt;

// Step 3: Adjust Start based on the first defined element's index.
Start -= Stride * FirstIdx;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: this is confusing for readers, since it changes the meaning of Start depending on whether SecondIdx < 0 or not.

} else {
// Verify this element matches the sequence.
if (Val != Start + Stride * I)
return std::nullopt;
}
}

// Need at least two defined elements.
if (SecondIdx < 0)
return std::nullopt;

return std::make_pair(Start, Stride);
}

Expand Down
254 changes: 254 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-build-vector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,258 @@ define void @build_vector_no_stride_v4i64(ptr %a) #0 {
ret void
}

; Sequence with trailing poison elements.
define void @build_vector_trailing_poison_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_trailing_poison_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.s, #0, #3
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 0, i32 3, i32 6, i32 9, i32 12, i32 15, i32 poison, i32 poison>, ptr %a, align 4
ret void
}

; Sequence with leading poison elements.
define void @build_vector_leading_poison_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_leading_poison_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.s, #0, #3
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 poison, i32 poison, i32 6, i32 9, i32 12, i32 15, i32 18, i32 21>, ptr %a, align 4
ret void
}

; Sequence with poison elements in the middle.
define void @build_vector_middle_poison_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_middle_poison_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.s, #0, #3
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 0, i32 3, i32 poison, i32 poison, i32 12, i32 15, i32 18, i32 21>, ptr %a, align 4
ret void
}

; Sequence with poison elements scattered throughout.
define void @build_vector_scattered_poison_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_scattered_poison_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.s, #0, #3
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 poison, i32 3, i32 poison, i32 9, i32 poison, i32 15, i32 poison, i32 21>, ptr %a, align 4
ret void
}

; Sequence with only two defined elements (minimum required).
define void @build_vector_two_defined_v4i64(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_two_defined_v4i64:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.d, #5, #7
; VBITS_GE_256-NEXT: ptrue p0.d, vl4
; VBITS_GE_256-NEXT: st1d { z0.d }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <4 x i64> <i64 poison, i64 12, i64 poison, i64 26>, ptr %a, align 8
ret void
}

; Sequence with negative stride and poison elements.
define void @build_vector_neg_stride_poison_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_neg_stride_poison_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.s, #0, #-2
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 poison, i32 -2, i32 -4, i32 poison, i32 -8, i32 -10, i32 poison, i32 -14>, ptr %a, align 4
ret void
}

; Only one defined element - cannot determine stride, so no index instruction.
define void @build_vector_single_defined_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_single_defined_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: mov z0.s, #42 // =0x2a
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 poison, i32 poison, i32 poison, i32 42, i32 poison, i32 poison, i32 poison, i32 poison>, ptr %a, align 4
ret void
}

; Fractional stride: elements at indices 1 and 3 differ by 3, so stride would be 3/2.
define void @build_vector_fractional_stride_v8i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_fractional_stride_v8i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: ptrue p0.s, vl8
; VBITS_GE_256-NEXT: adrp x8, .LCPI12_0
; VBITS_GE_256-NEXT: add x8, x8, :lo12:.LCPI12_0
; VBITS_GE_256-NEXT: ld1w { z0.s }, p0/z, [x8]
; VBITS_GE_256-NEXT: st1w { z0.s }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i32> <i32 poison, i32 0, i32 poison, i32 3, i32 poison, i32 poison, i32 poison, i32 poison>, ptr %a, align 4
ret void
}

; zip1 pattern: constant <0, 1, 2, 3> is expanded to <0, 1, 2, 3, poison, poison, poison, poison>
; to match the shuffle result width. isConstantSequence recognizes this as a sequence.
define <8 x i8> @zip_const_seq_with_variable(i8 %x) #0 {
; VBITS_GE_256-LABEL: zip_const_seq_with_variable:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.b, #0, #1
; VBITS_GE_256-NEXT: dup v1.8b, w0
; VBITS_GE_256-NEXT: zip1 v0.8b, v0.8b, v1.8b
; VBITS_GE_256-NEXT: ret
%ins = insertelement <4 x i8> poison, i8 %x, i32 0
%splat = shufflevector <4 x i8> %ins, <4 x i8> poison, <4 x i32> zeroinitializer
%interleave = shufflevector <4 x i8> <i8 0, i8 1, i8 2, i8 3>, <4 x i8> %splat, <8 x i32> <i32 0, i32 4, i32 1, i32 5, i32 2, i32 6, i32 3, i32 7>
ret <8 x i8> %interleave
}

; zip2 pattern: constant <0, 1, 2, 3, 4, 5, 6, 7> is transformed by the DAG combiner to
; <poison, poison, poison, poison, 4, 5, 6, 7> since zip2 only uses elements 4-7.
define <8 x i8> @zip2_const_seq_with_variable(<8 x i8> %x) #0 {
; VBITS_GE_256-LABEL: zip2_const_seq_with_variable:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z1.b, #0, #1
; VBITS_GE_256-NEXT: zip2 v0.8b, v1.8b, v0.8b
; VBITS_GE_256-NEXT: ret
%interleave = shufflevector <8 x i8> <i8 0, i8 1, i8 2, i8 3, i8 4, i8 5, i8 6, i8 7>, <8 x i8> %x, <8 x i32> <i32 4, i32 12, i32 5, i32 13, i32 6, i32 14, i32 7, i32 15>
ret <8 x i8> %interleave
}

; Modular arithmetic: <0, poison, poison, 0xFF> has IdxDiff=3, ValDiff=0xFF.
; Stride = ValDiff * inverse(IdxDiff) mod 2^8 = 0xFF * 0xAB = 0x55.
; Verify: 0 + 3*85 = 255 mod 256.
define void @build_vector_mod_inverse_v4i8(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_mod_inverse_v4i8:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: mov w8, #85 // =0x55
; VBITS_GE_256-NEXT: ptrue p0.h, vl4
; VBITS_GE_256-NEXT: index z0.h, #0, w8
; VBITS_GE_256-NEXT: st1b { z0.h }, p0, [x0]
; VBITS_GE_256-NEXT: ret
store <4 x i8> <i8 0, i8 poison, i8 poison, i8 255>, ptr %a
ret void
}

; Modular arithmetic: <poison, 0, poison, poison, 0xFE, ...> has IdxDiff=3, ValDiff=0xFE.
; Stride = ValDiff * inverse(IdxDiff) mod 2^8 = 0xFE * 0xAB = 0xAA.
; Verify: 86 + 3*170 = 596 = 254 mod 256.
define void @build_vector_mod_inverse_v8i8_0xAA(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_mod_inverse_v8i8_0xAA:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: mov w8, #170 // =0xaa
; VBITS_GE_256-NEXT: index z0.b, #0, w8
; VBITS_GE_256-NEXT: add z0.b, z0.b, #86 // =0x56
; VBITS_GE_256-NEXT: str d0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i8> <i8 poison, i8 0, i8 poison, i8 poison, i8 254, i8 poison, i8 poison, i8 poison>, ptr %a
ret void
}

; Modular arithmetic: <poison, poison, 0, poison, poison, 0xFD, ...> has IdxDiff=3, ValDiff=0xFD.
; Stride = ValDiff * inverse(IdxDiff) mod 2^8 = 0xFD * 0xAB = 0xFF.
; Verify: 2 + 3*255 = 767 = 253 mod 256.
define void @build_vector_mod_inverse_v8i8_neg1(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_mod_inverse_v8i8_neg1:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: index z0.b, #2, #-1
; VBITS_GE_256-NEXT: str d0, [x0]
; VBITS_GE_256-NEXT: ret
store <8 x i8> <i8 poison, i8 poison, i8 0, i8 poison, i8 poison, i8 253, i8 poison, i8 poison>, ptr %a
ret void
}

; Modular arithmetic: <poison, 0xAA, poison, 0x54, poison, 0xFE, poison> has IdxDiff=2, ValDiff=0xAA.
; Stride = (ValDiff/2) * inverse(IdxDiff/2) mod 2^8 = 0x55 * 0x01 = 0x55.
; Verify: 85 + 1*85 = 170, 85 + 3*85 = 340 = 84, 85 + 5*85 = 510 = 254 mod 256.
define void @build_vector_mod_inverse_v7i8(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_mod_inverse_v7i8:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: mov w8, #85 // =0x55
; VBITS_GE_256-NEXT: index z0.b, #0, w8
; VBITS_GE_256-NEXT: add z0.b, z0.b, #85 // =0x55
; VBITS_GE_256-NEXT: mov h1, v0.h[2]
; VBITS_GE_256-NEXT: str s0, [x0]
; VBITS_GE_256-NEXT: str h1, [x0, #4]
; VBITS_GE_256-NEXT: ret
store <7 x i8> <i8 poison, i8 170, i8 poison, i8 84, i8 poison, i8 254, i8 poison>, ptr %a
ret void
}

; Modular arithmetic: <0, poison, poison, 0xFFFF> has IdxDiff=3, ValDiff=0xFFFF.
; Stride = ValDiff * inverse(IdxDiff) mod 2^16 = 0xFFFF * 0xAAAB = 0x5555.
; Verify: 0 + 3*21845 = 65535 mod 65536.
define void @build_vector_mod_inverse_i16(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_mod_inverse_i16:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: mov w8, #21845 // =0x5555
; VBITS_GE_256-NEXT: index z0.h, #0, w8
; VBITS_GE_256-NEXT: str d0, [x0]
; VBITS_GE_256-NEXT: ret
store <4 x i16> <i16 0, i16 poison, i16 poison, i16 -1>, ptr %a
ret void
}

; Modular arithmetic: <1, poison, poison, 0> has IdxDiff=3, ValDiff=0xFFFFFFFF.
; Stride = ValDiff * inverse(IdxDiff) mod 2^32 = 0xFFFFFFFF * 0xAAAAAAAB = 0x55555555.
; Verify: 1 + 3*1431655765 = 4294967296 = 0 mod 2^32.
define void @build_vector_mod_inverse_i32(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_mod_inverse_i32:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: mov w8, #1431655765 // =0x55555555
; VBITS_GE_256-NEXT: index z0.s, #1, w8
; VBITS_GE_256-NEXT: str q0, [x0]
; VBITS_GE_256-NEXT: ret
store <4 x i32> <i32 1, i32 poison, i32 poison, i32 0>, ptr %a
ret void
}

; TODO: Multiple stride candidates (simple): IdxDiff=2 gives 2 candidates {64, 192}.
; Val[2]=128, Val[3]=64. Stride 64 fails at index 3, stride 192 would work.
; Currently falls back since we only try one stride candidate.
define void @build_vector_multi_stride_2cand(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_multi_stride_2cand:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: fmov v0.4s, #4.00000000
; VBITS_GE_256-NEXT: str q0, [x0]
; VBITS_GE_256-NEXT: ret
store <16 x i8> <i8 0, i8 poison, i8 128, i8 64, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison>, ptr %a
ret void
}

; TODO: Multiple stride candidates (complex): IdxDiff=4 gives 4 candidates {2, 66, 130, 194}.
; Val[6]=140 filters to {66, 194}. Val[7]=78 filters to {194}. Stride 194 would work.
; Currently falls back since we only try one stride candidate.
define void @build_vector_multi_stride_4cand(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_multi_stride_4cand:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: adrp x8, .LCPI22_0
; VBITS_GE_256-NEXT: ldr q0, [x8, :lo12:.LCPI22_0]
; VBITS_GE_256-NEXT: str q0, [x0]
; VBITS_GE_256-NEXT: ret
store <16 x i8> <i8 0, i8 poison, i8 poison, i8 poison, i8 8, i8 poison, i8 140, i8 78, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison>, ptr %a
ret void
}

; Multiple stride candidates (failure): IdxDiff=4 gives 4 candidates {2, 66, 130, 194}.
; Val[5]=74 filters to {66}. Val[6]=12 requires {2, 130}. No stride satisfies both.
; Falls back to constant pool load since no valid stride exists.
define void @build_vector_multi_stride_fail(ptr %a) #0 {
; VBITS_GE_256-LABEL: build_vector_multi_stride_fail:
; VBITS_GE_256: // %bb.0:
; VBITS_GE_256-NEXT: adrp x8, .LCPI23_0
; VBITS_GE_256-NEXT: ldr q0, [x8, :lo12:.LCPI23_0]
; VBITS_GE_256-NEXT: str q0, [x0]
; VBITS_GE_256-NEXT: ret
store <16 x i8> <i8 0, i8 poison, i8 poison, i8 poison, i8 8, i8 74, i8 12, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison, i8 poison>, ptr %a
ret void
}

attributes #0 = { "target-features"="+sve" }