Skip to content

Commit a57d5d0

Browse files
xal-0KristofferC
authored andcommitted
Add JLJITLinkMemoryManager (ports memory manager to JITLink) (#60105)
Ports our RTDyLD memory manager to JITLink in order to avoid memory use regressions after switching to JITLink everywhere (#60031). This is a direct port: finalization must happen all at once, because it invalidates all allocation `wr_ptr`s. I decided it wasn't worth it to associate `OnFinalizedFunction` callbacks with each block, since they are large enough to make it extremely likely that all in-flight allocations land in the same block; everything must be relocated before finalization can happen. (cherry picked from commit 6fa0e75)
1 parent d05709c commit a57d5d0

File tree

2 files changed

+188
-52
lines changed

2 files changed

+188
-52
lines changed

src/cgmemmgr.cpp

Lines changed: 187 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
#include "llvm-version.h"
44
#include "platform.h"
55

6+
#include <llvm/ExecutionEngine/JITLink/JITLink.h>
7+
#include <llvm/ExecutionEngine/JITLink/JITLinkMemoryManager.h>
8+
#include <llvm/ExecutionEngine/Orc/MapperJITLinkMemoryManager.h>
69
#include <llvm/ExecutionEngine/SectionMemoryManager.h>
10+
711
#include "julia.h"
812
#include "julia_internal.h"
913

@@ -458,26 +462,36 @@ struct Block {
458462
}
459463
};
460464

465+
struct Allocation {
466+
// Address to write to (the one returned by the allocation function)
467+
void *wr_addr;
468+
// Runtime address
469+
void *rt_addr;
470+
size_t sz;
471+
bool relocated;
472+
};
473+
461474
class RWAllocator {
462475
static constexpr int nblocks = 8;
463476
Block blocks[nblocks]{};
464477
public:
465478
RWAllocator() JL_NOTSAFEPOINT = default;
466-
void *alloc(size_t size, size_t align) JL_NOTSAFEPOINT
479+
Allocation alloc(size_t size, size_t align) JL_NOTSAFEPOINT
467480
{
468481
size_t min_size = (size_t)-1;
469482
int min_id = 0;
470483
for (int i = 0;i < nblocks && blocks[i].ptr;i++) {
471484
if (void *ptr = blocks[i].alloc(size, align))
472-
return ptr;
485+
return {ptr, ptr, size, false};
473486
if (blocks[i].avail < min_size) {
474487
min_size = blocks[i].avail;
475488
min_id = i;
476489
}
477490
}
478491
size_t block_size = get_block_size(size);
479492
blocks[min_id].reset(map_anon_page(block_size), block_size);
480-
return blocks[min_id].alloc(size, align);
493+
void *ptr = blocks[min_id].alloc(size, align);
494+
return {ptr, ptr, size, false};
481495
}
482496
};
483497

@@ -517,16 +531,6 @@ struct SplitPtrBlock : public Block {
517531
}
518532
};
519533

520-
struct Allocation {
521-
// Address to write to (the one returned by the allocation function)
522-
void *wr_addr;
523-
// Runtime address
524-
void *rt_addr;
525-
size_t sz;
526-
bool relocated;
527-
};
528-
529-
template<bool exec>
530534
class ROAllocator {
531535
protected:
532536
static constexpr int nblocks = 8;
@@ -554,7 +558,7 @@ class ROAllocator {
554558
}
555559
// Allocations that have not been finalized yet.
556560
SmallVector<Allocation, 16> allocations;
557-
void *alloc(size_t size, size_t align) JL_NOTSAFEPOINT
561+
Allocation alloc(size_t size, size_t align) JL_NOTSAFEPOINT
558562
{
559563
size_t min_size = (size_t)-1;
560564
int min_id = 0;
@@ -570,8 +574,9 @@ class ROAllocator {
570574
wr_ptr = get_wr_ptr(block, ptr, size, align);
571575
}
572576
block.state |= SplitPtrBlock::Alloc;
573-
allocations.push_back(Allocation{wr_ptr, ptr, size, false});
574-
return wr_ptr;
577+
Allocation a{wr_ptr, ptr, size, false};
578+
allocations.push_back(a);
579+
return a;
575580
}
576581
if (block.avail < min_size) {
577582
min_size = block.avail;
@@ -592,18 +597,21 @@ class ROAllocator {
592597
#ifdef _OS_WINDOWS_
593598
block.state = SplitPtrBlock::Alloc;
594599
void *wr_ptr = get_wr_ptr(block, ptr, size, align);
595-
allocations.push_back(Allocation{wr_ptr, ptr, size, false});
600+
Allocation a{wr_ptr, ptr, size, false};
601+
allocations.push_back(a);
596602
ptr = wr_ptr;
597603
#else
598604
block.state = SplitPtrBlock::Alloc | SplitPtrBlock::InitAlloc;
599-
allocations.push_back(Allocation{ptr, ptr, size, false});
605+
Allocation a{ptr, ptr, size, false};
606+
allocations.push_back(a);
600607
#endif
601-
return ptr;
608+
return a;
602609
}
603610
};
604611

605-
template<bool exec>
606-
class DualMapAllocator : public ROAllocator<exec> {
612+
class DualMapAllocator : public ROAllocator {
613+
bool exec;
614+
607615
protected:
608616
void *get_wr_ptr(SplitPtrBlock &block, void *rt_ptr, size_t, size_t) override JL_NOTSAFEPOINT
609617
{
@@ -664,7 +672,7 @@ class DualMapAllocator : public ROAllocator<exec> {
664672
}
665673
}
666674
public:
667-
DualMapAllocator() JL_NOTSAFEPOINT
675+
DualMapAllocator(bool exec) JL_NOTSAFEPOINT : exec(exec)
668676
{
669677
assert(anon_hdl != -1);
670678
}
@@ -677,13 +685,13 @@ class DualMapAllocator : public ROAllocator<exec> {
677685
finalize_block(block, true);
678686
block.reset(nullptr, 0);
679687
}
680-
ROAllocator<exec>::finalize();
688+
ROAllocator::finalize();
681689
}
682690
};
683691

684692
#ifdef _OS_LINUX_
685-
template<bool exec>
686-
class SelfMemAllocator : public ROAllocator<exec> {
693+
class SelfMemAllocator : public ROAllocator {
694+
bool exec;
687695
SmallVector<Block, 16> temp_buff;
688696
protected:
689697
void *get_wr_ptr(SplitPtrBlock &block, void *rt_ptr,
@@ -720,9 +728,7 @@ class SelfMemAllocator : public ROAllocator<exec> {
720728
}
721729
}
722730
public:
723-
SelfMemAllocator() JL_NOTSAFEPOINT
724-
: ROAllocator<exec>(),
725-
temp_buff()
731+
SelfMemAllocator(bool exec) JL_NOTSAFEPOINT : exec(exec), temp_buff()
726732
{
727733
assert(get_self_mem_fd() != -1);
728734
}
@@ -756,11 +762,25 @@ class SelfMemAllocator : public ROAllocator<exec> {
756762
}
757763
if (cached)
758764
temp_buff.resize(1);
759-
ROAllocator<exec>::finalize();
765+
ROAllocator::finalize();
760766
}
761767
};
762768
#endif // _OS_LINUX_
763769

770+
std::pair<std::unique_ptr<ROAllocator>, std::unique_ptr<ROAllocator>>
771+
get_preferred_allocators() JL_NOTSAFEPOINT
772+
{
773+
#ifdef _OS_LINUX_
774+
if (get_self_mem_fd() != -1)
775+
return {std::make_unique<SelfMemAllocator>(false),
776+
std::make_unique<SelfMemAllocator>(true)};
777+
#endif
778+
if (init_shared_map() != -1)
779+
return {std::make_unique<DualMapAllocator>(false),
780+
std::make_unique<DualMapAllocator>(true)};
781+
return {};
782+
}
783+
764784
class RTDyldMemoryManagerJL : public SectionMemoryManager {
765785
struct EHFrame {
766786
uint8_t *addr;
@@ -770,29 +790,18 @@ class RTDyldMemoryManagerJL : public SectionMemoryManager {
770790
void operator=(const RTDyldMemoryManagerJL&) = delete;
771791
SmallVector<EHFrame, 16> pending_eh;
772792
RWAllocator rw_alloc;
773-
std::unique_ptr<ROAllocator<false>> ro_alloc;
774-
std::unique_ptr<ROAllocator<true>> exe_alloc;
793+
std::unique_ptr<ROAllocator> ro_alloc;
794+
std::unique_ptr<ROAllocator> exe_alloc;
775795
size_t total_allocated;
776796

777797
public:
778798
RTDyldMemoryManagerJL() JL_NOTSAFEPOINT
779799
: SectionMemoryManager(),
780800
pending_eh(),
781801
rw_alloc(),
782-
ro_alloc(),
783-
exe_alloc(),
784802
total_allocated(0)
785803
{
786-
#ifdef _OS_LINUX_
787-
if (!ro_alloc && get_self_mem_fd() != -1) {
788-
ro_alloc.reset(new SelfMemAllocator<false>());
789-
exe_alloc.reset(new SelfMemAllocator<true>());
790-
}
791-
#endif
792-
if (!ro_alloc && init_shared_map() != -1) {
793-
ro_alloc.reset(new DualMapAllocator<false>());
794-
exe_alloc.reset(new DualMapAllocator<true>());
795-
}
804+
std::tie(ro_alloc, exe_alloc) = get_preferred_allocators();
796805
}
797806
~RTDyldMemoryManagerJL() override JL_NOTSAFEPOINT
798807
{
@@ -845,7 +854,7 @@ uint8_t *RTDyldMemoryManagerJL::allocateCodeSection(uintptr_t Size,
845854
jl_timing_counter_inc(JL_TIMING_COUNTER_JITSize, Size);
846855
jl_timing_counter_inc(JL_TIMING_COUNTER_JITCodeSize, Size);
847856
if (exe_alloc)
848-
return (uint8_t*)exe_alloc->alloc(Size, Alignment);
857+
return (uint8_t*)exe_alloc->alloc(Size, Alignment).wr_addr;
849858
return SectionMemoryManager::allocateCodeSection(Size, Alignment, SectionID,
850859
SectionName);
851860
}
@@ -860,9 +869,9 @@ uint8_t *RTDyldMemoryManagerJL::allocateDataSection(uintptr_t Size,
860869
jl_timing_counter_inc(JL_TIMING_COUNTER_JITSize, Size);
861870
jl_timing_counter_inc(JL_TIMING_COUNTER_JITDataSize, Size);
862871
if (!isReadOnly)
863-
return (uint8_t*)rw_alloc.alloc(Size, Alignment);
872+
return (uint8_t*)rw_alloc.alloc(Size, Alignment).wr_addr;
864873
if (ro_alloc)
865-
return (uint8_t*)ro_alloc->alloc(Size, Alignment);
874+
return (uint8_t*)ro_alloc->alloc(Size, Alignment).wr_addr;
866875
return SectionMemoryManager::allocateDataSection(Size, Alignment, SectionID,
867876
SectionName, isReadOnly);
868877
}
@@ -917,6 +926,133 @@ void RTDyldMemoryManagerJL::deregisterEHFrames(uint8_t *Addr,
917926
}
918927
#endif
919928

929+
class JLJITLinkMemoryManager : public jitlink::JITLinkMemoryManager {
930+
using OnFinalizedFunction =
931+
jitlink::JITLinkMemoryManager::InFlightAlloc::OnFinalizedFunction;
932+
933+
std::mutex Mutex;
934+
RWAllocator RWAlloc;
935+
std::unique_ptr<ROAllocator> ROAlloc;
936+
std::unique_ptr<ROAllocator> ExeAlloc;
937+
SmallVector<OnFinalizedFunction> FinalizedCallbacks;
938+
uint32_t InFlight{0};
939+
940+
public:
941+
class InFlightAlloc;
942+
943+
static std::unique_ptr<JITLinkMemoryManager> Create()
944+
{
945+
auto [ROAlloc, ExeAlloc] = get_preferred_allocators();
946+
if (ROAlloc && ExeAlloc)
947+
return std::unique_ptr<JLJITLinkMemoryManager>(
948+
new JLJITLinkMemoryManager(std::move(ROAlloc), std::move(ExeAlloc)));
949+
950+
return cantFail(
951+
orc::MapperJITLinkMemoryManager::CreateWithMapper<orc::InProcessMemoryMapper>(
952+
/*Reservation Granularity*/ 16 * 1024 * 1024));
953+
}
954+
955+
void allocate(const jitlink::JITLinkDylib *JD, jitlink::LinkGraph &G,
956+
OnAllocatedFunction OnAllocated) override;
957+
958+
void deallocate(std::vector<FinalizedAlloc> Allocs,
959+
OnDeallocatedFunction OnDeallocated) override
960+
{
961+
jl_unreachable();
962+
}
963+
964+
protected:
965+
JLJITLinkMemoryManager(std::unique_ptr<ROAllocator> ROAlloc,
966+
std::unique_ptr<ROAllocator> ExeAlloc)
967+
: ROAlloc(std::move(ROAlloc)), ExeAlloc(std::move(ExeAlloc))
968+
{
969+
}
970+
971+
void finalize(OnFinalizedFunction OnFinalized)
972+
{
973+
SmallVector<OnFinalizedFunction> Callbacks;
974+
{
975+
std::unique_lock Lock{Mutex};
976+
FinalizedCallbacks.push_back(std::move(OnFinalized));
977+
978+
if (--InFlight > 0)
979+
return;
980+
981+
ROAlloc->finalize();
982+
ExeAlloc->finalize();
983+
Callbacks = std::move(FinalizedCallbacks);
984+
}
985+
986+
for (auto &CB : Callbacks)
987+
std::move(CB)(FinalizedAlloc{});
988+
}
989+
};
990+
991+
class JLJITLinkMemoryManager::InFlightAlloc
992+
: public jitlink::JITLinkMemoryManager::InFlightAlloc {
993+
JLJITLinkMemoryManager &MM;
994+
jitlink::LinkGraph &G;
995+
996+
public:
997+
InFlightAlloc(JLJITLinkMemoryManager &MM, jitlink::LinkGraph &G) : MM(MM), G(G) {}
998+
999+
void abandon(OnAbandonedFunction OnAbandoned) override { jl_unreachable(); }
1000+
1001+
void finalize(OnFinalizedFunction OnFinalized) override
1002+
{
1003+
auto *GP = &G;
1004+
MM.finalize([GP, OnFinalized =
1005+
std::move(OnFinalized)](Expected<FinalizedAlloc> FA) mutable {
1006+
if (!FA)
1007+
return OnFinalized(FA.takeError());
1008+
// Need to handle dealloc actions when we GC code
1009+
auto E = orc::shared::runFinalizeActions(GP->allocActions());
1010+
if (!E)
1011+
return OnFinalized(E.takeError());
1012+
OnFinalized(std::move(FA));
1013+
});
1014+
}
1015+
};
1016+
1017+
using orc::MemProt;
1018+
1019+
void JLJITLinkMemoryManager::allocate(const jitlink::JITLinkDylib *JD,
1020+
jitlink::LinkGraph &G,
1021+
OnAllocatedFunction OnAllocated)
1022+
{
1023+
jitlink::BasicLayout BL{G};
1024+
1025+
{
1026+
std::unique_lock Lock{Mutex};
1027+
for (auto &[AG, Seg] : BL.segments()) {
1028+
if (AG.getMemLifetime() == orc::MemLifetime::NoAlloc)
1029+
continue;
1030+
assert(AG.getMemLifetime() == orc::MemLifetime::Standard);
1031+
1032+
auto Prot = AG.getMemProt();
1033+
uint64_t Alignment = Seg.Alignment.value();
1034+
uint64_t Size = Seg.ContentSize + Seg.ZeroFillSize;
1035+
Allocation Alloc;
1036+
if (Prot == (MemProt::Read | MemProt::Write))
1037+
Alloc = RWAlloc.alloc(Size, Alignment);
1038+
else if (Prot == MemProt::Read)
1039+
Alloc = ROAlloc->alloc(Size, Alignment);
1040+
else if (Prot == (MemProt::Read | MemProt::Exec))
1041+
Alloc = ExeAlloc->alloc(Size, Alignment);
1042+
else
1043+
abort();
1044+
1045+
Seg.Addr = orc::ExecutorAddr::fromPtr(Alloc.rt_addr);
1046+
Seg.WorkingMem = (char *)Alloc.wr_addr;
1047+
}
1048+
}
1049+
1050+
if (auto Err = BL.apply())
1051+
return OnAllocated(std::move(Err));
1052+
1053+
++InFlight;
1054+
OnAllocated(std::make_unique<InFlightAlloc>(*this, G));
1055+
}
9201056
}
9211057

9221058
RTDyldMemoryManager* createRTDyldMemoryManager() JL_NOTSAFEPOINT
@@ -928,3 +1064,8 @@ size_t getRTDyldMemoryManagerTotalBytes(RTDyldMemoryManager *mm) JL_NOTSAFEPOINT
9281064
{
9291065
return ((RTDyldMemoryManagerJL*)mm)->getTotalBytes();
9301066
}
1067+
1068+
std::unique_ptr<jitlink::JITLinkMemoryManager> createJITLinkMemoryManager()
1069+
{
1070+
return JLJITLinkMemoryManager::Create();
1071+
}

src/jitlayers.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,12 +1156,6 @@ class JLMemoryUsagePlugin : public ObjectLinkingLayer::Plugin {
11561156
#pragma clang diagnostic ignored "-Wunused-function"
11571157
#endif
11581158

1159-
// TODO: Port our memory management optimisations to JITLink instead of using the
1160-
// default InProcessMemoryManager.
1161-
std::unique_ptr<jitlink::JITLinkMemoryManager> createJITLinkMemoryManager() JL_NOTSAFEPOINT {
1162-
return cantFail(orc::MapperJITLinkMemoryManager::CreateWithMapper<orc::InProcessMemoryMapper>(/*Reservation Granularity*/ 16 * 1024 * 1024));
1163-
}
1164-
11651159
#ifdef _COMPILER_CLANG_
11661160
#pragma clang diagnostic pop
11671161
#endif
@@ -1185,6 +1179,7 @@ class JLEHFrameRegistrar final : public jitlink::EHFrameRegistrar {
11851179
};
11861180

11871181
RTDyldMemoryManager *createRTDyldMemoryManager(void) JL_NOTSAFEPOINT;
1182+
std::unique_ptr<jitlink::JITLinkMemoryManager> createJITLinkMemoryManager() JL_NOTSAFEPOINT;
11881183

11891184
// A simple forwarding class, since OrcJIT v2 needs a unique_ptr, while we have a shared_ptr
11901185
class ForwardingMemoryManager : public RuntimeDyld::MemoryManager {

0 commit comments

Comments
 (0)