Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion internal/build/build.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,10 @@ func Do(args []string, conf *Config) ([]Package, error) {
})

buildMode := ssaBuildMode
cabiOptimize := true
if IsDbgEnabled() {
buildMode |= ssa.GlobalDebug
cabiOptimize = false
}
if !IsOptimizeEnabled() {
buildMode |= ssa.NaiveForm
Expand All @@ -324,7 +326,7 @@ func Do(args []string, conf *Config) ([]Package, error) {
needPyInit: make(map[*packages.Package]bool),
buildConf: conf,
crossCompile: export,
cTransformer: cabi.NewTransformer(prog, conf.AbiMode),
cTransformer: cabi.NewTransformer(prog, conf.AbiMode, cabiOptimize),
}
pkgs, err := buildAllPkgs(ctx, initial, verbose)
check(err)
Expand Down
155 changes: 139 additions & 16 deletions internal/cabi/cabi.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ const (
ModeAllFunc
)

func NewTransformer(prog ssa.Program, mode Mode) *Transformer {
func NewTransformer(prog ssa.Program, mode Mode, optimize bool) *Transformer {
target := prog.Target()
tr := &Transformer{
prog: prog,
td: prog.TargetData(),
GOOS: target.GOOS,
GOARCH: target.GOARCH,
mode: mode,
prog: prog,
td: prog.TargetData(),
GOOS: target.GOOS,
GOARCH: target.GOARCH,
mode: mode,
optimize: optimize,
}
switch target.GOARCH {
case "amd64":
Expand All @@ -42,12 +43,13 @@ func NewTransformer(prog ssa.Program, mode Mode) *Transformer {
}

type Transformer struct {
prog ssa.Program
td llvm.TargetData
GOOS string
GOARCH string
sys TypeInfoSys
mode Mode
prog ssa.Program
td llvm.TargetData
GOOS string
GOARCH string
sys TypeInfoSys
mode Mode
optimize bool
}

func (p *Transformer) isCFunc(name string) bool {
Expand Down Expand Up @@ -113,7 +115,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) {
}
}
for _, call := range callInstrs {
p.transformCallInstr(ctx, call.call, call.fn)
p.transformCallInstr(m, ctx, call.call, call.fn)
}
for _, fn := range fns {
p.transformFunc(m, fn)
Expand Down Expand Up @@ -191,6 +193,10 @@ func funcInlineHint(ctx llvm.Context) llvm.Attribute {
return ctx.CreateEnumAttribute(llvm.AttributeKindID("inlinehint"), 0)
}

func funcNoUnwind(ctx llvm.Context) llvm.Attribute {
return ctx.CreateEnumAttribute(llvm.AttributeKindID("nounwind"), 0)
}

func (p *Transformer) IsWrapType(ctx llvm.Context, ftyp llvm.Type, typ llvm.Type, index int) bool {
if p.sys != nil {
bret := index == 0
Expand Down Expand Up @@ -314,15 +320,15 @@ func (p *Transformer) transformFunc(m llvm.Module, fn llvm.Value) bool {
}

if !fn.IsDeclaration() {
p.transformFuncBody(ctx, &info, fn, nfn, nft)
p.transformFuncBody(m, ctx, &info, fn, nfn, nft)
}

fn.ReplaceAllUsesWith(nfn)
fn.EraseFromParentAsFunction()
return true
}

func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llvm.Value, nfn llvm.Value, nft llvm.Type) {
func (p *Transformer) transformFuncBody(m llvm.Module, ctx llvm.Context, info *FuncInfo, fn llvm.Value, nfn llvm.Value, nft llvm.Type) {
var blocks []llvm.BasicBlock
bb := fn.FirstBasicBlock()
for !bb.IsNil() {
Expand Down Expand Up @@ -353,12 +359,29 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
// skip
continue
case AttrPointer:
// void @fn(%typ %0)
// %1 = alloca %typ, align 8
// call void @llvm.memset(ptr %1, i8 0, i64 36, i1 false)
// store %typ %0, ptr %1, align 4
//
// void @fn(ptr byval(%typ) %0)
// %1 = load %typ, ptr %0, align 4
// %2 = alloca %typ, align 8
// call void @llvm.memset(ptr %2, i8 0, i64 36, i1 false)
// store %typ %1, ptr %2, align 4
nv = b.CreateLoad(ti.Type, params[index], "")
// replace %0 to %2
if p.optimize {
replaceAllocaInstrs(fn.Param(i), params[index])
}
case AttrWidthType:
iptr := llvm.CreateAlloca(b, ti.Type1)
b.CreateStore(params[index], iptr)
ptr := b.CreateBitCast(iptr, llvm.PointerType(ti.Type, 0), "")
nv = b.CreateLoad(ti.Type, ptr, "")
if p.optimize {
replaceAllocaInstrs(fn.Param(i), ptr)
}
case AttrWidthType2:
typ := llvm.StructType([]llvm.Type{ti.Type1, ti.Type2}, false)
iptr := llvm.CreateAlloca(b, typ)
Expand All @@ -367,6 +390,9 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
b.CreateStore(params[index], b.CreateStructGEP(typ, iptr, 1, ""))
ptr := b.CreateBitCast(iptr, llvm.PointerType(ti.Type, 0), "")
nv = b.CreateLoad(ti.Type, ptr, "")
if p.optimize {
replaceAllocaInstrs(fn.Param(i), ptr)
}
case AttrExtract:
nsubs := ti.Type.StructElementTypesCount()
nv = llvm.Undef(ti.Type)
Expand Down Expand Up @@ -400,9 +426,31 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
var rv llvm.Value
switch info.Return.Kind {
case AttrPointer:
// %typ @fn()
// %2 = load %typ, ptr %1
// ret %typ %2
//
// void @fn(ptr sret(%typ) %0)
// %2 = load %typ, ptr %1
// store %typ %2, ptr %0 # llvm.memcpy(ptr %0, ptr %1, i64 size, i1 false)
// ret void
if p.optimize {
if load := ret.IsALoadInst(); !load.IsNil() {
p.callMemcpy(m, ctx, b, params[0], ret.Operand(0), info.Return.Size)
rv = b.CreateRetVoid()
break
}
}
b.CreateStore(ret, params[0])
rv = b.CreateRetVoid()
case AttrWidthType, AttrWidthType2:
if p.optimize {
if load := ret.IsALoadInst(); !load.IsNil() {
iptr := b.CreateBitCast(ret.Operand(0), llvm.PointerType(nft.ReturnType(), 0), "")
rv = b.CreateRet(b.CreateLoad(nft.ReturnType(), iptr, ""))
break
}
}
ptr := llvm.CreateAlloca(b, info.Return.Type)
b.CreateStore(ret, ptr)
iptr := b.CreateBitCast(ptr, llvm.PointerType(nft.ReturnType(), 0), "")
Expand All @@ -414,7 +462,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv
}
}

func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn llvm.Value) bool {
func (p *Transformer) transformCallInstr(m llvm.Module, ctx llvm.Context, call llvm.Value, fn llvm.Value) bool {
nfn := call.CalledValue()
info := p.GetFuncInfo(ctx, call.CalledFunctionType())
if !info.HasWrap() {
Expand Down Expand Up @@ -443,6 +491,19 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn l
case AttrVoid:
// none
case AttrPointer:
if p.optimize {
if rv := param.IsALoadInst(); !rv.IsNil() {
ptr := rv.Operand(0)
if p.sys.SupportByVal() {
nparams = append(nparams, ptr)
} else {
nptr := createAlloca(ti.Type)
p.callMemcpy(m, ctx, b, nptr, ptr, ti.Size)
nparams = append(nparams, nptr)
}
break
}
}
ptr := createAlloca(ti.Type)
b.CreateStore(param, ptr)
nparams = append(nparams, ptr)
Expand Down Expand Up @@ -613,3 +674,65 @@ func (p *Transformer) transformCallbackFunc(m llvm.Module, fn llvm.Value) (wrap
}
return wrapFunc, true
}

func (p *Transformer) callMemcpy(m llvm.Module, ctx llvm.Context, b llvm.Builder, dst llvm.Value, src llvm.Value, size int) llvm.Value {
memcpy := p.getMemcpy(m, ctx)
sz := llvm.ConstInt(ctx.IntType(p.prog.PointerSize()*8), uint64(size), false)
return b.CreateCall(memcpy.GlobalValueType(), memcpy, []llvm.Value{
dst, src, sz, llvm.ConstInt(ctx.Int1Type(), 0, false),
}, "")
}

func (p *Transformer) getMemcpy(m llvm.Module, ctx llvm.Context) llvm.Value {
memcpy := m.NamedFunction("llvm.memcpy")
if !memcpy.IsNil() {
return memcpy
}
ftyp := llvm.FunctionType(ctx.VoidType(), []llvm.Type{
llvm.PointerType(ctx.Int8Type(), 0),
llvm.PointerType(ctx.Int8Type(), 0),
ctx.IntType(p.prog.PointerSize() * 8),
ctx.Int1Type(),
}, false)
memcpy = llvm.AddFunction(m, "llvm.memcpy", ftyp)
memcpy.SetFunctionCallConv(llvm.CCallConv)
memcpy.AddFunctionAttr(funcNoUnwind(ctx))
return memcpy
}

func replaceAllocaInstrs(param llvm.Value, nv llvm.Value) {
u := param.FirstUse()
var storeInstrs []llvm.Value
for !u.IsNil() {
if user := u.User().IsAStoreInst(); !user.IsNil() && user.Operand(0) == param {
storeInstrs = append(storeInstrs, user)
}
u = u.NextUse()
}
for _, instr := range storeInstrs {
if alloc := instr.Operand(1).IsAAllocaInst(); !alloc.IsNil() {
skips := make(map[llvm.Value]bool)
next := llvm.NextInstruction(alloc)
for !next.IsNil() && next != instr {
skips[next] = true
next = llvm.NextInstruction(next)
}
var uses []llvm.Value
u := alloc.FirstUse()
for !u.IsNil() {
if v := u.User(); !skips[v] {
uses = append(uses, v)
}
u = u.NextUse()
}
for _, use := range uses {
n := use.OperandsCount()
for i := 0; i < n; i++ {
if use.Operand(i) == alloc {
use.SetOperand(i, nv)
}
}
}
}
}
}
Loading