Skip to content

Commit

Permalink
[SOL] Correctly copy 16-byte aligned memory (#97)
Browse files Browse the repository at this point in the history
* Fix issue with copying 16-byte aligned memory
  • Loading branch information
LucasSte authored Jun 20, 2024
1 parent 3f1b8db commit 7d055ed
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 34 deletions.
72 changes: 43 additions & 29 deletions llvm/lib/Target/SBF/SBFInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ void SBFInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
DebugLoc dl = MI->getDebugLoc();
unsigned LdOpc, StOpc;

unsigned BytesPerOp = std::min(static_cast<unsigned>(Alignment), 8u);
switch (Alignment) {
case 1:
LdOpc = SBF::LDB;
Expand All @@ -66,49 +67,62 @@ void SBFInstrInfo::expandMEMCPY(MachineBasicBlock::iterator MI) const {
StOpc = SBF::STW;
break;
case 8:
case 16:
LdOpc = SBF::LDD;
StOpc = SBF::STD;
break;
default:
llvm_unreachable("unsupported memcpy alignment");
}

unsigned IterationNum = CopyLen >> Log2_64(Alignment);
for(unsigned I = 0; I < IterationNum; ++I) {
unsigned IterationNum = (CopyLen >> Log2_64(BytesPerOp));
for (unsigned I = 0; I < IterationNum; ++I) {
BuildMI(*BB, MI, dl, get(LdOpc))
.addReg(ScratchReg, RegState::Define).addReg(SrcReg)
.addImm(I * Alignment);
.addReg(ScratchReg, RegState::Define)
.addReg(SrcReg)
.addImm(I * BytesPerOp);
BuildMI(*BB, MI, dl, get(StOpc))
.addReg(ScratchReg, RegState::Kill).addReg(DstReg)
.addImm(I * Alignment);
.addReg(ScratchReg, RegState::Kill)
.addReg(DstReg)
.addImm(I * BytesPerOp);
}

unsigned BytesLeft = CopyLen & (Alignment - 1);
unsigned Offset = IterationNum * Alignment;
bool Hanging4Byte = BytesLeft & 0x4;
bool Hanging2Byte = BytesLeft & 0x2;
bool Hanging1Byte = BytesLeft & 0x1;
if (Hanging4Byte) {
BuildMI(*BB, MI, dl, get(SBF::LDW))
.addReg(ScratchReg, RegState::Define).addReg(SrcReg).addImm(Offset);
BuildMI(*BB, MI, dl, get(SBF::STW))
.addReg(ScratchReg, RegState::Kill).addReg(DstReg).addImm(Offset);
Offset += 4;
unsigned BytesLeft = CopyLen - IterationNum * BytesPerOp;
unsigned Offset;
if (BytesLeft == 0) {
BB->erase(MI);
return;
}
if (Hanging2Byte) {
BuildMI(*BB, MI, dl, get(SBF::LDH))
.addReg(ScratchReg, RegState::Define).addReg(SrcReg).addImm(Offset);
BuildMI(*BB, MI, dl, get(SBF::STH))
.addReg(ScratchReg, RegState::Kill).addReg(DstReg).addImm(Offset);
Offset += 2;
}
if (Hanging1Byte) {
BuildMI(*BB, MI, dl, get(SBF::LDB))
.addReg(ScratchReg, RegState::Define).addReg(SrcReg).addImm(Offset);
BuildMI(*BB, MI, dl, get(SBF::STB))
.addReg(ScratchReg, RegState::Kill).addReg(DstReg).addImm(Offset);

if (BytesLeft < 2) {
Offset = CopyLen - 1;
LdOpc = SBF::LDB;
StOpc = SBF::STB;
} else if (BytesLeft <= 2) {
Offset = CopyLen - 2;
LdOpc = SBF::LDH;
StOpc = SBF::STH;
} else if (BytesLeft <= 4) {
Offset = CopyLen - 4;
LdOpc = SBF::LDW;
StOpc = SBF::STW;
} else if (BytesLeft <= 8) {
Offset = CopyLen - 8;
LdOpc = SBF::LDD;
StOpc = SBF::STD;
} else {
llvm_unreachable("There cannot be more than 8 bytes left");
}

BuildMI(*BB, MI, dl, get(LdOpc))
.addReg(ScratchReg, RegState::Define)
.addReg(SrcReg)
.addImm(Offset);
BuildMI(*BB, MI, dl, get(StOpc))
.addReg(ScratchReg, RegState::Kill)
.addReg(DstReg)
.addImm(Offset);

BB->erase(MI);
}

Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/Target/SBF/SBFSelectionDAGInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ SDValue SBFSelectionDAGInfo::EmitTargetCodeForMemcpy(
return SDValue();

unsigned CopyLen = ConstantSize->getZExtValue();
unsigned StoresNumEstimate = alignTo(CopyLen, Alignment) >> Log2(Alignment);
// If the alignment is greater than 8, we can only store and load 8 bytes at a
// time.
uint64_t BytesPerOp = std::min(Alignment.value(), static_cast<uint64_t>(8));
unsigned StoresNumEstimate =
alignTo(CopyLen, Alignment) >> Log2_64(BytesPerOp);
// Impose the same copy length limit as MaxStoresPerMemcpy.
if (StoresNumEstimate > getCommonMaxStoresPerMemFunc())
return SDValue();
Expand Down
6 changes: 2 additions & 4 deletions llvm/test/CodeGen/SBF/memcpy-expand-in-order.ll
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,5 @@ entry:
; CHECK: stxdw [[[DST_REG]] + 8], [[SCRATCH_REG]]
; CHECK: ldxdw [[SCRATCH_REG]], [[[SRC_REG]] + 16]
; CHECK: stxdw [[[DST_REG]] + 16], [[SCRATCH_REG]]
; CHECK: ldxh [[SCRATCH_REG]], [[[SRC_REG]] + 24]
; CHECK: stxh [[[DST_REG]] + 24], [[SCRATCH_REG]]
; CHECK: ldxb [[SCRATCH_REG]], [[[SRC_REG]] + 26]
; CHECK: stxb [[[DST_REG]] + 26], [[SCRATCH_REG]]
; CHECK: ldxw [[SCRATCH_REG]], [[[SRC_REG]] + 23]
; CHECK: stxw [[[DST_REG]] + 23], [[SCRATCH_REG]]
92 changes: 92 additions & 0 deletions llvm/test/CodeGen/SBF/memcpy_16.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
; RUN: llc < %s -march=sbf -sbf-expand-memcpy-in-order | FileCheck %s

; Function Attrs: mustprogress nocallback nofree nounwind willreturn memory(argmem: readwrite)
declare void @llvm.memcpy.p0.p0.i64(ptr noalias nocapture writeonly, ptr noalias nocapture readonly, i64, i1 immarg) #1

define void @memcpy_test_1(ptr align 16 %a, ptr align 16 %b) local_unnamed_addr #0 {
entry:
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %a, ptr align 16 %b, i64 32, i1 0)

; 4 pairs of loads and stores
; CHECK: memcpy_test_1
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 0]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 0], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 8]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 8], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 16]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 16], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 24]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 24], [[SCRATCH_REG:r[0-9]]]
ret void
}

define void @memcpy_test_2(ptr align 16 %a, ptr align 16 %b) local_unnamed_addr #0 {
entry:
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %a, ptr align 16 %b, i64 17, i1 0)

; 2 pairs of loads and stores + 1 pair for the byte
; CHECK: memcpy_test_2
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 0]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 0], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 8]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 8], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxb [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 16]
; CHECK: stxb [[[DST_REG:r[0-9]]] + 16], [[SCRATCH_REG:r[0-9]]]
ret void
}

define void @memcpy_test_3(ptr align 16 %a, ptr align 16 %b) local_unnamed_addr #0 {
entry:
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %a, ptr align 16 %b, i64 18, i1 0)

; 2 pairs of loads and stores + 1 pair for the 2 bytes
; CHECK: memcpy_test_3
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 0]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 0], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 8]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 8], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxh [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 16]
; CHECK: stxh [[[DST_REG:r[0-9]]] + 16], [[SCRATCH_REG:r[0-9]]]
ret void
}

define void @memcpy_test_4(ptr align 16 %a, ptr align 16 %b) local_unnamed_addr #0 {
entry:
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %a, ptr align 16 %b, i64 19, i1 0)

; 2 pairs of loads and stores + 1 pair for the 3 bytes
; CHECK: memcpy_test_4
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 0]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 0], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 8]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 8], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 15]
; CHECK: stxw [[[DST_REG:r[0-9]]] + 15], [[SCRATCH_REG:r[0-9]]]
ret void
}

define void @memcpy_test_5(ptr align 16 %a, ptr align 16 %b) local_unnamed_addr #0 {
entry:
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %a, ptr align 16 %b, i64 21, i1 0)

; 2 pairs of loads and stores + 1 pair for the 5 bytes
; CHECK: memcpy_test_5
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 0]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 0], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 8]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 8], [[SCRATCH_REG:r[0-9]]]
; CHECK: ldxdw [[SCRATCH_REG:r[0-9]]], [[[SRC_REG:r[0-9]]] + 13]
; CHECK: stxdw [[[DST_REG:r[0-9]]] + 13], [[SCRATCH_REG:r[0-9]]]
ret void
}

define void @memcpy_test_6(ptr align 16 %a, ptr align 16 %b) local_unnamed_addr #0 {
entry:
call void @llvm.memcpy.p0.p0.i64(ptr align 16 %a, ptr align 16 %b, i64 33, i1 0)

; More than 32 bytes, call memcpy
; CHECK: memcpy_test_6
; CHECK: mov64 r3, 33
; CHECK: call memcpy
ret void
}

0 comments on commit 7d055ed

Please sign in to comment.