Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions _demo/cabisret/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package main

type array9 struct {
x [9]float32
}

func demo1(a array9) array9 {
a.x[0] += 1
return a
}

func demo2(a array9) array9 {
for i := 0; i < 1024*128; i++ {
a = demo1(a)
}
return a
}

func testDemo() {
ar := array9{x: [9]float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}
for i := 0; i < 1024*128; i++ {
ar = demo1(ar)
}
ar = demo2(ar)
println(ar.x[0], ar.x[1])
}

func testSlice() {
var b []byte
for i := 0; i < 1024*128; i++ {
b = append(b, byte(i))
}
_ = b
}

func main() {
testDemo()
testSlice()
}
51 changes: 36 additions & 15 deletions internal/cabi/cabi.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,15 @@ func (p *Transformer) isCFunc(name string) bool {
return !strings.Contains(name, ".")
}

type CallInstr struct {
call llvm.Value
fn llvm.Value
}

func (p *Transformer) TransformModule(path string, m llvm.Module) {
ctx := m.Context()
var fns []llvm.Value
var callInstrs []llvm.Value
var callInstrs []CallInstr
switch p.mode {
case ModeNone:
return
Expand All @@ -66,16 +71,22 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
for !fn.IsNil() {
if p.isCFunc(fn.Name()) {
p.transformFuncCall(m, fn)
if p.isWrapFunctionType(m.Context(), fn.GlobalValueType()) {
if p.isWrapFunctionType(ctx, fn.GlobalValueType()) {
fns = append(fns, fn)
use := fn.FirstUse()
for !use.IsNil() {
if call := use.User().IsACallInst(); !call.IsNil() && call.CalledValue() == fn {
callInstrs = append(callInstrs, call)
}
}
bb := fn.FirstBasicBlock()
for !bb.IsNil() {
instr := bb.FirstInstruction()
for !instr.IsNil() {
if call := instr.IsACallInst(); !call.IsNil() && p.isCFunc(call.CalledValue().Name()) {
if p.isWrapFunctionType(ctx, call.CalledFunctionType()) {
callInstrs = append(callInstrs, CallInstr{call, fn})
}
use = use.NextUse()
}
instr = llvm.NextInstruction(instr)
}
bb = llvm.NextBasicBlock(bb)
}
fn = llvm.NextFunction(fn)
}
Expand All @@ -91,7 +102,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
for !instr.IsNil() {
if call := instr.IsACallInst(); !call.IsNil() {
if p.isWrapFunctionType(ctx, call.CalledFunctionType()) {
callInstrs = append(callInstrs, call)
callInstrs = append(callInstrs, CallInstr{call, fn})
}
}
instr = llvm.NextInstruction(instr)
Expand All @@ -102,7 +113,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
}
}
for _, call := range callInstrs {
p.transformCallInstr(ctx, call)
p.transformCallInstr(ctx, call.call, call.fn)
}
for _, fn := range fns {
p.transformFunc(m, fn)
Expand Down Expand Up @@ -369,6 +380,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
fn.Param(i).ReplaceAllUsesWith(nv)
index++
}

if info.Return.Kind >= AttrPointer {
var retInstrs []llvm.Value
bb := nfn.FirstBasicBlock()
Expand Down Expand Up @@ -402,7 +414,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
}
}

func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool {
func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn llvm.Value) bool {
nfn := call.CalledValue()
info := p.GetFuncInfo(ctx, call.CalledFunctionType())
if !info.HasWrap() {
Expand All @@ -411,6 +423,15 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
nft, attrs := p.transformFuncType(ctx, &info)
b := ctx.NewBuilder()
b.SetInsertPointBefore(call)

first := fn.EntryBasicBlock().FirstInstruction()
createAlloca := func(t llvm.Type) (ret llvm.Value) {
b.SetInsertPointBefore(first)
ret = llvm.CreateAlloca(b, t)
b.SetInsertPointBefore(call)
return
}

operandCount := len(info.Params)
var nparams []llvm.Value
for i := 0; i < operandCount; i++ {
Expand All @@ -422,16 +443,16 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
case AttrVoid:
// none
case AttrPointer:
ptr := llvm.CreateAlloca(b, ti.Type)
ptr := createAlloca(ti.Type)
b.CreateStore(param, ptr)
nparams = append(nparams, ptr)
case AttrWidthType:
ptr := llvm.CreateAlloca(b, ti.Type)
ptr := createAlloca(ti.Type)
b.CreateStore(param, ptr)
iptr := b.CreateBitCast(ptr, llvm.PointerType(ti.Type1, 0), "")
nparams = append(nparams, b.CreateLoad(ti.Type1, iptr, ""))
case AttrWidthType2:
ptr := llvm.CreateAlloca(b, ti.Type)
ptr := createAlloca(ti.Type)
b.CreateStore(param, ptr)
typ := llvm.StructType([]llvm.Type{ti.Type1, ti.Type2}, false) // {i8,i64}
iptr := b.CreateBitCast(ptr, llvm.PointerType(typ, 0), "")
Expand All @@ -457,14 +478,14 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool
instr = llvm.CreateCall(b, nft, nfn, nparams)
updateCallAttr(instr)
case AttrPointer:
ret := llvm.CreateAlloca(b, info.Return.Type)
ret := createAlloca(info.Return.Type)
call := llvm.CreateCall(b, nft, nfn, append([]llvm.Value{ret}, nparams...))
updateCallAttr(call)
instr = b.CreateLoad(info.Return.Type, ret, "")
case AttrWidthType, AttrWidthType2:
ret := llvm.CreateCall(b, nft, nfn, nparams)
updateCallAttr(ret)
ptr := llvm.CreateAlloca(b, nft.ReturnType())
ptr := createAlloca(nft.ReturnType())
b.CreateStore(ret, ptr)
pret := b.CreateBitCast(ptr, llvm.PointerType(info.Return.Type, 0), "")
instr = b.CreateLoad(info.Return.Type, pret, "")
Expand Down
Loading