Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[clang][ARM64EC] Add support for hybrid_patchable attribute. #99478

Merged
merged 1 commit into from
Jul 27, 2024

Conversation

cjacek
Copy link
Contributor

@cjacek cjacek commented Jul 18, 2024

This adds support for hybrid_patchable on top of LLVM part from #92965 (so it depends on #92965 and this PR is meant only for the second commit). For the most part, it just adds LLVM attribute whenever C/C++ attribute is specified.

I added a warning when it's used on static function. Due to the way it works (emitting a reference to an undefined symbol that linker substitutes with a generated thunk), it can't work on static functions. MSVC just silently ignores it; using it like that seems like a non-obvious mistake to me, so I followed MSVC by allowing it too, but emitting an additional warning.

The patch does nothing special for function inlining, which matches my experimentation with MSVC. hybrid_patchable functions may be specified on inline functions. Depending on optimizations taken, when they are actually inlined, it has no effect, but when they are not inlined, it works as expected.

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:AArch64 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:codegen llvm:ir llvm:transforms labels Jul 18, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-clang
@llvm/pr-subscribers-llvm-ir
@llvm/pr-subscribers-clang-codegen

@llvm/pr-subscribers-llvm-transforms

Author: Jacek Caban (cjacek)

Changes

This adds support for hybrid_patchable on top of LLVM part from #92965 (so it depends on #92965 and this PR is meant only for the second commit). For the most part, it just adds LLVM attribute whenever C/C++ attribute is specified.

I added a warning when it's used on static function. Due to the way it works (emitting a reference to an undefined symbol that linker substitutes with a generated thunk), it can't work on static functions. MSVC just silently ignores it; using it like that seems like a non-obvious mistake to me, so I followed MSVC by allowing it too, but emitting an additional warning.

The patch does nothing special for function inlining, which matches my experimentation with MSVC. hybrid_patchable functions may be specified on inline functions. Depending on optimizations taken, when they are actually inlined, it has no effect, but when they are not inlined, it works as expected.


Patch is 37.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99478.diff

19 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+9)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+10)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
  • (modified) clang/lib/AST/TypePrinter.cpp (+3)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+3)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+7)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (added) clang/test/CodeGen/arm64ec-hybrid-patchable.c (+29)
  • (modified) clang/test/Misc/pragma-attribute-supported-attributes-list.test (+1)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1)
  • (modified) llvm/include/llvm/CodeGen/AsmPrinter.h (+1-1)
  • (modified) llvm/include/llvm/IR/Attributes.td (+3)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+2)
  • (modified) llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp (+2-2)
  • (modified) llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (+131-7)
  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+27)
  • (modified) llvm/lib/Target/AArch64/AArch64CallingConvention.td (+1-1)
  • (modified) llvm/lib/Transforms/Utils/CodeExtractor.cpp (+1)
  • (added) llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll (+315)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 1293d0ddbc117..2b6d421729322 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -477,6 +477,9 @@ def TargetELF : TargetSpec {
 def TargetELFOrMachO : TargetSpec {
   let ObjectFormats = ["ELF", "MachO"];
 }
+def TargetArm64EC : TargetSpec {
+  let CustomCode = [{ Target.getTriple().isWindowsArm64EC() }];
+}
 
 def TargetSupportsInitPriority : TargetSpec {
   let CustomCode = [{ !Target.getTriple().isOSzOS() }];
@@ -4027,6 +4030,12 @@ def SelectAny : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def HybridPatchable : DeclOrTypeAttr, TargetSpecificAttr<TargetArm64EC> {
+  let Spellings = [Declspec<"hybrid_patchable">, GCC<"hybrid_patchable">];
+  let Subjects = SubjectList<[Function]>;
+  let Documentation = [HybridPatchableDocs];
+}
+
 def Thread : Attr {
   let Spellings = [Declspec<"thread">];
   let LangOpts = [MicrosoftExt];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 09cf4f80bd999..b5944382ba7a4 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -5984,6 +5984,16 @@ For more information see
 or `msvc documentation <https://docs.microsoft.com/pl-pl/cpp/cpp/selectany>`_.
 }]; }
 
+def HybridPatchableDocs : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+The ``hybrid_patchable`` attribute declares an ARM64EC function with an additional
+x86-64 thunk, which may be patched in runtime.
+
+For more information see
+`ARM64EC ABI documentation <https://learn.microsoft.com/en-us/windows/arm/arm64ec-abi>`_.
+}]; }
+
 def WebAssemblyExportNameDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index de3d94155a9a0..60d9f3d0ec022 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3671,6 +3671,9 @@ def err_attribute_weak_static : Error<
   "weak declaration cannot have internal linkage">;
 def err_attribute_selectany_non_extern_data : Error<
   "'selectany' can only be applied to data items with external linkage">;
+def warn_attribute_hybrid_patchable_non_extern : Warning<
+  "'hybrid_patchable' is ignored on functions without external linkage">,
+  InGroup<IgnoredAttributes>;
 def err_declspec_thread_on_thread_variable : Error<
   "'__declspec(thread)' applied to variable that already has a "
   "thread-local storage specifier">;
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index ffec3ef9d2269..c2c3807399a8a 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1984,6 +1984,9 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
   case attr::MSABI: OS << "ms_abi"; break;
   case attr::SysVABI: OS << "sysv_abi"; break;
   case attr::RegCall: OS << "regcall"; break;
+  case attr::HybridPatchable:
+    OS << "hybrid_patchable";
+    break;
   case attr::Pcs: {
     OS << "pcs(";
    QualType t = T->getEquivalentType();
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index ea4635c039cb2..df2f493b3aa4a 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -977,6 +977,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
   if (D && D->hasAttr<NoProfileFunctionAttr>())
     Fn->addFnAttr(llvm::Attribute::NoProfile);
 
+  if (D && D->hasAttr<HybridPatchableAttr>())
+    Fn->addFnAttr(llvm::Attribute::HybridPatchable);
+
   if (D) {
     // Function attributes take precedence over command line flags.
     if (auto *A = D->getAttr<FunctionReturnThunksAttr>()) {
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 1f2fde12c9d24..0843071ed7ec1 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -6886,6 +6886,13 @@ static void checkAttributesAfterMerging(Sema &S, NamedDecl &ND) {
     }
   }
 
+  if (HybridPatchableAttr *Attr = ND.getAttr<HybridPatchableAttr>()) {
+    if (!ND.isExternallyVisible()) {
+      S.Diag(Attr->getLocation(),
+             diag::warn_attribute_hybrid_patchable_non_extern);
+      ND.dropAttr<SelectAnyAttr>();
+    }
+  }
   if (const InheritableAttr *Attr = getDLLAttr(&ND)) {
     auto *VD = dyn_cast<VarDecl>(&ND);
     bool IsAnonymousNS = false;
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 20f46c003a464..edaeef4bdb2de 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7033,6 +7033,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_MSConstexpr:
     handleMSConstexprAttr(S, D, AL);
     break;
+  case ParsedAttr::AT_HybridPatchable:
+    handleSimpleAttribute<HybridPatchableAttr>(S, D, AL);
+    break;
 
   // HLSL attributes:
   case ParsedAttr::AT_HLSLNumThreads:
diff --git a/clang/test/CodeGen/arm64ec-hybrid-patchable.c b/clang/test/CodeGen/arm64ec-hybrid-patchable.c
new file mode 100644
index 0000000000000..30e5c7cdbb102
--- /dev/null
+++ b/clang/test/CodeGen/arm64ec-hybrid-patchable.c
@@ -0,0 +1,29 @@
+// REQUIRES: aarch64-registered-target
+// RUN: %clang_cc1 -triple arm64ec-pc-windows -fms-extensions -emit-llvm -o - %s -verify | FileCheck %s
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define dso_local i32 @func() #0 {
+int __attribute__((hybrid_patchable)) func(void) {  return 1; }
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define dso_local i32 @func2() #0 {
+int __declspec(hybrid_patchable) func2(void) {  return 2; }
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define dso_local i32 @func3() #0 {
+int __declspec(hybrid_patchable) func3(void);
+int func3(void) {  return 3; }
+
+// CHECK: ; Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define internal void @static_func() #0 {
+// expected-warning@+1 {{'hybrid_patchable' is ignored on functions without external linkage}}
+static void __declspec(hybrid_patchable) static_func(void) {}
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define linkonce_odr dso_local i32 @func4() #0 comdat {
+int inline __declspec(hybrid_patchable) func4(void) {  return 4; }
+
+void caller(void) {
+  static_func();
+  func4();
+}
diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index 33f9c2f51363c..e082db698ef0c 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -83,6 +83,7 @@
 // CHECK-NEXT: HIPManaged (SubjectMatchRule_variable)
 // CHECK-NEXT: HLSLResourceClass (SubjectMatchRule_record_not_is_union)
 // CHECK-NEXT: Hot (SubjectMatchRule_function)
+// CHECK-NEXT: HybridPatchable (SubjectMatchRule_function)
 // CHECK-NEXT: IBAction (SubjectMatchRule_objc_method_is_instance)
 // CHECK-NEXT: IFunc (SubjectMatchRule_function)
 // CHECK-NEXT: InitPriority (SubjectMatchRule_variable)
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 184bbe32df695..fb88f2fe75adb 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -757,6 +757,7 @@ enum AttributeKindCodes {
   ATTR_KIND_RANGE = 92,
   ATTR_KIND_SANITIZE_NUMERICAL_STABILITY = 93,
   ATTR_KIND_INITIALIZES = 94,
+  ATTR_KIND_HYBRID_PATCHABLE = 95,
 };
 
 enum ComdatSelectionKindCodes {
diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h
index dc00bd57d655d..1c4e9e9111441 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -892,7 +892,6 @@ class AsmPrinter : public MachineFunctionPass {
   virtual void emitModuleCommandLines(Module &M);
 
   GCMetadataPrinter *getOrCreateGCPrinter(GCStrategy &S);
-  virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
   void emitGlobalIFunc(Module &M, const GlobalIFunc &GI);
 
 private:
@@ -900,6 +899,7 @@ class AsmPrinter : public MachineFunctionPass {
   bool shouldEmitLabelForBasicBlock(const MachineBasicBlock &MBB) const;
 
 protected:
+  virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
   virtual bool shouldEmitWeakSwiftAsyncExtendedFramePointerFlags() const {
     return false;
   }
diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td
index 0457f0c388d26..e1bd193891c1e 100644
--- a/llvm/include/llvm/IR/Attributes.td
+++ b/llvm/include/llvm/IR/Attributes.td
@@ -112,6 +112,9 @@ def ElementType : TypeAttr<"elementtype", [ParamAttr]>;
 /// symbol.
 def FnRetThunkExtern : EnumAttr<"fn_ret_thunk_extern", [FnAttr]>;
 
+/// Function has a hybrid patchable thunk.
+def HybridPatchable : EnumAttr<"hybrid_patchable", [FnAttr]>;
+
 /// Pass structure in an alloca.
 def InAlloca : TypeAttr<"inalloca", [ParamAttr]>;
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index b3ebe70e8c52f..324dcbca8137e 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -727,6 +727,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
     return bitc::ATTR_KIND_HOT;
   case Attribute::ElementType:
     return bitc::ATTR_KIND_ELEMENTTYPE;
+  case Attribute::HybridPatchable:
+    return bitc::ATTR_KIND_HYBRID_PATCHABLE;
   case Attribute::InlineHint:
     return bitc::ATTR_KIND_INLINE_HINT;
   case Attribute::InReg:
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 1f59ec545b4f7..b46a6d348413b 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2859,8 +2859,8 @@ bool AsmPrinter::emitSpecialLLVMGlobal(const GlobalVariable *GV) {
     auto *Arr = cast<ConstantArray>(GV->getInitializer());
     for (auto &U : Arr->operands()) {
       auto *C = cast<Constant>(U);
-      auto *Src = cast<Function>(C->getOperand(0)->stripPointerCasts());
-      auto *Dst = cast<Function>(C->getOperand(1)->stripPointerCasts());
+      auto *Src = cast<GlobalValue>(C->getOperand(0)->stripPointerCasts());
+      auto *Dst = cast<GlobalValue>(C->getOperand(1)->stripPointerCasts());
       int Kind = cast<ConstantInt>(C->getOperand(2))->getZExtValue();
 
       if (Src->hasDLLImportStorageClass()) {
diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
index 4ff52eb252d20..310b152ef9817 100644
--- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/IR/CallingConv.h"
+#include "llvm/IR/GlobalAlias.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Mangler.h"
@@ -70,15 +71,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
   Function *buildEntryThunk(Function *F);
   void lowerCall(CallBase *CB);
   Function *buildGuestExitThunk(Function *F);
-  bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
+  Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
+                                GlobalAlias *MangledAlias);
+  bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
+                       DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
   bool runOnModule(Module &M) override;
 
 private:
   int cfguard_module_flag = 0;
   FunctionType *GuardFnType = nullptr;
   PointerType *GuardFnPtrType = nullptr;
+  FunctionType *DispatchFnType = nullptr;
+  PointerType *DispatchFnPtrType = nullptr;
   Constant *GuardFnCFGlobal = nullptr;
   Constant *GuardFnGlobal = nullptr;
+  Constant *DispatchFnGlobal = nullptr;
   Module *M = nullptr;
 
   Type *PtrTy;
@@ -672,6 +679,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
   return GuestExit;
 }
 
+Function *
+AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
+                                                GlobalAlias *MangledAlias) {
+  llvm::raw_null_ostream NullThunkName;
+  FunctionType *Arm64Ty, *X64Ty;
+  Function *F = cast<Function>(MangledAlias->getAliasee());
+  SmallVector<ThunkArgTranslation> ArgTranslations;
+  getThunkType(F->getFunctionType(), F->getAttributes(),
+               Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
+               ArgTranslations);
+  std::string ThunkName(MangledAlias->getName());
+  if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
+    ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
+  } else {
+    ThunkName.append("$hybpatch_thunk");
+  }
+
+  Function *GuestExit =
+      Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
+  GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
+  GuestExit->setSection(".wowthk$aa");
+  BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
+  IRBuilder<> B(BB);
+
+  // Load the global symbol as a pointer to the check function.
+  LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
+
+  // Create new dispatch call instruction.
+  Function *ExitThunk =
+      buildExitThunk(F->getFunctionType(), F->getAttributes());
+  CallInst *Dispatch =
+      B.CreateCall(DispatchFnType, DispatchLoad,
+                   {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
+
+  // Ensure that the first arguments are passed in the correct registers.
+  Dispatch->setCallingConv(CallingConv::CFGuard_Check);
+
+  Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
+  SmallVector<Value *> Args;
+  for (Argument &Arg : GuestExit->args())
+    Args.push_back(&Arg);
+  CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
+  Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
+
+  if (Call->getType()->isVoidTy())
+    B.CreateRetVoid();
+  else
+    B.CreateRet(Call);
+
+  auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
+  auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
+  if (SRetAttr.isValid() && !InRegAttr.isValid()) {
+    GuestExit->addParamAttr(0, SRetAttr);
+    Call->addParamAttr(0, SRetAttr);
+  }
+
+  MangledAlias->setAliasee(GuestExit);
+  return GuestExit;
+}
+
 // Lower an indirect call with inline code.
 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
@@ -727,17 +794,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
 
   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
   GuardFnPtrType = PointerType::get(GuardFnType, 0);
+  DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
+  DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
   GuardFnCFGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
   GuardFnGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
+  DispatchFnGlobal =
+      M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
+
+  DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
+  SetVector<GlobalAlias *> PatchableFns;
 
-  SetVector<Function *> DirectCalledFns;
+  for (Function &F : Mod) {
+    if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
+        F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))
+      continue;
+
+    // Rename hybrid patchable functions and change callers to use a global
+    // alias instead.
+    if (std::optional<std::string> MangledName =
+            getArm64ECMangledFunctionName(F.getName().str())) {
+      std::string OrigName(F.getName());
+      F.setName(MangledName.value() + "$hp_target");
+
+      // The unmangled symbol is a weak alias to an undefined symbol with the
+      // "EXP+" prefix. This undefined symbol is resolved by the linker by
+      // creating an x86 thunk that jumps back to the actual EC target. Since we
+      // can't represent that in IR, we create an alias to the target instead.
+      // The "EXP+" symbol is set as metadata, which is then used by
+      // emitGlobalAlias to emit the right alias.
+      auto *A =
+          GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
+      F.replaceAllUsesWith(A);
+      F.setMetadata("arm64ec_exp_name",
+                    MDNode::get(M->getContext(),
+                                MDString::get(M->getContext(),
+                                              "EXP+" + MangledName.value())));
+      A->setAliasee(&F);
+
+      FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
+                                      MangledName.value(), &F);
+      PatchableFns.insert(A);
+    }
+  }
+
+  SetVector<GlobalValue *> DirectCalledFns;
   for (Function &F : Mod)
     if (!F.isDeclaration() &&
         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
-      processFunction(F, DirectCalledFns);
+      processFunction(F, DirectCalledFns, FnsMap);
 
   struct ThunkInfo {
     Constant *Src;
@@ -755,14 +862,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
           {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
     }
   }
-  for (Function *F : DirectCalledFns) {
+  for (GlobalValue *O : DirectCalledFns) {
+    auto GA = dyn_cast<GlobalAlias>(O);
+    auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
     ThunkMapping.push_back(
-        {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
+        {O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
          Arm64ECThunkType::Exit});
-    if (!F->hasDLLImportStorageClass())
+    if (!GA && !F->hasDLLImportStorageClass())
       ThunkMapping.push_back(
           {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
   }
+  for (GlobalAlias *A : PatchableFns) {
+    Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
+    ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
+  }
 
   if (!ThunkMapping.empty()) {
     SmallVector<Constant *> ThunkMappingArrayElems;
@@ -785,7 +898,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
 }
 
 bool AArch64Arm64ECCallLowering::processFunction(
-    Function &F, SetVector<Function *> &DirectCalledFns) {
+    Function &F, SetVector<GlobalValue *> &DirectCalledFns,
+    DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
   SmallVector<CallBase *, 8> IndirectCalls;
 
   // For ARM64EC targets, a function definition's name is mangled differently
@@ -837,6 +951,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
         continue;
       }
 
+      // Use mangled global alias for direct calls to patchable functions.
+      if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
+        auto I = FnsMap.find(A);
+        if (I != FnsMap.end()) {
+          CB->setCalledOperand(I->second);
+          DirectCalledFns.insert(I->first);
+          continue;
+        }
+      }
+
       IndirectCalls.push_back(CB);
       ++Arm64ECCallsLowered;
     }
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 6f6b551993b6d..63358c1568a35 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -201,6 +201,7 @@ class AArch64AsmPrinter : public AsmPrinter {
   void PrintDebugValueComment(const MachineInstr *MI, raw_ostream &OS);
 
   void emitFunctionBodyEnd() override;
+  void emitGlobalAlias(const Module &M, const GlobalAlias &GA) override;
 
   MCSymbol *GetCPISymbol(unsigned CPID) const override;
   void emitEndOfAsmFile(Module &M) override;
@@ -1263,6 +1264,32 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
   }
 }
 
+void AArch64AsmPrinter::emitGlobalAlias(const Module &M,
+                                        const GlobalAlias &GA) {
+  if (auto F = dyn_cast_or_null<Function>(GA.getAliasee())) {
+    // Global aliases must poin...
[truncated]

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 18, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Jacek Caban (cjacek)

Changes

This adds support for hybrid_patchable on top of LLVM part from #92965 (so it depends on #92965 and this PR is meant only for the second commit). For the most part, it just adds LLVM attribute whenever C/C++ attribute is specified.

I added a warning when it's used on static function. Due to the way it works (emitting a reference to an undefined symbol that linker substitutes with a generated thunk), it can't work on static functions. MSVC just silently ignores it; using it like that seems like a non-obvious mistake to me, so I followed MSVC by allowing it too, but emitting an additional warning.

The patch does nothing special for function inlining, which matches my experimentation with MSVC. hybrid_patchable functions may be specified on inline functions. Depending on optimizations taken, when they are actually inlined, it has no effect, but when they are not inlined, it works as expected.


Patch is 37.20 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/99478.diff

19 Files Affected:

  • (modified) clang/include/clang/Basic/Attr.td (+9)
  • (modified) clang/include/clang/Basic/AttrDocs.td (+10)
  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+3)
  • (modified) clang/lib/AST/TypePrinter.cpp (+3)
  • (modified) clang/lib/CodeGen/CodeGenFunction.cpp (+3)
  • (modified) clang/lib/Sema/SemaDecl.cpp (+7)
  • (modified) clang/lib/Sema/SemaDeclAttr.cpp (+3)
  • (added) clang/test/CodeGen/arm64ec-hybrid-patchable.c (+29)
  • (modified) clang/test/Misc/pragma-attribute-supported-attributes-list.test (+1)
  • (modified) llvm/include/llvm/Bitcode/LLVMBitCodes.h (+1)
  • (modified) llvm/include/llvm/CodeGen/AsmPrinter.h (+1-1)
  • (modified) llvm/include/llvm/IR/Attributes.td (+3)
  • (modified) llvm/lib/Bitcode/Writer/BitcodeWriter.cpp (+2)
  • (modified) llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp (+2-2)
  • (modified) llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp (+131-7)
  • (modified) llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp (+27)
  • (modified) llvm/lib/Target/AArch64/AArch64CallingConvention.td (+1-1)
  • (modified) llvm/lib/Transforms/Utils/CodeExtractor.cpp (+1)
  • (added) llvm/test/CodeGen/AArch64/arm64ec-hybrid-patchable.ll (+315)
diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td
index 1293d0ddbc117..2b6d421729322 100644
--- a/clang/include/clang/Basic/Attr.td
+++ b/clang/include/clang/Basic/Attr.td
@@ -477,6 +477,9 @@ def TargetELF : TargetSpec {
 def TargetELFOrMachO : TargetSpec {
   let ObjectFormats = ["ELF", "MachO"];
 }
+def TargetArm64EC : TargetSpec {
+  let CustomCode = [{ Target.getTriple().isWindowsArm64EC() }];
+}
 
 def TargetSupportsInitPriority : TargetSpec {
   let CustomCode = [{ !Target.getTriple().isOSzOS() }];
@@ -4027,6 +4030,12 @@ def SelectAny : InheritableAttr {
   let SimpleHandler = 1;
 }
 
+def HybridPatchable : DeclOrTypeAttr, TargetSpecificAttr<TargetArm64EC> {
+  let Spellings = [Declspec<"hybrid_patchable">, GCC<"hybrid_patchable">];
+  let Subjects = SubjectList<[Function]>;
+  let Documentation = [HybridPatchableDocs];
+}
+
 def Thread : Attr {
   let Spellings = [Declspec<"thread">];
   let LangOpts = [MicrosoftExt];
diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td
index 09cf4f80bd999..b5944382ba7a4 100644
--- a/clang/include/clang/Basic/AttrDocs.td
+++ b/clang/include/clang/Basic/AttrDocs.td
@@ -5984,6 +5984,16 @@ For more information see
 or `msvc documentation <https://docs.microsoft.com/pl-pl/cpp/cpp/selectany>`_.
 }]; }
 
+def HybridPatchableDocs : Documentation {
+  let Category = DocCatDecl;
+  let Content = [{
+The ``hybrid_patchable`` attribute declares an ARM64EC function with an additional
+x86-64 thunk, which may be patched in runtime.
+
+For more information see
+`ARM64EC ABI documentation <https://learn.microsoft.com/en-us/windows/arm/arm64ec-abi>`_.
+}]; }
+
 def WebAssemblyExportNameDocs : Documentation {
   let Category = DocCatFunction;
   let Content = [{
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index de3d94155a9a0..60d9f3d0ec022 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -3671,6 +3671,9 @@ def err_attribute_weak_static : Error<
   "weak declaration cannot have internal linkage">;
 def err_attribute_selectany_non_extern_data : Error<
   "'selectany' can only be applied to data items with external linkage">;
+def warn_attribute_hybrid_patchable_non_extern : Warning<
+  "'hybrid_patchable' is ignored on functions without external linkage">,
+  InGroup<IgnoredAttributes>;
 def err_declspec_thread_on_thread_variable : Error<
   "'__declspec(thread)' applied to variable that already has a "
   "thread-local storage specifier">;
diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp
index ffec3ef9d2269..c2c3807399a8a 100644
--- a/clang/lib/AST/TypePrinter.cpp
+++ b/clang/lib/AST/TypePrinter.cpp
@@ -1984,6 +1984,9 @@ void TypePrinter::printAttributedAfter(const AttributedType *T,
   case attr::MSABI: OS << "ms_abi"; break;
   case attr::SysVABI: OS << "sysv_abi"; break;
   case attr::RegCall: OS << "regcall"; break;
+  case attr::HybridPatchable:
+    OS << "hybrid_patchable";
+    break;
   case attr::Pcs: {
     OS << "pcs(";
    QualType t = T->getEquivalentType();
diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp
index ea4635c039cb2..df2f493b3aa4a 100644
--- a/clang/lib/CodeGen/CodeGenFunction.cpp
+++ b/clang/lib/CodeGen/CodeGenFunction.cpp
@@ -977,6 +977,9 @@ void CodeGenFunction::StartFunction(GlobalDecl GD, QualType RetTy,
   if (D && D->hasAttr<NoProfileFunctionAttr>())
     Fn->addFnAttr(llvm::Attribute::NoProfile);
 
+  if (D && D->hasAttr<HybridPatchableAttr>())
+    Fn->addFnAttr(llvm::Attribute::HybridPatchable);
+
   if (D) {
     // Function attributes take precedence over command line flags.
     if (auto *A = D->getAttr<FunctionReturnThunksAttr>()) {
diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp
index 1f2fde12c9d24..0843071ed7ec1 100644
--- a/clang/lib/Sema/SemaDecl.cpp
+++ b/clang/lib/Sema/SemaDecl.cpp
@@ -6886,6 +6886,13 @@ static void checkAttributesAfterMerging(Sema &S, NamedDecl &ND) {
     }
   }
 
+  if (HybridPatchableAttr *Attr = ND.getAttr<HybridPatchableAttr>()) {
+    if (!ND.isExternallyVisible()) {
+      S.Diag(Attr->getLocation(),
+             diag::warn_attribute_hybrid_patchable_non_extern);
+      ND.dropAttr<SelectAnyAttr>();
+    }
+  }
   if (const InheritableAttr *Attr = getDLLAttr(&ND)) {
     auto *VD = dyn_cast<VarDecl>(&ND);
     bool IsAnonymousNS = false;
diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp
index 20f46c003a464..edaeef4bdb2de 100644
--- a/clang/lib/Sema/SemaDeclAttr.cpp
+++ b/clang/lib/Sema/SemaDeclAttr.cpp
@@ -7033,6 +7033,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL,
   case ParsedAttr::AT_MSConstexpr:
     handleMSConstexprAttr(S, D, AL);
     break;
+  case ParsedAttr::AT_HybridPatchable:
+    handleSimpleAttribute<HybridPatchableAttr>(S, D, AL);
+    break;
 
   // HLSL attributes:
   case ParsedAttr::AT_HLSLNumThreads:
diff --git a/clang/test/CodeGen/arm64ec-hybrid-patchable.c b/clang/test/CodeGen/arm64ec-hybrid-patchable.c
new file mode 100644
index 0000000000000..30e5c7cdbb102
--- /dev/null
+++ b/clang/test/CodeGen/arm64ec-hybrid-patchable.c
@@ -0,0 +1,29 @@
+// REQUIRES: aarch64-registered-target
+// RUN: %clang_cc1 -triple arm64ec-pc-windows -fms-extensions -emit-llvm -o - %s -verify | FileCheck %s
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define dso_local i32 @func() #0 {
+int __attribute__((hybrid_patchable)) func(void) {  return 1; }
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define dso_local i32 @func2() #0 {
+int __declspec(hybrid_patchable) func2(void) {  return 2; }
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define dso_local i32 @func3() #0 {
+int __declspec(hybrid_patchable) func3(void);
+int func3(void) {  return 3; }
+
+// CHECK: ; Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define internal void @static_func() #0 {
+// expected-warning@+1 {{'hybrid_patchable' is ignored on functions without external linkage}}
+static void __declspec(hybrid_patchable) static_func(void) {}
+
+// CHECK: ;    Function Attrs: hybrid_patchable noinline nounwind optnone
+// CHECK-NEXT: define linkonce_odr dso_local i32 @func4() #0 comdat {
+int inline __declspec(hybrid_patchable) func4(void) {  return 4; }
+
+void caller(void) {
+  static_func();
+  func4();
+}
diff --git a/clang/test/Misc/pragma-attribute-supported-attributes-list.test b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
index 33f9c2f51363c..e082db698ef0c 100644
--- a/clang/test/Misc/pragma-attribute-supported-attributes-list.test
+++ b/clang/test/Misc/pragma-attribute-supported-attributes-list.test
@@ -83,6 +83,7 @@
 // CHECK-NEXT: HIPManaged (SubjectMatchRule_variable)
 // CHECK-NEXT: HLSLResourceClass (SubjectMatchRule_record_not_is_union)
 // CHECK-NEXT: Hot (SubjectMatchRule_function)
+// CHECK-NEXT: HybridPatchable (SubjectMatchRule_function)
 // CHECK-NEXT: IBAction (SubjectMatchRule_objc_method_is_instance)
 // CHECK-NEXT: IFunc (SubjectMatchRule_function)
 // CHECK-NEXT: InitPriority (SubjectMatchRule_variable)
diff --git a/llvm/include/llvm/Bitcode/LLVMBitCodes.h b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
index 184bbe32df695..fb88f2fe75adb 100644
--- a/llvm/include/llvm/Bitcode/LLVMBitCodes.h
+++ b/llvm/include/llvm/Bitcode/LLVMBitCodes.h
@@ -757,6 +757,7 @@ enum AttributeKindCodes {
   ATTR_KIND_RANGE = 92,
   ATTR_KIND_SANITIZE_NUMERICAL_STABILITY = 93,
   ATTR_KIND_INITIALIZES = 94,
+  ATTR_KIND_HYBRID_PATCHABLE = 95,
 };
 
 enum ComdatSelectionKindCodes {
diff --git a/llvm/include/llvm/CodeGen/AsmPrinter.h b/llvm/include/llvm/CodeGen/AsmPrinter.h
index dc00bd57d655d..1c4e9e9111441 100644
--- a/llvm/include/llvm/CodeGen/AsmPrinter.h
+++ b/llvm/include/llvm/CodeGen/AsmPrinter.h
@@ -892,7 +892,6 @@ class AsmPrinter : public MachineFunctionPass {
   virtual void emitModuleCommandLines(Module &M);
 
   GCMetadataPrinter *getOrCreateGCPrinter(GCStrategy &S);
-  virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
   void emitGlobalIFunc(Module &M, const GlobalIFunc &GI);
 
 private:
@@ -900,6 +899,7 @@ class AsmPrinter : public MachineFunctionPass {
   bool shouldEmitLabelForBasicBlock(const MachineBasicBlock &MBB) const;
 
 protected:
+  virtual void emitGlobalAlias(const Module &M, const GlobalAlias &GA);
   virtual bool shouldEmitWeakSwiftAsyncExtendedFramePointerFlags() const {
     return false;
   }
diff --git a/llvm/include/llvm/IR/Attributes.td b/llvm/include/llvm/IR/Attributes.td
index 0457f0c388d26..e1bd193891c1e 100644
--- a/llvm/include/llvm/IR/Attributes.td
+++ b/llvm/include/llvm/IR/Attributes.td
@@ -112,6 +112,9 @@ def ElementType : TypeAttr<"elementtype", [ParamAttr]>;
 /// symbol.
 def FnRetThunkExtern : EnumAttr<"fn_ret_thunk_extern", [FnAttr]>;
 
+/// Function has a hybrid patchable thunk.
+def HybridPatchable : EnumAttr<"hybrid_patchable", [FnAttr]>;
+
 /// Pass structure in an alloca.
 def InAlloca : TypeAttr<"inalloca", [ParamAttr]>;
 
diff --git a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
index b3ebe70e8c52f..324dcbca8137e 100644
--- a/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
+++ b/llvm/lib/Bitcode/Writer/BitcodeWriter.cpp
@@ -727,6 +727,8 @@ static uint64_t getAttrKindEncoding(Attribute::AttrKind Kind) {
     return bitc::ATTR_KIND_HOT;
   case Attribute::ElementType:
     return bitc::ATTR_KIND_ELEMENTTYPE;
+  case Attribute::HybridPatchable:
+    return bitc::ATTR_KIND_HYBRID_PATCHABLE;
   case Attribute::InlineHint:
     return bitc::ATTR_KIND_INLINE_HINT;
   case Attribute::InReg:
diff --git a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
index 1f59ec545b4f7..b46a6d348413b 100644
--- a/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
+++ b/llvm/lib/CodeGen/AsmPrinter/AsmPrinter.cpp
@@ -2859,8 +2859,8 @@ bool AsmPrinter::emitSpecialLLVMGlobal(const GlobalVariable *GV) {
     auto *Arr = cast<ConstantArray>(GV->getInitializer());
     for (auto &U : Arr->operands()) {
       auto *C = cast<Constant>(U);
-      auto *Src = cast<Function>(C->getOperand(0)->stripPointerCasts());
-      auto *Dst = cast<Function>(C->getOperand(1)->stripPointerCasts());
+      auto *Src = cast<GlobalValue>(C->getOperand(0)->stripPointerCasts());
+      auto *Dst = cast<GlobalValue>(C->getOperand(1)->stripPointerCasts());
       int Kind = cast<ConstantInt>(C->getOperand(2))->getZExtValue();
 
       if (Src->hasDLLImportStorageClass()) {
diff --git a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
index 4ff52eb252d20..310b152ef9817 100644
--- a/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64Arm64ECCallLowering.cpp
@@ -21,6 +21,7 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/IR/CallingConv.h"
+#include "llvm/IR/GlobalAlias.h"
 #include "llvm/IR/IRBuilder.h"
 #include "llvm/IR/Instruction.h"
 #include "llvm/IR/Mangler.h"
@@ -70,15 +71,21 @@ class AArch64Arm64ECCallLowering : public ModulePass {
   Function *buildEntryThunk(Function *F);
   void lowerCall(CallBase *CB);
   Function *buildGuestExitThunk(Function *F);
-  bool processFunction(Function &F, SetVector<Function *> &DirectCalledFns);
+  Function *buildPatchableThunk(GlobalAlias *UnmangledAlias,
+                                GlobalAlias *MangledAlias);
+  bool processFunction(Function &F, SetVector<GlobalValue *> &DirectCalledFns,
+                       DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap);
   bool runOnModule(Module &M) override;
 
 private:
   int cfguard_module_flag = 0;
   FunctionType *GuardFnType = nullptr;
   PointerType *GuardFnPtrType = nullptr;
+  FunctionType *DispatchFnType = nullptr;
+  PointerType *DispatchFnPtrType = nullptr;
   Constant *GuardFnCFGlobal = nullptr;
   Constant *GuardFnGlobal = nullptr;
+  Constant *DispatchFnGlobal = nullptr;
   Module *M = nullptr;
 
   Type *PtrTy;
@@ -672,6 +679,66 @@ Function *AArch64Arm64ECCallLowering::buildGuestExitThunk(Function *F) {
   return GuestExit;
 }
 
+Function *
+AArch64Arm64ECCallLowering::buildPatchableThunk(GlobalAlias *UnmangledAlias,
+                                                GlobalAlias *MangledAlias) {
+  llvm::raw_null_ostream NullThunkName;
+  FunctionType *Arm64Ty, *X64Ty;
+  Function *F = cast<Function>(MangledAlias->getAliasee());
+  SmallVector<ThunkArgTranslation> ArgTranslations;
+  getThunkType(F->getFunctionType(), F->getAttributes(),
+               Arm64ECThunkType::GuestExit, NullThunkName, Arm64Ty, X64Ty,
+               ArgTranslations);
+  std::string ThunkName(MangledAlias->getName());
+  if (ThunkName[0] == '?' && ThunkName.find("@") != std::string::npos) {
+    ThunkName.insert(ThunkName.find("@"), "$hybpatch_thunk");
+  } else {
+    ThunkName.append("$hybpatch_thunk");
+  }
+
+  Function *GuestExit =
+      Function::Create(Arm64Ty, GlobalValue::WeakODRLinkage, 0, ThunkName, M);
+  GuestExit->setComdat(M->getOrInsertComdat(ThunkName));
+  GuestExit->setSection(".wowthk$aa");
+  BasicBlock *BB = BasicBlock::Create(M->getContext(), "", GuestExit);
+  IRBuilder<> B(BB);
+
+  // Load the global symbol as a pointer to the check function.
+  LoadInst *DispatchLoad = B.CreateLoad(DispatchFnPtrType, DispatchFnGlobal);
+
+  // Create new dispatch call instruction.
+  Function *ExitThunk =
+      buildExitThunk(F->getFunctionType(), F->getAttributes());
+  CallInst *Dispatch =
+      B.CreateCall(DispatchFnType, DispatchLoad,
+                   {UnmangledAlias, ExitThunk, UnmangledAlias->getAliasee()});
+
+  // Ensure that the first arguments are passed in the correct registers.
+  Dispatch->setCallingConv(CallingConv::CFGuard_Check);
+
+  Value *DispatchRetVal = B.CreateBitCast(Dispatch, PtrTy);
+  SmallVector<Value *> Args;
+  for (Argument &Arg : GuestExit->args())
+    Args.push_back(&Arg);
+  CallInst *Call = B.CreateCall(Arm64Ty, DispatchRetVal, Args);
+  Call->setTailCallKind(llvm::CallInst::TCK_MustTail);
+
+  if (Call->getType()->isVoidTy())
+    B.CreateRetVoid();
+  else
+    B.CreateRet(Call);
+
+  auto SRetAttr = F->getAttributes().getParamAttr(0, Attribute::StructRet);
+  auto InRegAttr = F->getAttributes().getParamAttr(0, Attribute::InReg);
+  if (SRetAttr.isValid() && !InRegAttr.isValid()) {
+    GuestExit->addParamAttr(0, SRetAttr);
+    Call->addParamAttr(0, SRetAttr);
+  }
+
+  MangledAlias->setAliasee(GuestExit);
+  return GuestExit;
+}
+
 // Lower an indirect call with inline code.
 void AArch64Arm64ECCallLowering::lowerCall(CallBase *CB) {
   assert(Triple(CB->getModule()->getTargetTriple()).isOSWindows() &&
@@ -727,17 +794,57 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
 
   GuardFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy}, false);
   GuardFnPtrType = PointerType::get(GuardFnType, 0);
+  DispatchFnType = FunctionType::get(PtrTy, {PtrTy, PtrTy, PtrTy}, false);
+  DispatchFnPtrType = PointerType::get(DispatchFnType, 0);
   GuardFnCFGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall_cfg", GuardFnPtrType);
   GuardFnGlobal =
       M->getOrInsertGlobal("__os_arm64x_check_icall", GuardFnPtrType);
+  DispatchFnGlobal =
+      M->getOrInsertGlobal("__os_arm64x_dispatch_call", DispatchFnPtrType);
+
+  DenseMap<GlobalAlias *, GlobalAlias *> FnsMap;
+  SetVector<GlobalAlias *> PatchableFns;
 
-  SetVector<Function *> DirectCalledFns;
+  for (Function &F : Mod) {
+    if (!F.hasFnAttribute(Attribute::HybridPatchable) || F.isDeclaration() ||
+        F.hasLocalLinkage() || F.getName().ends_with("$hp_target"))
+      continue;
+
+    // Rename hybrid patchable functions and change callers to use a global
+    // alias instead.
+    if (std::optional<std::string> MangledName =
+            getArm64ECMangledFunctionName(F.getName().str())) {
+      std::string OrigName(F.getName());
+      F.setName(MangledName.value() + "$hp_target");
+
+      // The unmangled symbol is a weak alias to an undefined symbol with the
+      // "EXP+" prefix. This undefined symbol is resolved by the linker by
+      // creating an x86 thunk that jumps back to the actual EC target. Since we
+      // can't represent that in IR, we create an alias to the target instead.
+      // The "EXP+" symbol is set as metadata, which is then used by
+      // emitGlobalAlias to emit the right alias.
+      auto *A =
+          GlobalAlias::create(GlobalValue::LinkOnceODRLinkage, OrigName, &F);
+      F.replaceAllUsesWith(A);
+      F.setMetadata("arm64ec_exp_name",
+                    MDNode::get(M->getContext(),
+                                MDString::get(M->getContext(),
+                                              "EXP+" + MangledName.value())));
+      A->setAliasee(&F);
+
+      FnsMap[A] = GlobalAlias::create(GlobalValue::LinkOnceODRLinkage,
+                                      MangledName.value(), &F);
+      PatchableFns.insert(A);
+    }
+  }
+
+  SetVector<GlobalValue *> DirectCalledFns;
   for (Function &F : Mod)
     if (!F.isDeclaration() &&
         F.getCallingConv() != CallingConv::ARM64EC_Thunk_Native &&
         F.getCallingConv() != CallingConv::ARM64EC_Thunk_X64)
-      processFunction(F, DirectCalledFns);
+      processFunction(F, DirectCalledFns, FnsMap);
 
   struct ThunkInfo {
     Constant *Src;
@@ -755,14 +862,20 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
           {&F, buildEntryThunk(&F), Arm64ECThunkType::Entry});
     }
   }
-  for (Function *F : DirectCalledFns) {
+  for (GlobalValue *O : DirectCalledFns) {
+    auto GA = dyn_cast<GlobalAlias>(O);
+    auto F = dyn_cast<Function>(GA ? GA->getAliasee() : O);
     ThunkMapping.push_back(
-        {F, buildExitThunk(F->getFunctionType(), F->getAttributes()),
+        {O, buildExitThunk(F->getFunctionType(), F->getAttributes()),
          Arm64ECThunkType::Exit});
-    if (!F->hasDLLImportStorageClass())
+    if (!GA && !F->hasDLLImportStorageClass())
       ThunkMapping.push_back(
           {buildGuestExitThunk(F), F, Arm64ECThunkType::GuestExit});
   }
+  for (GlobalAlias *A : PatchableFns) {
+    Function *Thunk = buildPatchableThunk(A, FnsMap[A]);
+    ThunkMapping.push_back({Thunk, A, Arm64ECThunkType::GuestExit});
+  }
 
   if (!ThunkMapping.empty()) {
     SmallVector<Constant *> ThunkMappingArrayElems;
@@ -785,7 +898,8 @@ bool AArch64Arm64ECCallLowering::runOnModule(Module &Mod) {
 }
 
 bool AArch64Arm64ECCallLowering::processFunction(
-    Function &F, SetVector<Function *> &DirectCalledFns) {
+    Function &F, SetVector<GlobalValue *> &DirectCalledFns,
+    DenseMap<GlobalAlias *, GlobalAlias *> &FnsMap) {
   SmallVector<CallBase *, 8> IndirectCalls;
 
   // For ARM64EC targets, a function definition's name is mangled differently
@@ -837,6 +951,16 @@ bool AArch64Arm64ECCallLowering::processFunction(
         continue;
       }
 
+      // Use mangled global alias for direct calls to patchable functions.
+      if (GlobalAlias *A = dyn_cast<GlobalAlias>(CB->getCalledOperand())) {
+        auto I = FnsMap.find(A);
+        if (I != FnsMap.end()) {
+          CB->setCalledOperand(I->second);
+          DirectCalledFns.insert(I->first);
+          continue;
+        }
+      }
+
       IndirectCalls.push_back(CB);
       ++Arm64ECCallsLowered;
     }
diff --git a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
index 6f6b551993b6d..63358c1568a35 100644
--- a/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
+++ b/llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
@@ -201,6 +201,7 @@ class AArch64AsmPrinter : public AsmPrinter {
   void PrintDebugValueComment(const MachineInstr *MI, raw_ostream &OS);
 
   void emitFunctionBodyEnd() override;
+  void emitGlobalAlias(const Module &M, const GlobalAlias &GA) override;
 
   MCSymbol *GetCPISymbol(unsigned CPID) const override;
   void emitEndOfAsmFile(Module &M) override;
@@ -1263,6 +1264,32 @@ void AArch64AsmPrinter::emitFunctionEntryLabel() {
   }
 }
 
+void AArch64AsmPrinter::emitGlobalAlias(const Module &M,
+                                        const GlobalAlias &GA) {
+  if (auto F = dyn_cast_or_null<Function>(GA.getAliasee())) {
+    // Global aliases must poin...
[truncated]

Copy link
Collaborator

@AaronBallman AaronBallman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point, you should also add documentation to clang/docs/ReleaseNotes.rst so that users know about the new functionality.

@@ -4027,6 +4030,12 @@ def SelectAny : InheritableAttr {
let SimpleHandler = 1;
}

def HybridPatchable : DeclOrTypeAttr, TargetSpecificAttr<TargetArm64EC> {
let Spellings = [Declspec<"hybrid_patchable">, GCC<"hybrid_patchable">];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see any evidence that GCC supports this attribute, were you aiming for supporting __attribute__((hybrid_patchable)) and not [[gnu::hybrid_patchable]]? If so, that would be a GNU spelling instead of a GCC spelling, but we typically would use a Clang spelling so that [[clang::hybrid_patchable]] works.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking mostly about __attribute__((hybrid_patchable)) as it's the form that mingw will mostly likely want to use, so I meant GNU. Clang sounds good to me if proffered, I changed that and added a test for it.

@@ -477,6 +477,9 @@ def TargetELF : TargetSpec {
def TargetELFOrMachO : TargetSpec {
let ObjectFormats = ["ELF", "MachO"];
}
def TargetArm64EC : TargetSpec {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this say Windows somewhere in the name to make it more clear that this isn't an ARM64 attribute in general?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it to TargetWindowsArm64EC in the new version.

@@ -4027,6 +4030,12 @@ def SelectAny : InheritableAttr {
let SimpleHandler = 1;
}

def HybridPatchable : DeclOrTypeAttr, TargetSpecificAttr<TargetArm64EC> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You've specified this as a type attribute but implemented it purely as a declaration attribute; did you intend for this to be InheritableAttr instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I don't remember why I did it like that. I changed it in the new version.

if (!ND.isExternallyVisible()) {
S.Diag(Attr->getLocation(),
diag::warn_attribute_hybrid_patchable_non_extern);
ND.dropAttr<SelectAnyAttr>();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
ND.dropAttr<SelectAnyAttr>();
ND.dropAttr<HybridPatchableAttr>();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, it was a copoy&paste typo. Actually, we don't really need it as LLVM part ignores the attribute for static functions itself. I removed that line and left only diagnostics part here to slightly simplify the code.

@cjacek
Copy link
Contributor Author

cjacek commented Jul 19, 2024

At some point, you should also add documentation to clang/docs/ReleaseNotes.rst so that users know about the new functionality.

I included release notes changelog in the new version. I also rebased it on top of merged #92965 and addressed other comments. Thanks for reviews!

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

(Please cherry-pick to 19; I'd like to have this in 19. And maybe cjacek@4febef2, if you wouldn't mind opening a pull request for that.)

@cjacek cjacek merged commit ea98dc8 into llvm:main Jul 27, 2024
8 checks passed
@cjacek cjacek deleted the clang-hybrid-patchable branch July 27, 2024 12:29
@cjacek
Copy link
Contributor Author

cjacek commented Jul 27, 2024

/cherry-pick ea98dc8

@cjacek
Copy link
Contributor Author

cjacek commented Jul 27, 2024

Thanks. I created #100872 for dllexport fix.

@llvmbot
Copy link
Collaborator

llvmbot commented Jul 27, 2024

/cherry-pick ea98dc8

Error: Command failed due to missing milestone.

@cjacek cjacek added this to the LLVM 19.X Release milestone Jul 27, 2024
@cjacek
Copy link
Contributor Author

cjacek commented Jul 27, 2024

/cherry-pick ea98dc8

llvmbot pushed a commit to llvmbot/llvm-project that referenced this pull request Jul 27, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 27, 2024

/pull-request #100873

tru pushed a commit to llvmbot/llvm-project that referenced this pull request Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AArch64 clang:codegen clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category llvm:ir llvm:transforms
Projects
Development

Successfully merging this pull request may close these issues.

4 participants