diff --git a/asm/amd/.gitignore b/asm/amd/.gitignore new file mode 100644 index 000000000..f15395e87 --- /dev/null +++ b/asm/amd/.gitignore @@ -0,0 +1 @@ +testdata/fuzz \ No newline at end of file diff --git a/asm/amd/interpreter.go b/asm/amd/interpreter.go new file mode 100644 index 000000000..7487e966b --- /dev/null +++ b/asm/amd/interpreter.go @@ -0,0 +1,179 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package amd // import "go.opentelemetry.io/ebpf-profiler/asm/amd" + +import ( + "fmt" + "io" + "math" + + "go.opentelemetry.io/ebpf-profiler/asm/expression" + "golang.org/x/arch/x86/x86asm" +) + +type CodeBlock struct { + Address expression.Expression + Code []byte +} + +type Interpreter struct { + Regs Registers + code []byte + CodeAddress expression.Expression + pc int +} + +func NewInterpreter() *Interpreter { + it := &Interpreter{} + it.initRegs() + return it +} + +func NewInterpreterWithCode(code []byte) *Interpreter { + it := &Interpreter{code: code, CodeAddress: expression.Named("code address")} + it.initRegs() + return it +} + +func (i *Interpreter) ResetCode(code []byte, address expression.Expression) { + i.code = code + i.CodeAddress = address + i.pc = 0 +} + +func (i *Interpreter) Loop() (x86asm.Inst, error) { + return i.LoopWithBreak(func(x86asm.Inst) bool { return false }) +} + +func (i *Interpreter) LoopWithBreak(breakLoop func(op x86asm.Inst) bool) (x86asm.Inst, error) { + prev := x86asm.Inst{} + for { + op, err := i.Step() + if err != nil { + return prev, err + } + if breakLoop(op) { + return op, nil + } + prev = op + } +} + +func (i *Interpreter) Step() (x86asm.Inst, error) { + if len(i.code) == 0 { + return x86asm.Inst{}, io.EOF + } + var inst x86asm.Inst + var err error + if ok, instLen := DecodeSkippable(i.code); ok { + inst = x86asm.Inst{Op: x86asm.NOP, Len: instLen} + } else { + inst, err = x86asm.Decode(i.code, 64) + if err != nil { + return inst, fmt.Errorf("at 0x%x : %v", i.pc, err) + } + } + i.pc += inst.Len + i.code = i.code[inst.Len:] + i.Regs.setX86asm(x86asm.RIP, expression.Add(i.CodeAddress, expression.Imm(uint64(i.pc)))) + switch inst.Op { + case x86asm.ADD: + if dst, ok := inst.Args[0].(x86asm.Reg); ok { + left := i.Regs.getX86asm(dst) + switch src := inst.Args[1].(type) { + case x86asm.Imm: + right := expression.Imm(uint64(src)) + i.Regs.setX86asm(dst, expression.Add(left, right)) + case x86asm.Reg: + right := i.Regs.getX86asm(src) + i.Regs.setX86asm(dst, expression.Add(left, right)) + case x86asm.Mem: + right := i.MemArg(src) + right = expression.MemWithSegment(src.Segment, right, inst.MemBytes) + i.Regs.setX86asm(dst, expression.Add(left, right)) + } + } + case x86asm.SHL: + if dst, ok := inst.Args[0].(x86asm.Reg); ok { + if src, imm := inst.Args[1].(x86asm.Imm); imm { + v := expression.Multiply( + i.Regs.getX86asm(dst), + expression.Imm(uint64(math.Pow(2, float64(src)))), + ) + i.Regs.setX86asm(dst, v) + } + } + case x86asm.MOV, x86asm.MOVZX, x86asm.MOVSXD, x86asm.MOVSX: + if dst, ok := inst.Args[0].(x86asm.Reg); ok { + switch src := inst.Args[1].(type) { + case x86asm.Imm: + i.Regs.setX86asm(dst, expression.Imm(uint64(src))) + case x86asm.Reg: + i.Regs.setX86asm(dst, i.Regs.getX86asm(src)) + case x86asm.Mem: + v := i.MemArg(src) + + dataSizeBits := inst.DataSize + + v = expression.MemWithSegment(src.Segment, v, inst.MemBytes) + if inst.Op == x86asm.MOVSXD || inst.Op == x86asm.MOVSX { + v = expression.SignExtend(v, dataSizeBits) + } else { + v = expression.ZeroExtend(v, dataSizeBits) + } + i.Regs.setX86asm(dst, v) + } + } + case x86asm.XOR: + if dst, ok := inst.Args[0].(x86asm.Reg); ok { + if src, reg := inst.Args[1].(x86asm.Reg); reg { + if src == dst { + i.Regs.setX86asm(dst, expression.Imm(0)) + } + } + } + case x86asm.AND: + if dst, ok := inst.Args[0].(x86asm.Reg); ok { + if src, imm := inst.Args[1].(x86asm.Imm); imm { + if src == 3 { // todo other cases + i.Regs.setX86asm(dst, expression.ZeroExtend(i.Regs.getX86asm(dst), 2)) + } + } + } + case x86asm.LEA: + if dst, ok := inst.Args[0].(x86asm.Reg); ok { + if src, mem := inst.Args[1].(x86asm.Mem); mem { + v := i.MemArg(src) + i.Regs.setX86asm(dst, v) + } + } + default: + } + return inst, nil +} + +func (i *Interpreter) MemArg(src x86asm.Mem) expression.Expression { + vs := make([]expression.Expression, 0, 3) + if src.Disp != 0 { + vs = append(vs, expression.Imm(uint64(src.Disp))) + } + if src.Base != 0 { + vs = append(vs, i.Regs.getX86asm(src.Base)) + } + if src.Index != 0 { + v := expression.Multiply( + i.Regs.getX86asm(src.Index), + expression.Imm(uint64(src.Scale)), + ) + vs = append(vs, v) + } + v := expression.Add(vs...) + return v +} + +func (i *Interpreter) initRegs() { + for j := 0; j < len(i.Regs.regs); j++ { + i.Regs.regs[j] = expression.Named(Reg(j).String()) + } +} diff --git a/asm/amd/interpreter_test.go b/asm/amd/interpreter_test.go new file mode 100644 index 000000000..be3fc2dad --- /dev/null +++ b/asm/amd/interpreter_test.go @@ -0,0 +1,171 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package amd + +import ( + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/ebpf-profiler/asm/expression" +) + +func BenchmarkPythonInterpreter(b *testing.B) { + for i := 0; i < b.N; i++ { + testPythonInterpreter(b) + } +} + +func TestPythonInterpreter(t *testing.T) { + testPythonInterpreter(t) +} + +func testPythonInterpreter(t testing.TB) { + // 00010000 4D 89 F2 mov r10, r14 + // 00010003 45 0F B6 36 movzx r14d, byte ptr [r14] + // 00010007 48 8D 05 2D B3 35 00 lea rax, [rip + 0x35b32d] + // 0001000E 4C 8B 6C 24 08 mov r13, qword ptr [rsp + 8] + // 00010013 48 89 C1 mov rcx, rax + // 00010016 48 89 44 24 10 mov qword ptr [rsp + 0x10], rax + // 0001001B 45 0F B6 5A 01 movzx r11d, byte ptr [r10 + 1] + // 00010020 41 0F B6 C6 movzx eax, r14b + // 00010024 48 8B 04 C1 mov rax, qword ptr [rcx + rax*8] + // 00010028 FF E0 jmp rax + code := []byte{ + 0x4d, 0x89, 0xf2, 0x45, 0x0f, 0xb6, 0x36, 0x48, 0x8d, 0x05, 0x2d, 0xb3, 0x35, + 0x00, 0x4c, 0x8b, 0x6c, 0x24, 0x08, 0x48, 0x89, 0xc1, 0x48, 0x89, 0x44, 0x24, + 0x10, 0x45, 0x0f, 0xb6, 0x5a, 0x01, 0x41, 0x0f, 0xb6, 0xc6, 0x48, 0x8b, 0x04, + 0xc1, 0xff, 0xe0, + } + it := NewInterpreterWithCode(code) + it.CodeAddress = expression.Imm(0x8AF05) + r14 := it.Regs.Get(R14) + _, err := it.Loop() + if err == nil || err != io.EOF { + t.Fatal(err) + } + actual := it.Regs.Get(RAX) + expected := expression.Mem( + expression.Add( + expression.Multiply( + expression.ZeroExtend8(expression.Mem1(r14)), + expression.Imm(8), + ), + expression.NewImmediateCapture("switch table"), + ), + 8, + ) + if !actual.Match(expected) { + t.Fatal() + } +} + +func TestRecoverSwitchCase(t *testing.T) { + blocks := []CodeBlock{ + { + Address: expression.Imm(0x3310E3), + // 003310E3 48 8B 44 24 20 mov rax, qword ptr [rsp + 0x20] + // 003310E8 48 89 18 mov qword ptr [rax], rbx + // 003310EB 49 83 C2 02 add r10, 2 + // 003310EF 44 89 E0 mov eax, r12d + // 003310F2 83 E0 03 and eax, 3 + // 003310F5 31 DB xor ebx, ebx + // 003310F7 41 F6 C4 04 test r12b, 4 + // 003310FB 4C 89 74 24 10 mov qword ptr [rsp + 0x10], r14 + // 00331100 74 08 je 0x33110a + Code: []byte{0x48, 0x8b, 0x44, 0x24, 0x20, 0x48, 0x89, 0x18, 0x49, + 0x83, 0xc2, 0x02, 0x44, 0x89, 0xe0, 0x83, 0xe0, 0x03, 0x31, 0xdb, + 0x41, 0xf6, 0xc4, 0x04, 0x4c, 0x89, 0x74, 0x24, 0x10, 0x74, 0x08}, + }, + { + Address: expression.Imm(0x33110a), + // 0033110A 4D 89 DC mov r12, r11 + // 0033110D 4D 8D 47 F8 lea r8, [r15 - 8] + // 00331111 4C 89 7C 24 60 mov qword ptr [rsp + 0x60], r15 + // 00331116 4D 8B 7F F8 mov r15, qword ptr [r15 - 8] + // 0033111A 48 8B 0D 87 06 17 01 mov rcx, qword ptr [rip + 0x1170687] + // 00331121 89 C0 mov eax, eax + // 00331123 48 8D 15 02 E7 C0 00 lea rdx, [rip + 0xc0e702] + // 0033112A 48 63 04 82 movsxd rax, dword ptr [rdx + rax*4] + // 0033112E 48 01 D0 add rax, rdx + // 00331131 4C 89 D5 mov rbp, r10 + // 00331134 4D 89 C5 mov r13, r8 + // 00331137 FF E0 jmp rax + Code: []byte{ + 0x4d, 0x89, 0xdc, 0x4d, 0x8d, 0x47, 0xf8, 0x4c, 0x89, 0x7c, 0x24, + 0x60, 0x4d, 0x8b, 0x7f, 0xf8, 0x48, 0x8b, 0x0d, 0x87, 0x06, 0x17, + 0x01, 0x89, 0xc0, 0x48, 0x8d, 0x15, 0x02, 0xe7, 0xc0, 0x00, 0x48, + 0x63, 0x04, 0x82, 0x48, 0x01, 0xd0, 0x4c, 0x89, 0xd5, 0x4d, 0x89, + 0xc5, 0xff, 0xe0, + }, + }, + } + it := NewInterpreter() + initR12 := it.Regs.Get(R12) + it.ResetCode(blocks[0].Code, blocks[0].Address) + _, err := it.Loop() + require.ErrorIs(t, err, io.EOF) + + expected := expression.ZeroExtend(initR12, 2) + assertEval(t, it.Regs.Get(RAX), expected) + it.ResetCode(blocks[1].Code, blocks[1].Address) + _, err = it.Loop() + require.ErrorIs(t, err, io.EOF) + table := expression.NewImmediateCapture("table") + base := expression.NewImmediateCapture("base") + expected = expression.Add( + expression.SignExtend( + expression.Mem( + expression.Add( + expression.Multiply( + expression.ZeroExtend(initR12, 2), + expression.Imm(4), + ), + table, + ), + 4, + ), + 64, + ), + base, + ) + assertEval(t, it.Regs.Get(RAX), expected) + assert.EqualValues(t, 0xf3f82c, table.CapturedValue()) + assert.EqualValues(t, 0xf3f82c, base.CapturedValue()) +} + +func assertEval(t *testing.T, left, right expression.Expression) { + if !left.Match(right) { + assert.Failf(t, "failed to eval %s to %s", left.DebugString(), right.DebugString()) + t.Logf("left %s", left.DebugString()) + t.Logf("right %s", right.DebugString()) + } +} + +func FuzzInterpreter(f *testing.F) { + f.Fuzz(func(_ *testing.T, code []byte) { + i := NewInterpreterWithCode(code) + _, _ = i.Loop() + }) +} + +func TestMoveSignExtend(t *testing.T) { + i := NewInterpreterWithCode([]byte{ + // 00000000 B8 01 00 00 00 mov eax, 1 + // 00000005 8B 40 04 mov eax, dword ptr [rax + 4] + // 00000008 B8 02 00 00 00 mov eax, 2 + // 0000000D 48 0F B6 40 04 movzx rax, byte ptr [rax + 4] + // 00000012 B8 03 00 00 00 mov eax, 3 + // 00000017 48 0F BF 40 04 movsx rax, word ptr [rax + 4] + 0xB8, 0x01, 0x00, 0x00, 0x00, 0x8B, 0x40, 0x04, + 0xB8, 0x02, 0x00, 0x00, 0x00, 0x48, 0x0F, 0xB6, + 0x40, 0x04, 0xB8, 0x03, 0x00, 0x00, 0x00, 0x48, + 0x0F, 0xBF, 0x40, 0x04, + }) + _, err := i.Loop() + require.ErrorIs(t, err, io.EOF) + pattern := expression.SignExtend(expression.Mem(expression.Imm(7), 2), 64) + require.True(t, i.Regs.Get(RAX).Match(pattern)) +} diff --git a/asm/amd/regs_state.go b/asm/amd/regs_state.go index 92b36ee39..77af3cc91 100644 --- a/asm/amd/regs_state.go +++ b/asm/amd/regs_state.go @@ -3,64 +3,230 @@ package amd // import "go.opentelemetry.io/ebpf-profiler/asm/amd" -import "golang.org/x/arch/x86/x86asm" +import ( + "fmt" -// regIndex returns index into RegsState.regs -func regIndex(reg x86asm.Reg) int { + "go.opentelemetry.io/ebpf-profiler/asm/expression" + "golang.org/x/arch/x86/x86asm" +) + +type Registers struct { + regs [int(registersCount)]expression.Expression +} + +type regEntry struct { + idx Reg + bits int +} +type Reg uint8 + +const ( + _ Reg = iota + RAX + RCX + RDX + RBX + RSP + RBP + RSI + RDI + R8 + R9 + R10 + R11 + R12 + R13 + R14 + R15 + RIP + registersCount +) + +var regNames = [...]string{ + RAX: "RAX", + RCX: "RCX", + RDX: "RDX", + RBX: "RBX", + RSP: "RSP", + RBP: "RBP", + RSI: "RSI", + RDI: "RDI", + R8: "R8", + R9: "R9", + R10: "R10", + R11: "R11", + R12: "R12", + R13: "R13", + R14: "R14", + R15: "R15", + RIP: "RIP", +} + +func (r Reg) String() string { + i := int(r) + if r == 0 || i >= len(regNames) || regNames[i] == "" { + return fmt.Sprintf("Reg(%d)", i) + } + return regNames[i] +} + +func regMappingFor(reg x86asm.Reg) regEntry { switch reg { - case x86asm.RAX, x86asm.EAX: - return 1 - case x86asm.RBX, x86asm.EBX: - return 2 - case x86asm.RCX, x86asm.ECX: - return 3 - case x86asm.RDX, x86asm.EDX: - return 4 - case x86asm.RDI, x86asm.EDI: - return 5 - case x86asm.RSI, x86asm.ESI: - return 6 - case x86asm.RBP, x86asm.EBP: - return 7 - case x86asm.R8, x86asm.R8L: - return 8 - case x86asm.R9, x86asm.R9L: - return 9 - case x86asm.R10, x86asm.R10L: - return 10 - case x86asm.R11, x86asm.R11L: - return 11 - case x86asm.R12, x86asm.R12L: - return 12 - case x86asm.R13, x86asm.R13L: - return 13 - case x86asm.R14, x86asm.R14L: - return 14 - case x86asm.R15, x86asm.R15L: - return 15 - case x86asm.RSP, x86asm.ESP: - return 16 + case x86asm.AL: + return regEntry{idx: RAX, bits: 8} + case x86asm.CL: + return regEntry{idx: RCX, bits: 8} + case x86asm.DL: + return regEntry{idx: RDX, bits: 8} + case x86asm.BL: + return regEntry{idx: RBX, bits: 8} + case x86asm.SPB: + return regEntry{idx: RSP, bits: 8} + case x86asm.BPB: + return regEntry{idx: RBP, bits: 8} + case x86asm.SIB: + return regEntry{idx: RSI, bits: 8} + case x86asm.DIB: + return regEntry{idx: RDI, bits: 8} + case x86asm.R8B: + return regEntry{idx: R8, bits: 8} + case x86asm.R9B: + return regEntry{idx: R9, bits: 8} + case x86asm.R10B: + return regEntry{idx: R10, bits: 8} + case x86asm.R11B: + return regEntry{idx: R11, bits: 8} + case x86asm.R12B: + return regEntry{idx: R12, bits: 8} + case x86asm.R13B: + return regEntry{idx: R13, bits: 8} + case x86asm.R14B: + return regEntry{idx: R14, bits: 8} + case x86asm.R15B: + return regEntry{idx: R15, bits: 8} + case x86asm.AX: + return regEntry{idx: RAX, bits: 16} + case x86asm.CX: + return regEntry{idx: RCX, bits: 16} + case x86asm.DX: + return regEntry{idx: RDX, bits: 16} + case x86asm.BX: + return regEntry{idx: RBX, bits: 16} + case x86asm.SP: + return regEntry{idx: RSP, bits: 16} + case x86asm.BP: + return regEntry{idx: RBP, bits: 16} + case x86asm.SI: + return regEntry{idx: RSI, bits: 16} + case x86asm.DI: + return regEntry{idx: RDI, bits: 16} + case x86asm.R8W: + return regEntry{idx: R8, bits: 16} + case x86asm.R9W: + return regEntry{idx: R9, bits: 16} + case x86asm.R10W: + return regEntry{idx: R10, bits: 16} + case x86asm.R11W: + return regEntry{idx: R11, bits: 16} + case x86asm.R12W: + return regEntry{idx: R12, bits: 16} + case x86asm.R13W: + return regEntry{idx: R13, bits: 16} + case x86asm.R14W: + return regEntry{idx: R14, bits: 16} + case x86asm.R15W: + return regEntry{idx: R15, bits: 16} + case x86asm.EAX: + return regEntry{idx: RAX, bits: 32} + case x86asm.ECX: + return regEntry{idx: RCX, bits: 32} + case x86asm.EDX: + return regEntry{idx: RDX, bits: 32} + case x86asm.EBX: + return regEntry{idx: RBX, bits: 32} + case x86asm.ESP: + return regEntry{idx: RSP, bits: 32} + case x86asm.EBP: + return regEntry{idx: RBP, bits: 32} + case x86asm.ESI: + return regEntry{idx: RSI, bits: 32} + case x86asm.EDI: + return regEntry{idx: RDI, bits: 32} + case x86asm.R8L: + return regEntry{idx: R8, bits: 32} + case x86asm.R9L: + return regEntry{idx: R9, bits: 32} + case x86asm.R10L: + return regEntry{idx: R10, bits: 32} + case x86asm.R11L: + return regEntry{idx: R11, bits: 32} + case x86asm.R12L: + return regEntry{idx: R12, bits: 32} + case x86asm.R13L: + return regEntry{idx: R13, bits: 32} + case x86asm.R14L: + return regEntry{idx: R14, bits: 32} + case x86asm.R15L: + return regEntry{idx: R15, bits: 32} + case x86asm.RAX: + return regEntry{idx: RAX, bits: 64} + case x86asm.RCX: + return regEntry{idx: RCX, bits: 64} + case x86asm.RDX: + return regEntry{idx: RDX, bits: 64} + case x86asm.RBX: + return regEntry{idx: RBX, bits: 64} + case x86asm.RSP: + return regEntry{idx: RSP, bits: 64} + case x86asm.RBP: + return regEntry{idx: RBP, bits: 64} + case x86asm.RSI: + return regEntry{idx: RSI, bits: 64} + case x86asm.RDI: + return regEntry{idx: RDI, bits: 64} + case x86asm.R8: + return regEntry{idx: R8, bits: 64} + case x86asm.R9: + return regEntry{idx: R9, bits: 64} + case x86asm.R10: + return regEntry{idx: R10, bits: 64} + case x86asm.R11: + return regEntry{idx: R11, bits: 64} + case x86asm.R12: + return regEntry{idx: R12, bits: 64} + case x86asm.R13: + return regEntry{idx: R13, bits: 64} + case x86asm.R14: + return regEntry{idx: R14, bits: 64} + case x86asm.R15: + return regEntry{idx: R15, bits: 64} case x86asm.RIP: - return 17 + return regEntry{idx: RIP, bits: 64} default: - return 0 + return regEntry{idx: 0, bits: 64} } } -type RegsState struct { - regs [18]regState -} - -func (r *RegsState) Set(reg x86asm.Reg, value, loadedFrom uint64) { - r.regs[regIndex(reg)].Value = value - r.regs[regIndex(reg)].LoadedFrom = loadedFrom +func (r *Registers) setX86asm(reg x86asm.Reg, v expression.Expression) { + e := regMappingFor(reg) + if e.bits != 64 { + v = expression.ZeroExtend(v, e.bits) + } + r.regs[e.idx] = v } -func (r *RegsState) Get(reg x86asm.Reg) (value, loadedFrom uint64) { - return r.regs[regIndex(reg)].Value, r.regs[regIndex(reg)].LoadedFrom +func (r *Registers) getX86asm(reg x86asm.Reg) expression.Expression { + e := regMappingFor(reg) + res := r.regs[e.idx] + if e.bits != 64 { + res = expression.ZeroExtend(res, e.bits) + } + return res } -type regState struct { - LoadedFrom uint64 - Value uint64 +func (r *Registers) Get(reg Reg) expression.Expression { + if int(reg) >= len(r.regs) { + return r.regs[0] + } + return r.regs[int(reg)] } diff --git a/asm/expression/add.go b/asm/expression/add.go new file mode 100644 index 000000000..7061f6fa2 --- /dev/null +++ b/asm/expression/add.go @@ -0,0 +1,36 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package expression // import "go.opentelemetry.io/ebpf-profiler/asm/expression" + +func Add(vs ...Expression) Expression { + oss := make(operands, 0, len(vs)+1) + v := uint64(0) + for _, it := range vs { + if o, ok := it.(*op); ok && o.typ == opAdd { + for _, jit := range o.operands { + if imm, immOk := jit.(*immediate); immOk { + v += imm.Value + } else { + oss = append(oss, jit) + } + } + } else { + if imm, immOk := it.(*immediate); immOk { + v += imm.Value + } else { + oss = append(oss, it) + } + } + } + if len(oss) == 0 { + return Imm(v) + } + if v != 0 { + oss = append(oss, Imm(v)) + } + if len(oss) == 1 { + return oss[0] + } + return newOp(opAdd, oss) +} diff --git a/asm/expression/capture.go b/asm/expression/capture.go new file mode 100644 index 000000000..f28b6c551 --- /dev/null +++ b/asm/expression/capture.go @@ -0,0 +1,29 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package expression // import "go.opentelemetry.io/ebpf-profiler/asm/expression" + +var _ Expression = &ImmediateCapture{} + +func NewImmediateCapture(name string) *ImmediateCapture { + return &ImmediateCapture{ + name: name, + } +} + +type ImmediateCapture struct { + name string + capturedValue immediate +} + +func (v *ImmediateCapture) CapturedValue() uint64 { + return v.capturedValue.Value +} + +func (v *ImmediateCapture) DebugString() string { + return "@" + v.name +} + +func (v *ImmediateCapture) Match(_ Expression) bool { + return false +} diff --git a/asm/expression/expression.go b/asm/expression/expression.go new file mode 100644 index 000000000..8bc4ffff4 --- /dev/null +++ b/asm/expression/expression.go @@ -0,0 +1,72 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package expression // import "go.opentelemetry.io/ebpf-profiler/asm/expression" +import "sort" + +// Expression is an interface representing a 64-bit size value. It can be immediate +type Expression interface { + // Match compares this Expression value against a pattern Expression. + // The order of the arguments matters: a.Match(b) or b.Match(a) may + // produce different results. The intended order The pattern should be passed as + // an argument, not the other way around. + // It returns true if the values are considered equal or compatible according to + // the type-specific rules: + // - For operations (add, mul): checks if operation types and operands match + // - For immediate: checks if values are equal and extracts value into a ImmediateCapture + // - For mem references: checks if segments and addresses match + // - For extend operations: checks if sizes and inner values match + // - For named: checks if they are pointing to the same object instance. + // - For ImmediateCapture: matches nothing - see immediate + Match(pattern Expression) bool + DebugString() string +} + +type operands []Expression + +func (os *operands) Match(other operands) bool { + if len(*os) != len(other) { + return false + } + sort.Sort(sortedOperands(*os)) + sort.Sort(sortedOperands(other)) + for i := 0; i < len(*os); i++ { + if !(*os)[i].Match(other[i]) { + return false + } + } + return true +} + +type sortedOperands operands + +func (s sortedOperands) Len() int { + return len(s) +} + +func (s sortedOperands) Less(i, j int) bool { + o1 := cmpOrder(s[i]) + o2 := cmpOrder(s[j]) + return o1 < o2 +} + +func (s sortedOperands) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func cmpOrder(u Expression) int { + switch u.(type) { + case *mem: + return 1 + case *op: + return 2 + case *ImmediateCapture: + return 3 + case *named: + return 4 + case *immediate: + return 5 + default: + return 0 + } +} diff --git a/asm/expression/expression_test.go b/asm/expression/expression_test.go new file mode 100644 index 000000000..bf6772eaa --- /dev/null +++ b/asm/expression/expression_test.go @@ -0,0 +1,116 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package expression + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestExpression(t *testing.T) { + t.Run("add sort-summ-immediate", func(t *testing.T) { + v := Named("v") + require.Equal(t, Add(v, Imm(14)), Add(Imm(1), Imm(3), Imm(1), v, Imm(9))) + }) + + t.Run("named match", func(t *testing.T) { + n := Named("v") + require.True(t, n.Match(n)) + require.False(t, n.Match(Imm(239))) + }) + + t.Run("add 0", func(t *testing.T) { + v := Named("v") + require.Equal(t, v, Add(Imm(0), v)) + }) + + t.Run("add nested", func(t *testing.T) { + s1 := Named("s1") + s2 := Named("s2") + s3 := Named("s3") + performAssertions := func(e Expression) { + opp, ok := e.(*op) + require.True(t, ok) + require.Len(t, opp.operands, 3) + require.Contains(t, opp.operands, s1) + require.Contains(t, opp.operands, s2) + require.Contains(t, opp.operands, s3) + } + performAssertions(Add(Add(s1, s3), s2)) + performAssertions(Add(Add(s1, s3), s2)) + }) + + t.Run("add opt", func(t *testing.T) { + v := Named("v") + require.Equal(t, Add(Add(Imm(2), v), Imm(7)), Add(v, Imm(9))) + }) + + t.Run("add 1 element", func(t *testing.T) { + require.Equal(t, Add(Imm(2)), Imm(2)) + }) + + t.Run("mul immediate", func(t *testing.T) { + v := Named("v") + require.Equal(t, Multiply(v, Imm(27)), Multiply(Imm(1), Imm(3), Imm(1), v, Imm(9))) + }) + + t.Run("mul 1", func(t *testing.T) { + v := Named("v") + + require.Equal(t, v, Multiply(Imm(1), v)) + }) + + t.Run("mul add", func(t *testing.T) { + v1 := Named("v1") + v2 := Named("v2") + v3 := Named("v3") + require.Equal(t, Add(Multiply(v1, v3), Multiply(v2, v3)), Multiply(Add(v1, v2), v3)) + }) + + t.Run("op order", func(t *testing.T) { + v := Named("v") + v2 := Mem8(Named("v2")) + require.True(t, Multiply(v, v2).Match(Multiply(v2, v))) + }) + + t.Run("mul order", func(t *testing.T) { + v := Named("v") + + var a Expression = &op{opMul, []Expression{v, Imm(239)}} + require.Equal(t, a, Multiply(Imm(239), v)) + }) + + t.Run("mul 0", func(t *testing.T) { + v := Named("v") + + require.Equal(t, Imm(0), Multiply(Imm(0), Imm(3), Imm(1), v, Imm(9))) + }) + + t.Run("extend nested", func(t *testing.T) { + v := Named("v") + + require.Equal(t, ZeroExtend(v, 7), ZeroExtend(ZeroExtend(v, 7), 7)) + }) + + t.Run("extend nested smaller", func(t *testing.T) { + v := Named("v") + + require.Equal(t, ZeroExtend(v, 5), ZeroExtend(ZeroExtend(v, 7), 5)) + }) + t.Run("extend nested smaller", func(t *testing.T) { + v := Named("v") + + require.Equal(t, ZeroExtend(v, 5), ZeroExtend(ZeroExtend(v, 5), 7)) + }) + + t.Run("extend 0", func(t *testing.T) { + require.Equal(t, Imm(0), ZeroExtend(Named("v1"), 0)) + }) + + t.Run("nested extend ", func(t *testing.T) { + v1 := Named("v1") + require.Equal(t, ZeroExtend(v1, 8), ZeroExtend(ZeroExtend(v1, 8), 8)) + }) +} diff --git a/asm/expression/extend.go b/asm/expression/extend.go new file mode 100644 index 000000000..cdbc74a9f --- /dev/null +++ b/asm/expression/extend.go @@ -0,0 +1,84 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package expression // import "go.opentelemetry.io/ebpf-profiler/asm/expression" +import ( + "fmt" + "math" +) + +var _ Expression = &extend{} + +func SignExtend(v Expression, bits int) Expression { + return &extend{v, bits, true} +} + +func ZeroExtend32(v Expression) Expression { + return ZeroExtend(v, 32) +} + +func ZeroExtend8(v Expression) Expression { + return ZeroExtend(v, 8) +} + +func ZeroExtend(v Expression, bits int) Expression { + if bits >= 64 { + bits = 64 + } + c := &extend{ + v: v, + bits: bits, + } + if c.bits == 0 { + return Imm(0) + } + if c.bits == 64 { + return c.v + } + switch typed := c.v.(type) { + case *immediate: + return Imm(typed.Value & c.MaxValue()) + case *extend: + if typed.sign { + return c + } + if typed.bits <= c.bits { + return typed + } + return &extend{typed.v, c.bits, false} + default: + return c + } +} + +type extend struct { + v Expression + bits int + sign bool +} + +func (c *extend) MaxValue() uint64 { + if c.bits >= 64 || c.sign { + return math.MaxUint64 + } + return 1< } rip := uint64(0x1bbba0) - val, _ := decodeStubArgumentAMD64( - code, - rip, - 0, - ) + val, _ := decodeStubArgumentAMD64(code, rip, 0) if val != 0x3a4c2c { b.Fail() } @@ -72,6 +68,7 @@ func TestAmd64DecodeStub(t *testing.T) { name string code []byte rip uint64 + memBase uint64 expected uint64 expectedError string }{ @@ -113,7 +110,8 @@ func TestAmd64DecodeStub(t *testing.T) { 0xe9, 0x2e, 0x41, 0xeb, 0xff, // 1adcad: jmp 61de0 }, rip: 0x1adc90, - expected: 0x248, + memBase: 0xcafe0000, + expected: 0xcafe0248, }, { name: "3.12.8 gcc12 disable-optimizations enabled-shared", @@ -126,7 +124,8 @@ func TestAmd64DecodeStub(t *testing.T) { 0xe8, 0x95, 0x78, 0xe2, 0xff, // 2e25e6: call 109e80 }, rip: 0x2e25d0, - expected: 0x608, + expected: 0x608 + 0xef00000, + memBase: 0xef00000, }, { name: "3.10.16 clang18 enable-optimizations enabled-shared", @@ -139,7 +138,8 @@ func TestAmd64DecodeStub(t *testing.T) { 0xe9, 0x24, 0x55, 0xf9, 0xff, // cac67: jmp 60190 }, rip: 0xcac50, - expected: 0x24c, + expected: 0x24c + 0xef00000, + memBase: 0xef00000, }, { name: "3.10.16 clang18 enable-optimizations disable-shared", @@ -222,14 +222,14 @@ func TestAmd64DecodeStub(t *testing.T) { { name: "empty code", code: nil, - expectedError: "no call/jump instructions found", + expectedError: "EOF", }, { name: "no call/jump instructions found", code: []byte{ 0x48, 0xC7, 0xC7, 0xEF, 0xEF, 0xEF, 0x00, // mov rdi, 0xefefef }, - expectedError: "no call/jump instructions found", + expectedError: "EOF", }, { name: "bad instruction", @@ -237,17 +237,7 @@ func TestAmd64DecodeStub(t *testing.T) { 0x48, 0xC7, 0xC7, 0xEF, 0xEF, 0xEF, 0x00, // mov rdi, 0xefefef 0xea, // :shrug: }, - expectedError: "failed to decode instruction at 0x7", - }, - { - name: "synthetic mov scale index", - code: []byte{ - 0x48, 0xC7, 0xC0, 0xCA, 0xCA, 0x00, 0x00, // mov rax, 0xcaca - 0xBB, 0x00, 0x00, 0x00, 0x5E, // mov ebx, 0x5e000000 - 0x67, 0x48, 0x8B, 0x7C, 0x43, 0x05, // mov rdi, qword ptr [ebx + eax*2 + 5] - 0xEB, 0x00, // jmp 0x14 - }, - expected: 0xCACA*2 + 0x5E000000 + 5, + expectedError: "at 0x7", }, { name: "synthetic lea scale index", @@ -276,7 +266,7 @@ func TestAmd64DecodeStub(t *testing.T) { val, err := decodeStubArgumentAMD64( td.code, td.rip, - 0, // NULL pointer as mem + td.memBase, ) if td.expectedError != "" { require.Error(t, err) diff --git a/interpreter/python/python.go b/interpreter/python/python.go index 0d1c491ce..305b88a1b 100644 --- a/interpreter/python/python.go +++ b/interpreter/python/python.go @@ -666,7 +666,15 @@ func decodeStub(ef *pfelf.File, memoryBase libpf.SymbolValue, return libpf.SymbolValueInvalid, fmt.Errorf("unable to read '%s': %v", symbolName, err) } - value, err := decodeStubArgumentWrapper(code, sym.Address, memoryBase) + var value libpf.SymbolValue + switch ef.Machine { + case elf.EM_AARCH64: + value, err = decodeStubArgumentARM64(code, memoryBase), nil + case elf.EM_X86_64: + value, err = decodeStubArgumentAMD64(code, uint64(sym.Address), uint64(memoryBase)) + default: + return libpf.SymbolValueInvalid, fmt.Errorf("unsupported arch %s", ef.Machine.String()) + } // Sanity check the value range and alignment if err != nil || value%4 != 0 {