Skip to content

Conversation

@Arghnews
Copy link
Contributor

@Arghnews Arghnews commented Aug 27, 2025

Fixes #154293

Implement VectorExprEvaluator::VisitCallExpr constexpr support for left, right, arithmetic shift for MMX/SSE/AVX2/AVX512 intrinsics

_mm*_slli_epi*
_mm*_srli_epi*
_mm*_srai_epi*
_mm*_mask_slli_epi*
_mm*_maskz_slli_epi*

NOTE: not all intrinsics have all widths i.e.
_mm_srli_pi32 doesn't have pi64 etc.

Let me know any other feedback as it stands, thanks

Implements (full list):

_mm256_mask_slli_epi16
_mm256_mask_slli_epi32
_mm256_mask_slli_epi64
_mm256_mask_srai_epi16
_mm256_mask_srai_epi32
_mm256_mask_srai_epi64
_mm256_mask_srli_epi16
_mm256_mask_srli_epi32
_mm256_mask_srli_epi64
_mm256_maskz_slli_epi16
_mm256_maskz_slli_epi32
_mm256_maskz_slli_epi64
_mm256_maskz_srai_epi16
_mm256_maskz_srai_epi32
_mm256_maskz_srai_epi64
_mm256_maskz_srli_epi16
_mm256_maskz_srli_epi32
_mm256_maskz_srli_epi64
_mm256_slli_epi16
_mm256_slli_epi32
_mm256_slli_epi64
_mm256_srai_epi16
_mm256_srai_epi32
_mm256_srai_epi64
_mm256_srli_epi16
_mm256_srli_epi32
_mm256_srli_epi64
_mm512_mask_slli_epi16
_mm512_mask_slli_epi32
_mm512_mask_slli_epi64
_mm512_mask_srai_epi16
_mm512_mask_srai_epi32
_mm512_mask_srai_epi64
_mm512_mask_srli_epi16
_mm512_mask_srli_epi32
_mm512_mask_srli_epi64
_mm512_maskz_slli_epi16
_mm512_maskz_slli_epi32
_mm512_maskz_slli_epi64
_mm512_maskz_srai_epi16
_mm512_maskz_srai_epi32
_mm512_maskz_srai_epi64
_mm512_maskz_srli_epi16
_mm512_maskz_srli_epi32
_mm512_maskz_srli_epi64
_mm512_slli_epi16
_mm512_slli_epi32
_mm512_slli_epi64
_mm512_srai_epi16
_mm512_srai_epi32
_mm512_srai_epi64
_mm512_srli_epi16
_mm512_srli_epi32
_mm512_srli_epi64
_mm_slli_epi16
_mm_slli_epi32
_mm_slli_epi64
_mm_slli_pi16
_mm_slli_pi32
_mm_srai_epi16
_mm_srai_epi32
_mm_srai_epi64
_mm_srai_pi16
_mm_srai_pi32
_mm_srli_epi16
_mm_srli_epi32
_mm_srli_epi64
_mm_srli_pi16

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics labels Aug 27, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-clang

Author: Justin Riddell (Arghnews)

Changes

Implement VectorExprEvaluator::VisitCallExpr constexpr support for left, right, arithmetic shift for MMX/SSE/AVX2/AVX512 intrinsics

NOTE: currently incomplete, asking a question with what I've got so far:

I think I need to split this out into test files with specific -target-features. I believe the sort of thing I need is as listed below? Or can it be done more simply than this? Certain test files already have these target-features, although for example avx512dq-builtins-constrained.c has -triple=x86_64-apple-darwin, is this just about spitting out some kind of non-backend code that doesn't matter for this though? I'm not sure, would appreciate direction here, and the exact kind of // RUN: %clang_cc1 line I should be using, how restrictive does the +avx512 need to be for example, do I need REQUIRES: x86-registered-target or is that only if the actual backend code is being spit out (I sort of understand what this means)

MMX (_mm_*_pi*): -mmmx

128-bit epi16/32/64: -msse2 (except _srai_epi64: needs -mavx512dq -mavx512vl)

AVX2 (256-bit): -mavx2 (except _srai_epi64: needs -mavx512dq -mavx512vl)

AVX-512 512-bit:

epi16 → -mavx512bw

epi32/epi64 → -mavx512f

srai epi64 → add -mavx512dq

256-/128-bit masked forms → -mavx512vl

--

_mm*_slli_epi*
_mm*_srli_epi*
_mm*_srai_epi*
_mm*_mask_slli_epi*
_mm*_maskz_slli_epi*

NOTE: not all intrinsics have all widths i.e.
_mm_srli_pi32 doesn't have pi64 etc.

Also, current test has some tests that are split over multiple lines for readability. This doesn't seem common in the repo for other tests, but it's incredibly hard to read without it, I've only left it like that for test cases that are hard to read

Let me know any other feedback as it stands, thanks

@RKSimon

Implements (full list):

_mm256_mask_slli_epi16
_mm256_mask_slli_epi32
_mm256_mask_slli_epi64
_mm256_mask_srai_epi16
_mm256_mask_srai_epi32
_mm256_mask_srai_epi64
_mm256_mask_srli_epi16
_mm256_mask_srli_epi32
_mm256_mask_srli_epi64
_mm256_maskz_slli_epi16
_mm256_maskz_slli_epi32
_mm256_maskz_slli_epi64
_mm256_maskz_srai_epi16
_mm256_maskz_srai_epi32
_mm256_maskz_srai_epi64
_mm256_maskz_srli_epi16
_mm256_maskz_srli_epi32
_mm256_maskz_srli_epi64
_mm256_slli_epi16
_mm256_slli_epi32
_mm256_slli_epi64
_mm256_srai_epi16
_mm256_srai_epi32
_mm256_srai_epi64
_mm256_srli_epi16
_mm256_srli_epi32
_mm256_srli_epi64
_mm512_mask_slli_epi16
_mm512_mask_slli_epi32
_mm512_mask_slli_epi64
_mm512_mask_srai_epi16
_mm512_mask_srai_epi32
_mm512_mask_srai_epi64
_mm512_mask_srli_epi16
_mm512_mask_srli_epi32
_mm512_mask_srli_epi64
_mm512_maskz_slli_epi16
_mm512_maskz_slli_epi32
_mm512_maskz_slli_epi64
_mm512_maskz_srai_epi16
_mm512_maskz_srai_epi32
_mm512_maskz_srai_epi64
_mm512_maskz_srli_epi16
_mm512_maskz_srli_epi32
_mm512_maskz_srli_epi64
_mm512_slli_epi16
_mm512_slli_epi32
_mm512_slli_epi64
_mm512_srai_epi16
_mm512_srai_epi32
_mm512_srai_epi64
_mm512_srli_epi16
_mm512_srli_epi32
_mm512_srli_epi64
_mm_slli_epi16
_mm_slli_epi32
_mm_slli_epi64
_mm_slli_pi16
_mm_slli_pi32
_mm_srai_epi16
_mm_srai_epi32
_mm_srai_epi64
_mm_srai_pi16
_mm_srai_pi32
_mm_srli_epi16
_mm_srli_epi32
_mm_srli_epi64
_mm_srli_pi16

Patch is 91.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155542.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsX86.td (+51-32)
  • (modified) clang/lib/AST/ExprConstant.cpp (+200-52)
  • (modified) clang/lib/Headers/avx2intrin.h (+14-18)
  • (modified) clang/lib/Headers/avx512bwintrin.h (+16-21)
  • (modified) clang/lib/Headers/avx512fintrin.h (+30-41)
  • (modified) clang/lib/Headers/avx512vlbwintrin.h (+20-26)
  • (modified) clang/lib/Headers/avx512vlintrin.h (+35-46)
  • (modified) clang/lib/Headers/emmintrin.h (+13-13)
  • (modified) clang/lib/Headers/mmintrin.h (+21-31)
  • (added) clang/test/CodeGen/X86/shift-immediate-constexpr.c (+441)
diff --git a/clang/include/clang/Basic/BuiltinsX86.td b/clang/include/clang/Basic/BuiltinsX86.td
index 527acd9ef086e..fb194fbb667fa 100644
--- a/clang/include/clang/Basic/BuiltinsX86.td
+++ b/clang/include/clang/Basic/BuiltinsX86.td
@@ -275,15 +275,8 @@ let Features = "sse2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] i
   def psrlq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
   def psllw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
   def pslld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
-  def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
-  def psllwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
-  def pslldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
-  def psllqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
-  def psrlwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
-  def psrldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
-  def psrlqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
-  def psrawi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
-  def psradi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
+  def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long "
+                            "int>, _Vector<2, long long int>)">;
   def pmaddwd128 : X86Builtin<"_Vector<4, int>(_Vector<8, short>, _Vector<8, short>)">;
   def pslldqi128_byteshift : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant int)">;
   def psrldqi128_byteshift : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant int)">;
@@ -291,6 +284,18 @@ let Features = "sse2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] i
 
 let Features = "sse2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
   def pmuludq128 : X86Builtin<"_Vector<2, long long int>(_Vector<4, int>, _Vector<4, int>)">;
+
+  def psllwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
+  def pslldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
+  def psllqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
+
+  def psrlwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
+  def psrldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
+  def psrlqi128
+      : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
+
+  def psrawi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
+  def psradi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
 }
 
 let Features = "sse3", Attributes = [NoThrow] in {
@@ -595,23 +600,20 @@ let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] i
   def psignb256 : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>)">;
   def psignw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<16, short>)">;
   def psignd256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
-  def psllwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
   def psllw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
   def pslldqi256_byteshift : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
-  def pslldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
   def pslld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
-  def psllqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
-  def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
-  def psrawi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
-  def psraw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
-  def psradi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+  def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long "
+                            "int>, _Vector<2, long long int>)">;
+  def psraw256
+      : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
   def psrad256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
-  def psrldqi256_byteshift : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
-  def psrlwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
-  def psrlw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
-  def psrldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
-  def psrld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
-  def psrlqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
+  def psrldqi256_byteshift : X86Builtin<"_Vector<4, long long int>(_Vector<4, "
+                                        "long long int>, _Constant int)">;
+  def psrlw256
+      : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
+  def psrld256
+      : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
   def psrlq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
   def pblendd128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>, _Constant int)">;
   def pblendd256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>, _Constant int)">;
@@ -628,6 +630,17 @@ let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWi
   def pmuldq256 : X86Builtin<"_Vector<4, long long int>(_Vector<8, int>, _Vector<8, int>)">;
   def pmuludq256 : X86Builtin<"_Vector<4, long long int>(_Vector<8, int>, _Vector<8, int>)">;
 
+  def psllwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
+  def pslldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+  def psllqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
+
+  def psrlwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
+  def psrldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+  def psrlqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
+
+  def psrawi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
+  def psradi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+
   def pmulhuw256 : X86Builtin<"_Vector<16, unsigned short>(_Vector<16, unsigned short>, _Vector<16, unsigned short>)">;
   def pmulhw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<16, short>)">;
 
@@ -2098,7 +2111,6 @@ let Features = "avx512bw,evex512", Attributes = [NoThrow, Const, RequiredVectorW
   def pshuflw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Constant int)">;
   def psllv32hi : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<32, short>)">;
   def psllw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
-  def psllwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
 }
 
 let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
@@ -2109,7 +2121,8 @@ let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, RequiredVector
   def psllv8hi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
 }
 
-let Features = "avx512f,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
+let Features = "avx512f,evex512", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
+  def psllwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
   def pslldi512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, int)">;
   def psllqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
 }
@@ -2126,7 +2139,9 @@ let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, RequiredVector
   def psrlv8hi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
 }
 
-let Features = "avx512f,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
+let Features = "avx512f,evex512",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
+  def psrlwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
   def psrldi512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, int)">;
   def psrlqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
 }
@@ -2152,10 +2167,10 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256
 }
 
 let Features = "avx512bw,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
-  def psraw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
-  def psrawi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
-  def psrlw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
-  def psrlwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
+  def psraw512
+      : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
+  def psrlw512
+      : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
   def pslldqi512_byteshift : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Constant int)">;
   def psrldqi512_byteshift : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Constant int)">;
 }
@@ -2487,7 +2502,9 @@ let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<128>
   def scalefss_round_mask : X86Builtin<"_Vector<4, float>(_Vector<4, float>, _Vector<4, float>, _Vector<4, float>, unsigned char, _Constant int)">;
 }
 
-let Features = "avx512f,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
+let Features = "avx512f,evex512",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
+  def psrawi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
   def psradi512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, int)">;
   def psraqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
 }
@@ -2500,11 +2517,13 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256
   def psraq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
 }
 
-let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
+let Features = "avx512vl",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
   def psraqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
 }
 
-let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
+let Features = "avx512vl",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
   def psraqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
 }
 
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 19703e40d2696..85a3283568fdd 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11621,6 +11621,7 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
   case clang::X86::BI__builtin_ia32_pmulhw128:
   case clang::X86::BI__builtin_ia32_pmulhw256:
   case clang::X86::BI__builtin_ia32_pmulhw512:
+
   case clang::X86::BI__builtin_ia32_psllv2di:
   case clang::X86::BI__builtin_ia32_psllv4di:
   case clang::X86::BI__builtin_ia32_psllv4si:
@@ -11630,7 +11631,41 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
   case clang::X86::BI__builtin_ia32_psrlv2di:
   case clang::X86::BI__builtin_ia32_psrlv4di:
   case clang::X86::BI__builtin_ia32_psrlv4si:
-  case clang::X86::BI__builtin_ia32_psrlv8si:{
+  case clang::X86::BI__builtin_ia32_psrlv8si:
+
+  // Logical left shift by immediate
+  case clang::X86::BI__builtin_ia32_psllwi128:
+  case clang::X86::BI__builtin_ia32_pslldi128:
+  case clang::X86::BI__builtin_ia32_psllqi128:
+  case clang::X86::BI__builtin_ia32_psllwi256:
+  case clang::X86::BI__builtin_ia32_pslldi256:
+  case clang::X86::BI__builtin_ia32_psllqi256:
+  case clang::X86::BI__builtin_ia32_psllwi512:
+  case clang::X86::BI__builtin_ia32_pslldi512:
+  case clang::X86::BI__builtin_ia32_psllqi512:
+
+  // Logical right shift by immediate
+  case clang::X86::BI__builtin_ia32_psrlwi128:
+  case clang::X86::BI__builtin_ia32_psrldi128:
+  case clang::X86::BI__builtin_ia32_psrlqi128:
+  case clang::X86::BI__builtin_ia32_psrlwi256:
+  case clang::X86::BI__builtin_ia32_psrldi256:
+  case clang::X86::BI__builtin_ia32_psrlqi256:
+  case clang::X86::BI__builtin_ia32_psrlwi512:
+  case clang::X86::BI__builtin_ia32_psrldi512:
+  case clang::X86::BI__builtin_ia32_psrlqi512:
+
+  // Arithmetic right shift by immediate
+  case clang::X86::BI__builtin_ia32_psrawi128:
+  case clang::X86::BI__builtin_ia32_psradi128:
+  case clang::X86::BI__builtin_ia32_psraqi128:
+  case clang::X86::BI__builtin_ia32_psrawi256:
+  case clang::X86::BI__builtin_ia32_psradi256:
+  case clang::X86::BI__builtin_ia32_psraqi256:
+  case clang::X86::BI__builtin_ia32_psrawi512:
+  case clang::X86::BI__builtin_ia32_psradi512:
+  case clang::X86::BI__builtin_ia32_psraqi512: {
+
     APValue SourceLHS, SourceRHS;
     if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
         !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
@@ -11644,64 +11679,177 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
 
     for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
       APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt();
-      APSInt RHS = SourceRHS.getVectorElt(EltNum).getInt();
-      switch (E->getBuiltinCallee()) {
-      case Builtin::BI__builtin_elementwise_add_sat:
-        ResultElements.push_back(APValue(
-            APSInt(LHS.isSigned() ? LHS.sadd_sat(RHS) : LHS.uadd_sat(RHS),
-                   DestUnsigned)));
-        break;
-      case Builtin::BI__builtin_elementwise_sub_sat:
-        ResultElements.push_back(APValue(
-            APSInt(LHS.isSigned() ? LHS.ssub_sat(RHS) : LHS.usub_sat(RHS),
-                   DestUnsigned)));
-        break;
-      case clang::X86::BI__builtin_ia32_pmulhuw128:
-      case clang::X86::BI__builtin_ia32_pmulhuw256:
-      case clang::X86::BI__builtin_ia32_pmulhuw512:
-        ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhu(LHS, RHS),
-                                                /*isUnsigned=*/true)));
-        break;
-      case clang::X86::BI__builtin_ia32_pmulhw128:
-      case clang::X86::BI__builtin_ia32_pmulhw256:
-      case clang::X86::BI__builtin_ia32_pmulhw512:
-        ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhs(LHS, RHS),
-                                                /*isUnsigned=*/false)));
-        break;
-      case clang::X86::BI__builtin_ia32_psllv2di:
-      case clang::X86::BI__builtin_ia32_psllv4di:
-      case clang::X86::BI__builtin_ia32_psllv4si:
-      case clang::X86::BI__builtin_ia32_psllv8si:
-        if (RHS.uge(RHS.getBitWidth())) {
-          ResultElements.push_back(
-              APValue(APSInt(APInt::getZero(RHS.getBitWidth()), DestUnsigned)));
+
+      if (SourceRHS.isInt()) {
+        uint64_t LaneWidth = 0;
+        bool IsLeftShift = false;
+        bool IsRightShift = false;
+        bool IsArithmeticRightShift = false;
+
+        switch (E->getBuiltinCallee()) {
+        case clang::X86::BI__builtin_ia32_psllwi128:
+        case clang::X86::BI__builtin_ia32_psllwi256:
+        case clang::X86::BI__builtin_ia32_psllwi512:
+          IsLeftShift = true;
+          LaneWidth = 16;
+          break;
+        case clang::X86::BI__builtin_ia32_pslldi128:
+        case clang::X86::BI__builtin_ia32_pslldi256:
+        case clang::X86::BI__builtin_ia32_pslldi512:
+          IsLeftShift = true;
+          LaneWidth = 32;
           break;
+        case clang::X86::BI__builtin_ia32_psllqi128:
+        case clang::X86::BI__builtin_ia32_psllqi256:
+        case clang::X86::BI__builtin_ia32_psllqi512:
+          IsLeftShift = true;
+          LaneWidth = 64;
+          break;
+
+        case clang::X86::BI__builtin_ia32_psrlwi128:
+        case clang::X86::BI__builtin_ia32_psrlwi256:
+        case clang::X86::BI__builtin_ia32_psrlwi512:
+          IsRightShift = true;
+          LaneWidth = 16;
+          break;
+        case clang::X86::BI__builtin_ia32_psrldi128:
+        case clang::X86::BI__builtin_ia32_psrldi256:
+        case clang::X86::BI__builtin_ia32_psrldi512:
+          IsRightShift = true;
+          LaneWidth = 32;
+          break;
+        case clang::X86::BI__builtin_ia32_psrlqi128:
+        case clang::X86::BI__builtin_ia32_psrlqi256:
+        case clang::X86::BI__builtin_ia32_psrlqi512:
+          IsRightShift = true;
+          LaneWidth = 64;
+          break;
+
+        case clang::X86::BI__builtin_ia32_psrawi128:
+        case clang::X86::BI__builtin_ia32_psrawi256:
+        case clang::X86::BI__builtin_ia32_psrawi512:
+          IsArithmeticRightShift = true;
+          LaneWidth = 16;
+          break;
+        case clang::X86::BI__builtin_ia32_psradi128:
+        case clang::X86::BI__builtin_ia32_psradi256:
+        case clang::X86::BI__builtin_ia32_psradi512:
+          IsArithmeticRightShift = true;
+          LaneWidth = 32;
+          break;
+        case clang::X86::BI__builtin_ia32_psraqi128:
+        case clang::X86::BI__builtin_ia32_psraqi256:
+        case clang::X86::BI__builtin_ia32_psraqi512:
+          IsArithmeticRightShift = true;
+          LaneWidth = 64;
+          break;
+
+        default:
+          llvm_unreachable("Unexpected builtin callee");
         }
-        ResultElements.push_back(
-            APValue(APSInt(LHS.shl(RHS.getZExtValue()), DestUnsigned)));
-        break;
-      case clang::X86::BI__builtin_ia32_psrav4si:
-      case clang::X86::BI__builtin_ia32_psrav8si:
-        if (RHS.uge(RHS.getBitWidth())) {
+
+        const APSInt RHS = SourceRHS.getInt();
+        const auto ShiftAmount = RHS.getZExtValue();
+        APInt ResultOut;
+        if (IsArithmeticRightShift) {
+          ResultOut = LHS.ashr(std::min(ShiftAmount, LaneWidth));
+        } else if (ShiftAmount >= LaneWidth) {
+          ResultOut = APInt(LaneWidth, 0);
+        } else if (IsLeftShift) {
+          ResultOut = LHS.shl(ShiftAmount);
+        } else if (IsRightShift) {
+          ResultOut = LHS.lshr(ShiftAmount);
+        } else {
+          llvm_unreachable("Invalid shift type");
+        }
+        ResultElements.push_back(APValue(APSInt(
+            std::move(ResultOut),
+            /*isUnsigned=*/DestEltTy->isUnsignedIntegerOrEnumerationType())));
+      } else {
+        APSInt RHS = SourceRHS.getVectorElt(EltNum).getInt();
+        switch (E->getBuiltinCallee()) {
+        case Builtin::BI__builtin_elementwise_add_sat:
+          ResultElements.push_back(APValue(
+              APSInt(LHS.isSigned() ? LHS.sadd_sat(RHS) : LHS.uadd_sat(RHS),
+                     DestUnsigned)));
+          break;
+        case Builtin::BI__builtin_elementwise_sub_sat:
+          ResultElements.push_back(APValue(
+              APSInt(LHS.isSigned() ? LHS.ssub_sat(RHS) : LHS.usub_sat(RHS),
+                     DestUnsigned)));
+          break;
+        case clang::X86::BI__builtin_ia32_pmulhuw128:
+        case clang::X86::BI__builtin_ia32_pmulhuw256:
+        case clang::X86::BI__builtin_ia32_pmulhuw512:
+          ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhu(LHS, RHS),
+                                                  /*isUnsigned=*/true)));
+          break;
+        case clang::X86::BI__builtin_ia32_pmulhw128:
+        case clang::X86::BI__builtin_ia32_pmulhw256:
+        case clang::X86::BI__builtin_ia32_pmulhw512:
+          ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhs(LHS, RHS),
+                                                  /*isUnsigned=*/false)));
+          break;
+        case clang::X86::BI__builtin_ia32_psllv2di:
+        case clang::X86::BI__builtin_ia32_psllv4di:
+        case clang::X86::BI__builtin_ia32_psllv4si:
+        case clang::X86::BI__builtin_ia32_psllv8si:
+          if (RHS.uge(RHS.getBitWidth())) {
+            ResultElements.push_back(
+                APValue(APSInt(APInt::getZero(RHS.getBitWidth()), DestUnsi...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 27, 2025

@llvm/pr-subscribers-backend-x86

Author: Justin Riddell (Arghnews)

Changes

Implement VectorExprEvaluator::VisitCallExpr constexpr support for left, right, arithmetic shift for MMX/SSE/AVX2/AVX512 intrinsics

NOTE: currently incomplete, asking a question with what I've got so far:

I think I need to split this out into test files with specific -target-features. I believe the sort of thing I need is as listed below? Or can it be done more simply than this? Certain test files already have these target-features, although for example avx512dq-builtins-constrained.c has -triple=x86_64-apple-darwin, is this just about spitting out some kind of non-backend code that doesn't matter for this though? I'm not sure, would appreciate direction here, and the exact kind of // RUN: %clang_cc1 line I should be using, how restrictive does the +avx512 need to be for example, do I need REQUIRES: x86-registered-target or is that only if the actual backend code is being spit out (I sort of understand what this means)

MMX (_mm_*_pi*): -mmmx

128-bit epi16/32/64: -msse2 (except _srai_epi64: needs -mavx512dq -mavx512vl)

AVX2 (256-bit): -mavx2 (except _srai_epi64: needs -mavx512dq -mavx512vl)

AVX-512 512-bit:

epi16 → -mavx512bw

epi32/epi64 → -mavx512f

srai epi64 → add -mavx512dq

256-/128-bit masked forms → -mavx512vl

--

_mm*_slli_epi*
_mm*_srli_epi*
_mm*_srai_epi*
_mm*_mask_slli_epi*
_mm*_maskz_slli_epi*

NOTE: not all intrinsics have all widths i.e.
_mm_srli_pi32 doesn't have pi64 etc.

Also, current test has some tests that are split over multiple lines for readability. This doesn't seem common in the repo for other tests, but it's incredibly hard to read without it, I've only left it like that for test cases that are hard to read

Let me know any other feedback as it stands, thanks

@RKSimon

Implements (full list):

_mm256_mask_slli_epi16
_mm256_mask_slli_epi32
_mm256_mask_slli_epi64
_mm256_mask_srai_epi16
_mm256_mask_srai_epi32
_mm256_mask_srai_epi64
_mm256_mask_srli_epi16
_mm256_mask_srli_epi32
_mm256_mask_srli_epi64
_mm256_maskz_slli_epi16
_mm256_maskz_slli_epi32
_mm256_maskz_slli_epi64
_mm256_maskz_srai_epi16
_mm256_maskz_srai_epi32
_mm256_maskz_srai_epi64
_mm256_maskz_srli_epi16
_mm256_maskz_srli_epi32
_mm256_maskz_srli_epi64
_mm256_slli_epi16
_mm256_slli_epi32
_mm256_slli_epi64
_mm256_srai_epi16
_mm256_srai_epi32
_mm256_srai_epi64
_mm256_srli_epi16
_mm256_srli_epi32
_mm256_srli_epi64
_mm512_mask_slli_epi16
_mm512_mask_slli_epi32
_mm512_mask_slli_epi64
_mm512_mask_srai_epi16
_mm512_mask_srai_epi32
_mm512_mask_srai_epi64
_mm512_mask_srli_epi16
_mm512_mask_srli_epi32
_mm512_mask_srli_epi64
_mm512_maskz_slli_epi16
_mm512_maskz_slli_epi32
_mm512_maskz_slli_epi64
_mm512_maskz_srai_epi16
_mm512_maskz_srai_epi32
_mm512_maskz_srai_epi64
_mm512_maskz_srli_epi16
_mm512_maskz_srli_epi32
_mm512_maskz_srli_epi64
_mm512_slli_epi16
_mm512_slli_epi32
_mm512_slli_epi64
_mm512_srai_epi16
_mm512_srai_epi32
_mm512_srai_epi64
_mm512_srli_epi16
_mm512_srli_epi32
_mm512_srli_epi64
_mm_slli_epi16
_mm_slli_epi32
_mm_slli_epi64
_mm_slli_pi16
_mm_slli_pi32
_mm_srai_epi16
_mm_srai_epi32
_mm_srai_epi64
_mm_srai_pi16
_mm_srai_pi32
_mm_srli_epi16
_mm_srli_epi32
_mm_srli_epi64
_mm_srli_pi16

Patch is 91.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155542.diff

10 Files Affected:

  • (modified) clang/include/clang/Basic/BuiltinsX86.td (+51-32)
  • (modified) clang/lib/AST/ExprConstant.cpp (+200-52)
  • (modified) clang/lib/Headers/avx2intrin.h (+14-18)
  • (modified) clang/lib/Headers/avx512bwintrin.h (+16-21)
  • (modified) clang/lib/Headers/avx512fintrin.h (+30-41)
  • (modified) clang/lib/Headers/avx512vlbwintrin.h (+20-26)
  • (modified) clang/lib/Headers/avx512vlintrin.h (+35-46)
  • (modified) clang/lib/Headers/emmintrin.h (+13-13)
  • (modified) clang/lib/Headers/mmintrin.h (+21-31)
  • (added) clang/test/CodeGen/X86/shift-immediate-constexpr.c (+441)
diff --git a/clang/include/clang/Basic/BuiltinsX86.td b/clang/include/clang/Basic/BuiltinsX86.td
index 527acd9ef086e..fb194fbb667fa 100644
--- a/clang/include/clang/Basic/BuiltinsX86.td
+++ b/clang/include/clang/Basic/BuiltinsX86.td
@@ -275,15 +275,8 @@ let Features = "sse2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] i
   def psrlq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
   def psllw128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
   def pslld128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>)">;
-  def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Vector<2, long long int>)">;
-  def psllwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
-  def pslldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
-  def psllqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
-  def psrlwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
-  def psrldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
-  def psrlqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
-  def psrawi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
-  def psradi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
+  def psllq128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long "
+                            "int>, _Vector<2, long long int>)">;
   def pmaddwd128 : X86Builtin<"_Vector<4, int>(_Vector<8, short>, _Vector<8, short>)">;
   def pslldqi128_byteshift : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant int)">;
   def psrldqi128_byteshift : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, _Constant int)">;
@@ -291,6 +284,18 @@ let Features = "sse2", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] i
 
 let Features = "sse2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
   def pmuludq128 : X86Builtin<"_Vector<2, long long int>(_Vector<4, int>, _Vector<4, int>)">;
+
+  def psllwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
+  def pslldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
+  def psllqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
+
+  def psrlwi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
+  def psrldi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
+  def psrlqi128
+      : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
+
+  def psrawi128 : X86Builtin<"_Vector<8, short>(_Vector<8, short>, int)">;
+  def psradi128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, int)">;
 }
 
 let Features = "sse3", Attributes = [NoThrow] in {
@@ -595,23 +600,20 @@ let Features = "avx2", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] i
   def psignb256 : X86Builtin<"_Vector<32, char>(_Vector<32, char>, _Vector<32, char>)">;
   def psignw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<16, short>)">;
   def psignd256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>)">;
-  def psllwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
   def psllw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
   def pslldqi256_byteshift : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
-  def pslldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
   def pslld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
-  def psllqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
-  def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
-  def psrawi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
-  def psraw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
-  def psradi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+  def psllq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long "
+                            "int>, _Vector<2, long long int>)">;
+  def psraw256
+      : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
   def psrad256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
-  def psrldqi256_byteshift : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Constant int)">;
-  def psrlwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
-  def psrlw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
-  def psrldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
-  def psrld256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
-  def psrlqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
+  def psrldqi256_byteshift : X86Builtin<"_Vector<4, long long int>(_Vector<4, "
+                                        "long long int>, _Constant int)">;
+  def psrlw256
+      : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<8, short>)">;
+  def psrld256
+      : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<4, int>)">;
   def psrlq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
   def pblendd128 : X86Builtin<"_Vector<4, int>(_Vector<4, int>, _Vector<4, int>, _Constant int)">;
   def pblendd256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, _Vector<8, int>, _Constant int)">;
@@ -628,6 +630,17 @@ let Features = "avx2", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWi
   def pmuldq256 : X86Builtin<"_Vector<4, long long int>(_Vector<8, int>, _Vector<8, int>)">;
   def pmuludq256 : X86Builtin<"_Vector<4, long long int>(_Vector<8, int>, _Vector<8, int>)">;
 
+  def psllwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
+  def pslldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+  def psllqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
+
+  def psrlwi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
+  def psrldi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+  def psrlqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
+
+  def psrawi256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, int)">;
+  def psradi256 : X86Builtin<"_Vector<8, int>(_Vector<8, int>, int)">;
+
   def pmulhuw256 : X86Builtin<"_Vector<16, unsigned short>(_Vector<16, unsigned short>, _Vector<16, unsigned short>)">;
   def pmulhw256 : X86Builtin<"_Vector<16, short>(_Vector<16, short>, _Vector<16, short>)">;
 
@@ -2098,7 +2111,6 @@ let Features = "avx512bw,evex512", Attributes = [NoThrow, Const, RequiredVectorW
   def pshuflw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Constant int)">;
   def psllv32hi : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<32, short>)">;
   def psllw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
-  def psllwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
 }
 
 let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
@@ -2109,7 +2121,8 @@ let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, RequiredVector
   def psllv8hi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
 }
 
-let Features = "avx512f,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
+let Features = "avx512f,evex512", Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
+  def psllwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
   def pslldi512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, int)">;
   def psllqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
 }
@@ -2126,7 +2139,9 @@ let Features = "avx512bw,avx512vl", Attributes = [NoThrow, Const, RequiredVector
   def psrlv8hi : X86Builtin<"_Vector<8, short>(_Vector<8, short>, _Vector<8, short>)">;
 }
 
-let Features = "avx512f,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
+let Features = "avx512f,evex512",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
+  def psrlwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
   def psrldi512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, int)">;
   def psrlqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
 }
@@ -2152,10 +2167,10 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256
 }
 
 let Features = "avx512bw,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
-  def psraw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
-  def psrawi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
-  def psrlw512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
-  def psrlwi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
+  def psraw512
+      : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
+  def psrlw512
+      : X86Builtin<"_Vector<32, short>(_Vector<32, short>, _Vector<8, short>)">;
   def pslldqi512_byteshift : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Constant int)">;
   def psrldqi512_byteshift : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, _Constant int)">;
 }
@@ -2487,7 +2502,9 @@ let Features = "avx512f", Attributes = [NoThrow, Const, RequiredVectorWidth<128>
   def scalefss_round_mask : X86Builtin<"_Vector<4, float>(_Vector<4, float>, _Vector<4, float>, _Vector<4, float>, unsigned char, _Constant int)">;
 }
 
-let Features = "avx512f,evex512", Attributes = [NoThrow, Const, RequiredVectorWidth<512>] in {
+let Features = "avx512f,evex512",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<512>] in {
+  def psrawi512 : X86Builtin<"_Vector<32, short>(_Vector<32, short>, int)">;
   def psradi512 : X86Builtin<"_Vector<16, int>(_Vector<16, int>, int)">;
   def psraqi512 : X86Builtin<"_Vector<8, long long int>(_Vector<8, long long int>, int)">;
 }
@@ -2500,11 +2517,13 @@ let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256
   def psraq256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, _Vector<2, long long int>)">;
 }
 
-let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<128>] in {
+let Features = "avx512vl",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<128>] in {
   def psraqi128 : X86Builtin<"_Vector<2, long long int>(_Vector<2, long long int>, int)">;
 }
 
-let Features = "avx512vl", Attributes = [NoThrow, Const, RequiredVectorWidth<256>] in {
+let Features = "avx512vl",
+    Attributes = [NoThrow, Const, Constexpr, RequiredVectorWidth<256>] in {
   def psraqi256 : X86Builtin<"_Vector<4, long long int>(_Vector<4, long long int>, int)">;
 }
 
diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp
index 19703e40d2696..85a3283568fdd 100644
--- a/clang/lib/AST/ExprConstant.cpp
+++ b/clang/lib/AST/ExprConstant.cpp
@@ -11621,6 +11621,7 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
   case clang::X86::BI__builtin_ia32_pmulhw128:
   case clang::X86::BI__builtin_ia32_pmulhw256:
   case clang::X86::BI__builtin_ia32_pmulhw512:
+
   case clang::X86::BI__builtin_ia32_psllv2di:
   case clang::X86::BI__builtin_ia32_psllv4di:
   case clang::X86::BI__builtin_ia32_psllv4si:
@@ -11630,7 +11631,41 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
   case clang::X86::BI__builtin_ia32_psrlv2di:
   case clang::X86::BI__builtin_ia32_psrlv4di:
   case clang::X86::BI__builtin_ia32_psrlv4si:
-  case clang::X86::BI__builtin_ia32_psrlv8si:{
+  case clang::X86::BI__builtin_ia32_psrlv8si:
+
+  // Logical left shift by immediate
+  case clang::X86::BI__builtin_ia32_psllwi128:
+  case clang::X86::BI__builtin_ia32_pslldi128:
+  case clang::X86::BI__builtin_ia32_psllqi128:
+  case clang::X86::BI__builtin_ia32_psllwi256:
+  case clang::X86::BI__builtin_ia32_pslldi256:
+  case clang::X86::BI__builtin_ia32_psllqi256:
+  case clang::X86::BI__builtin_ia32_psllwi512:
+  case clang::X86::BI__builtin_ia32_pslldi512:
+  case clang::X86::BI__builtin_ia32_psllqi512:
+
+  // Logical right shift by immediate
+  case clang::X86::BI__builtin_ia32_psrlwi128:
+  case clang::X86::BI__builtin_ia32_psrldi128:
+  case clang::X86::BI__builtin_ia32_psrlqi128:
+  case clang::X86::BI__builtin_ia32_psrlwi256:
+  case clang::X86::BI__builtin_ia32_psrldi256:
+  case clang::X86::BI__builtin_ia32_psrlqi256:
+  case clang::X86::BI__builtin_ia32_psrlwi512:
+  case clang::X86::BI__builtin_ia32_psrldi512:
+  case clang::X86::BI__builtin_ia32_psrlqi512:
+
+  // Arithmetic right shift by immediate
+  case clang::X86::BI__builtin_ia32_psrawi128:
+  case clang::X86::BI__builtin_ia32_psradi128:
+  case clang::X86::BI__builtin_ia32_psraqi128:
+  case clang::X86::BI__builtin_ia32_psrawi256:
+  case clang::X86::BI__builtin_ia32_psradi256:
+  case clang::X86::BI__builtin_ia32_psraqi256:
+  case clang::X86::BI__builtin_ia32_psrawi512:
+  case clang::X86::BI__builtin_ia32_psradi512:
+  case clang::X86::BI__builtin_ia32_psraqi512: {
+
     APValue SourceLHS, SourceRHS;
     if (!EvaluateAsRValue(Info, E->getArg(0), SourceLHS) ||
         !EvaluateAsRValue(Info, E->getArg(1), SourceRHS))
@@ -11644,64 +11679,177 @@ bool VectorExprEvaluator::VisitCallExpr(const CallExpr *E) {
 
     for (unsigned EltNum = 0; EltNum < SourceLen; ++EltNum) {
       APSInt LHS = SourceLHS.getVectorElt(EltNum).getInt();
-      APSInt RHS = SourceRHS.getVectorElt(EltNum).getInt();
-      switch (E->getBuiltinCallee()) {
-      case Builtin::BI__builtin_elementwise_add_sat:
-        ResultElements.push_back(APValue(
-            APSInt(LHS.isSigned() ? LHS.sadd_sat(RHS) : LHS.uadd_sat(RHS),
-                   DestUnsigned)));
-        break;
-      case Builtin::BI__builtin_elementwise_sub_sat:
-        ResultElements.push_back(APValue(
-            APSInt(LHS.isSigned() ? LHS.ssub_sat(RHS) : LHS.usub_sat(RHS),
-                   DestUnsigned)));
-        break;
-      case clang::X86::BI__builtin_ia32_pmulhuw128:
-      case clang::X86::BI__builtin_ia32_pmulhuw256:
-      case clang::X86::BI__builtin_ia32_pmulhuw512:
-        ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhu(LHS, RHS),
-                                                /*isUnsigned=*/true)));
-        break;
-      case clang::X86::BI__builtin_ia32_pmulhw128:
-      case clang::X86::BI__builtin_ia32_pmulhw256:
-      case clang::X86::BI__builtin_ia32_pmulhw512:
-        ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhs(LHS, RHS),
-                                                /*isUnsigned=*/false)));
-        break;
-      case clang::X86::BI__builtin_ia32_psllv2di:
-      case clang::X86::BI__builtin_ia32_psllv4di:
-      case clang::X86::BI__builtin_ia32_psllv4si:
-      case clang::X86::BI__builtin_ia32_psllv8si:
-        if (RHS.uge(RHS.getBitWidth())) {
-          ResultElements.push_back(
-              APValue(APSInt(APInt::getZero(RHS.getBitWidth()), DestUnsigned)));
+
+      if (SourceRHS.isInt()) {
+        uint64_t LaneWidth = 0;
+        bool IsLeftShift = false;
+        bool IsRightShift = false;
+        bool IsArithmeticRightShift = false;
+
+        switch (E->getBuiltinCallee()) {
+        case clang::X86::BI__builtin_ia32_psllwi128:
+        case clang::X86::BI__builtin_ia32_psllwi256:
+        case clang::X86::BI__builtin_ia32_psllwi512:
+          IsLeftShift = true;
+          LaneWidth = 16;
+          break;
+        case clang::X86::BI__builtin_ia32_pslldi128:
+        case clang::X86::BI__builtin_ia32_pslldi256:
+        case clang::X86::BI__builtin_ia32_pslldi512:
+          IsLeftShift = true;
+          LaneWidth = 32;
           break;
+        case clang::X86::BI__builtin_ia32_psllqi128:
+        case clang::X86::BI__builtin_ia32_psllqi256:
+        case clang::X86::BI__builtin_ia32_psllqi512:
+          IsLeftShift = true;
+          LaneWidth = 64;
+          break;
+
+        case clang::X86::BI__builtin_ia32_psrlwi128:
+        case clang::X86::BI__builtin_ia32_psrlwi256:
+        case clang::X86::BI__builtin_ia32_psrlwi512:
+          IsRightShift = true;
+          LaneWidth = 16;
+          break;
+        case clang::X86::BI__builtin_ia32_psrldi128:
+        case clang::X86::BI__builtin_ia32_psrldi256:
+        case clang::X86::BI__builtin_ia32_psrldi512:
+          IsRightShift = true;
+          LaneWidth = 32;
+          break;
+        case clang::X86::BI__builtin_ia32_psrlqi128:
+        case clang::X86::BI__builtin_ia32_psrlqi256:
+        case clang::X86::BI__builtin_ia32_psrlqi512:
+          IsRightShift = true;
+          LaneWidth = 64;
+          break;
+
+        case clang::X86::BI__builtin_ia32_psrawi128:
+        case clang::X86::BI__builtin_ia32_psrawi256:
+        case clang::X86::BI__builtin_ia32_psrawi512:
+          IsArithmeticRightShift = true;
+          LaneWidth = 16;
+          break;
+        case clang::X86::BI__builtin_ia32_psradi128:
+        case clang::X86::BI__builtin_ia32_psradi256:
+        case clang::X86::BI__builtin_ia32_psradi512:
+          IsArithmeticRightShift = true;
+          LaneWidth = 32;
+          break;
+        case clang::X86::BI__builtin_ia32_psraqi128:
+        case clang::X86::BI__builtin_ia32_psraqi256:
+        case clang::X86::BI__builtin_ia32_psraqi512:
+          IsArithmeticRightShift = true;
+          LaneWidth = 64;
+          break;
+
+        default:
+          llvm_unreachable("Unexpected builtin callee");
         }
-        ResultElements.push_back(
-            APValue(APSInt(LHS.shl(RHS.getZExtValue()), DestUnsigned)));
-        break;
-      case clang::X86::BI__builtin_ia32_psrav4si:
-      case clang::X86::BI__builtin_ia32_psrav8si:
-        if (RHS.uge(RHS.getBitWidth())) {
+
+        const APSInt RHS = SourceRHS.getInt();
+        const auto ShiftAmount = RHS.getZExtValue();
+        APInt ResultOut;
+        if (IsArithmeticRightShift) {
+          ResultOut = LHS.ashr(std::min(ShiftAmount, LaneWidth));
+        } else if (ShiftAmount >= LaneWidth) {
+          ResultOut = APInt(LaneWidth, 0);
+        } else if (IsLeftShift) {
+          ResultOut = LHS.shl(ShiftAmount);
+        } else if (IsRightShift) {
+          ResultOut = LHS.lshr(ShiftAmount);
+        } else {
+          llvm_unreachable("Invalid shift type");
+        }
+        ResultElements.push_back(APValue(APSInt(
+            std::move(ResultOut),
+            /*isUnsigned=*/DestEltTy->isUnsignedIntegerOrEnumerationType())));
+      } else {
+        APSInt RHS = SourceRHS.getVectorElt(EltNum).getInt();
+        switch (E->getBuiltinCallee()) {
+        case Builtin::BI__builtin_elementwise_add_sat:
+          ResultElements.push_back(APValue(
+              APSInt(LHS.isSigned() ? LHS.sadd_sat(RHS) : LHS.uadd_sat(RHS),
+                     DestUnsigned)));
+          break;
+        case Builtin::BI__builtin_elementwise_sub_sat:
+          ResultElements.push_back(APValue(
+              APSInt(LHS.isSigned() ? LHS.ssub_sat(RHS) : LHS.usub_sat(RHS),
+                     DestUnsigned)));
+          break;
+        case clang::X86::BI__builtin_ia32_pmulhuw128:
+        case clang::X86::BI__builtin_ia32_pmulhuw256:
+        case clang::X86::BI__builtin_ia32_pmulhuw512:
+          ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhu(LHS, RHS),
+                                                  /*isUnsigned=*/true)));
+          break;
+        case clang::X86::BI__builtin_ia32_pmulhw128:
+        case clang::X86::BI__builtin_ia32_pmulhw256:
+        case clang::X86::BI__builtin_ia32_pmulhw512:
+          ResultElements.push_back(APValue(APSInt(llvm::APIntOps::mulhs(LHS, RHS),
+                                                  /*isUnsigned=*/false)));
+          break;
+        case clang::X86::BI__builtin_ia32_psllv2di:
+        case clang::X86::BI__builtin_ia32_psllv4di:
+        case clang::X86::BI__builtin_ia32_psllv4si:
+        case clang::X86::BI__builtin_ia32_psllv8si:
+          if (RHS.uge(RHS.getBitWidth())) {
+            ResultElements.push_back(
+                APValue(APSInt(APInt::getZero(RHS.getBitWidth()), DestUnsi...
[truncated]

@github-actions
Copy link

github-actions bot commented Aug 27, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@Arghnews Arghnews force-pushed the vector_shift branch 2 times, most recently from 409524a to 759f06f Compare August 27, 2025 03:06
@RKSimon RKSimon self-requested a review August 27, 2025 07:30
@RKSimon
Copy link
Collaborator

RKSimon commented Aug 27, 2025

We're putting these tests in the appropriate llvm-project\clang\test\CodeGen\X86*-builtins.c files - search for TEST_CONSTEXPR

@Arghnews
Copy link
Contributor Author

Arghnews commented Aug 28, 2025

We're putting these tests in the appropriate llvm-project\clang\test\CodeGen\X86*-builtins.c files - search for TEST_CONSTEXPR

I have been splitting these out as specified, however I've run into an issue.

Test files using -fexperimental-new-constant-interpreter like clang/test/CodeGen/X86/avx2-builtins.c or avx512f-builtins.c, that I'd put tests in, error when using the experimental constant interpreter flag.

If I remove this flag, they work fine (with a little fixing/neatening things up for casts etc.)

I guess this is because clang/lib/AST/ExprConstant is the current interpreter and we're not adding support to this new one.

I can't see a way to opt out of using this for specific lines when doing TEST_CONSTEXPR. Shall I just make new files for these cases? Is adding support for these in the new interpreter a large task, beyond the scope of this; if it's not, I'm happy to take a look

Please suggest what you think is the best way forward here, thanks
@RKSimon

@RKSimon RKSimon requested a review from tbaederr August 28, 2025 07:52
@RKSimon
Copy link
Collaborator

RKSimon commented Aug 28, 2025

CC @tbaederr who can advise on helping you add support for the new interp as well

@tbaederr
Copy link
Contributor

I guess this is because clang/lib/AST/ExprConstant is the current interpreter and we're not adding support to this new one.

The trick is to add support for the new one as well 👼

@RKSimon
Copy link
Collaborator

RKSimon commented Aug 28, 2025

@Arghnews Maybe use #155358 as reference

@llvmbot llvmbot added the clang:bytecode Issues for the clang bytecode constexpr interpreter label Aug 29, 2025
@Arghnews
Copy link
Contributor Author

Have made updates, neatened/refactored, now passes all tests with normal and experimental interpreter, have formatted etc.

@tbaederr @RKSimon please review, thanks!

@Arghnews

This comment was marked as outdated.

@Arghnews
Copy link
Contributor Author

Refactored to use your @RKSimon WIP change in #156017 , much neater way of doing it, cheers

@Arghnews Arghnews requested a review from RKSimon August 29, 2025 12:51
RKSimon added a commit to RKSimon/llvm-project that referenced this pull request Aug 29, 2025
…ut of bounds per-element shift amounts

This should allow us to reuse these cases for the shift-by-immediate builtins in llvm#155542
@Arghnews
Copy link
Contributor Author

Updated with suggested changes from #156019

RKSimon added a commit that referenced this pull request Aug 29, 2025
…ut of bounds per-element shift amounts (#156019)

This should allow us to reuse these cases for the shift-by-immediate builtins in #155542
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - but let @tbaederr comment on the bytecode

Implement VectorExprEvaluator::VisitCallExpr constexpr support for left,
right, arithmetic shift for MMX/SSE/AVX2/AVX512 intrinsics.
Also implement in experimental-new-constant-interpreter

Adds support and tests for
_mm*_slli_epi*
_mm*_srli_epi*
_mm*_srai_epi*
_mm*_mask_slli_epi*
_mm*_maskz_slli_epi*

NOTE: not all intrinsics have all widths i.e.
_mm_srli_pi32 doesn't have pi64 etc.
Copy link
Collaborator

@RKSimon RKSimon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@RKSimon RKSimon enabled auto-merge (squash) August 29, 2025 14:36
@RKSimon RKSimon merged commit f9e16fa into llvm:main Aug 29, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:X86 clang:bytecode Issues for the clang bytecode constexpr interpreter clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Headers][X86] VectorExprEvaluator::VisitCallExpr - allow MMX/SSE/AVX2/AVX512 shift by immediate intrinsics to be used in constexpr

4 participants