From 787e3ebf0b7845d52d0cb59cbfdab2bf802958eb Mon Sep 17 00:00:00 2001 From: hengyoush Date: Mon, 6 Jan 2025 03:22:26 +0800 Subject: [PATCH] feat: support for parsing ipip packet --- agent/analysis/stat.go | 56 ++++- agent/buffer/stream_buffer.go | 13 +- agent/conn/conntrack.go | 50 ++-- agent/conn/kern_event_handler.go | 94 +++++++- agent/conn/net.go | 144 +++++++++++ agent/conn/processor.go | 381 +++++++++++++++++++++--------- agent/metadata/process.go | 105 +++++++- agent/render/watch/time_detail.go | 49 +++- bpf/loader/container.go | 15 -- bpf/loader/loader.go | 48 +++- bpf/pktlatency.bpf.c | 5 - common/net.go | 101 -------- testdata/ipip_test.sh | 81 ++++--- testdata/ipip_test3.sh | 24 -- testdata/ipip_test_clean.sh | 8 +- 15 files changed, 829 insertions(+), 345 deletions(-) create mode 100644 agent/conn/net.go delete mode 100644 common/net.go delete mode 100644 testdata/ipip_test3.sh diff --git a/agent/analysis/stat.go b/agent/analysis/stat.go index 7868481a..01bf0c90 100644 --- a/agent/analysis/stat.go +++ b/agent/analysis/stat.go @@ -181,7 +181,10 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect // because we could missed some nicIngressEvents, the total duration may be negative annotatedRecord.StartTs = math.MaxUint64 if hasNicInEvents { - annotatedRecord.StartTs = min(events.nicIngressEvents[0].GetTimestamp(), annotatedRecord.StartTs) + nicInTimestamp, _, ok := events.nicIngressEvents[0].GetMinIfItmestampAttr() + if ok { + annotatedRecord.StartTs = min(uint64(nicInTimestamp), annotatedRecord.StartTs) + } } if hasTcpInEvents { annotatedRecord.StartTs = min(events.tcpInEvents[0].GetTimestamp(), annotatedRecord.StartTs) @@ -193,7 +196,10 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect annotatedRecord.StartTs = min(events.readSyscallEvents[0].GetTimestamp(), annotatedRecord.StartTs) } if hasDevOutEvents { - annotatedRecord.EndTs = events.devOutEvents[len(events.devOutEvents)-1].GetTimestamp() + devOutTimestamp, _, ok := events.devOutEvents[len(events.devOutEvents)-1].GetMaxIfItmestampAttr() + if ok { + annotatedRecord.EndTs = uint64(devOutTimestamp) + } } if connection.IsSsl() { annotatedRecord.ReqPlainTextSize = events.ingressMessage.ByteSize() @@ -221,12 +227,12 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect annotatedRecord.RespNicEventDetails = KernEventsToNicEventDetails(events.devOutEvents) } else { if hasWriteSyscallEvents { - annotatedRecord.StartTs = events.writeSyscallEvents[0].GetTimestamp() + annotatedRecord.StartTs = findMinTimestamp(events.writeSyscallEvents) } else { annotatedRecord.StartTs = events.egressMessage.TimestampNs() } if hasReadSyscallEvents { - annotatedRecord.EndTs = events.readSyscallEvents[len(events.readSyscallEvents)-1].GetTimestamp() + annotatedRecord.EndTs = findMaxTimestamp(events.readSyscallEvents) } else { annotatedRecord.EndTs = events.ingressMessage.TimestampNs() } @@ -242,7 +248,31 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect annotatedRecord.TotalDuration = float64(events.ingressMessage.TimestampNs()) - float64(events.egressMessage.TimestampNs()) } if hasNicInEvents && hasDevOutEvents { - annotatedRecord.BlackBoxDuration = float64(events.nicIngressEvents[len(events.nicIngressEvents)-1].GetTimestamp()) - float64(events.devOutEvents[0].GetTimestamp()) + nicIngressTimestamp := int64(0) + for _, nicIngressEvent := range events.nicIngressEvents { + _nicIngressTimestamp, _, ok := nicIngressEvent.GetMinIfItmestampAttr() + if ok { + nicIngressTimestamp = max(nicIngressTimestamp, _nicIngressTimestamp) + } + } + + if nicIngressTimestamp != 0 { + nicEgressTimestamp := int64(math.MaxInt64) + for _, devOutEvent := range events.devOutEvents { + _nicEgressTimestamp, _, ok := devOutEvent.GetMaxIfItmestampAttr() + if ok { + nicEgressTimestamp = min(nicEgressTimestamp, _nicEgressTimestamp) + } + } + if nicEgressTimestamp != int64(math.MaxInt64) { + annotatedRecord.BlackBoxDuration = float64(nicIngressTimestamp) - float64(nicEgressTimestamp) + } else { + annotatedRecord.BlackBoxDuration = -1 + } + nicEgressTimestamp++ + } else { + annotatedRecord.BlackBoxDuration = -1 + } } if (hasUserCopyEvents || hasReadSyscallEvents) && hasTcpInEvents { var readFromEndTime float64 @@ -292,6 +322,22 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect return nil } +func findMaxTimestamp(events []conn.KernEvent) uint64 { + var maxTimestamp uint64 = 0 + for _, each := range events { + maxTimestamp = max(maxTimestamp, each.GetTimestamp()) + } + return maxTimestamp +} + +func findMinTimestamp(events []conn.KernEvent) uint64 { + var minTimestamp uint64 = math.MaxUint64 + for _, each := range events { + minTimestamp = min(minTimestamp, each.GetTimestamp()) + } + return minTimestamp +} + func KernEventsToEventDetails[K analysisCommon.PacketEventDetail | analysisCommon.SyscallEventDetail](kernEvents []conn.KernEvent) []K { if len(kernEvents) == 0 { return []K{} diff --git a/agent/buffer/stream_buffer.go b/agent/buffer/stream_buffer.go index 980528bb..585de32e 100644 --- a/agent/buffer/stream_buffer.go +++ b/agent/buffer/stream_buffer.go @@ -128,7 +128,11 @@ func (sb *StreamBuffer) FindTimestampBySeq(targetSeq uint64) (uint64, bool) { return value.(uint64), true } -func (sb *StreamBuffer) Add(seq uint64, data []byte, timestamp uint64) { +func (sb *StreamBuffer) Add(seq uint64, data []byte, timestamp uint64) bool { + _, found := sb.timestamps.Get(seq) + if found { + return false + } dataLen := uint64(len(data)) newBuffer := &Buffer{ buf: data, @@ -137,16 +141,16 @@ func (sb *StreamBuffer) Add(seq uint64, data []byte, timestamp uint64) { if sb.IsEmpty() { sb.updateTimestamp(seq, timestamp) sb.buffers = append(sb.buffers, newBuffer) - return + return true } if sb.Position0()-int(seq) >= maxBytesGap { - return + return true } if int(seq)-sb.PositionN() >= maxBytesGap { sb.Clear() sb.buffers = append(sb.buffers, newBuffer) sb.updateTimestamp(seq, timestamp) - return + return true } leftIndex, leftBuffer := sb.findLeftBufferBySeq(seq) @@ -180,6 +184,7 @@ func (sb *StreamBuffer) Add(seq uint64, data []byte, timestamp uint64) { } sb.updateTimestamp(seq, timestamp) sb.shrinkBufferUntilSizeBelowCapacity() + return true } func (sb *StreamBuffer) updateTimestamp(seq uint64, timestamp uint64) { diff --git a/agent/conn/conntrack.go b/agent/conn/conntrack.go index 8c0370cf..0b4e353d 100644 --- a/agent/conn/conntrack.go +++ b/agent/conn/conntrack.go @@ -209,7 +209,7 @@ func (c *ConnManager) FindConnection4Exactly(TgidFd uint64) *Connection4 { } } -func (c *ConnManager) FindConnection4Or(TgidFd uint64, ts uint64) *Connection4 { +func (c *ConnManager) LookupConnection4ByTimestamp(TgidFd uint64, ts uint64) *Connection4 { v, _ := c.connMap.Load(TgidFd) connection, _ := v.(*Connection4) if connection == nil { @@ -219,14 +219,8 @@ func (c *ConnManager) FindConnection4Or(TgidFd uint64, ts uint64) *Connection4 { return connection } else { curConnList := connection.prevConn - if len(curConnList) > 0 { - lastPrevConn := curConnList[len(curConnList)-1] - if lastPrevConn.CloseTs != 0 && lastPrevConn.CloseTs < ts { - return connection - } - } for idx := len(curConnList) - 1; idx >= 0; idx-- { - if curConnList[idx].ConnectStartTs < ts { + if curConnList[idx].timeBoundCheck(ts) { return curConnList[idx] } } @@ -265,6 +259,19 @@ func (c *Connection4) ProtocolInferred() bool { return (c.Protocol != bpf.AgentTrafficProtocolTKProtocolUnknown) && (c.Protocol != bpf.AgentTrafficProtocolTKProtocolUnset) } +func (c *Connection4) timeBoundCheck(toCheck uint64) bool { + if c.ConnectStartTs == 0 { + return true + } + if toCheck < c.ConnectStartTs { + return false + } + if c.CloseTs != 0 && toCheck > c.CloseTs { + return false + } + return true +} + func (c *Connection4) extractSockKeys() (bpf.AgentSockKey, bpf.AgentSockKey) { var key bpf.AgentSockKey key.Dip = [2]uint64(common.BytesToSockKey(c.RemoteIp)) @@ -349,7 +356,10 @@ func (c *Connection4) doUpdateConnIdMapProtocolToUnknwon(key bpf.AgentSockKey, m func (c *Connection4) OnKernEvent(event *bpf.AgentKernEvt) bool { isReq, ok := isReq(c, event) if event.Len > 0 { - c.StreamEvents.AddKernEvent(event) + alreadyExisted := c.StreamEvents.AddKernEvent(event) + if !alreadyExisted { + return false + } } else if ok { if (event.Flags&uint8(common.TCP_FLAGS_SYN) != 0) && !isReq && event.Step == bpf.AgentStepTIP_IN { // 接收到Server给的Syn包 @@ -375,12 +385,16 @@ func (c *Connection4) OnKernEvent(event *bpf.AgentKernEvt) bool { } return true } -func (c *Connection4) addDataToBufferAndTryParse(data []byte, ke *bpf.AgentKernEvt) { +func (c *Connection4) addDataToBufferAndTryParse(data []byte, ke *bpf.AgentKernEvt) bool { + addedToBuffer := false isReq, _ := isReq(c, ke) if isReq { - c.reqStreamBuffer.Add(ke.Seq, data, ke.Ts) + addedToBuffer = c.reqStreamBuffer.Add(ke.Seq, data, ke.Ts) } else { - c.respStreamBuffer.Add(ke.Seq, data, ke.Ts) + addedToBuffer = c.respStreamBuffer.Add(ke.Seq, data, ke.Ts) + } + if !addedToBuffer { + return false } reqSteamMessageType := protocol.Request if c.Role == bpf.AgentEndpointRoleTKRoleUnknown { @@ -392,6 +406,7 @@ func (c *Connection4) addDataToBufferAndTryParse(data []byte, ke *bpf.AgentKernE } c.parseStreamBuffer(c.reqStreamBuffer, reqSteamMessageType, &c.ReqQueue, ke) c.parseStreamBuffer(c.respStreamBuffer, respSteamMessageType, &c.RespQueue, ke) + return true } func (c *Connection4) OnSslDataEvent(data []byte, event *bpf.SslData, recordChannel chan RecordWithConn) { if len(data) > 0 { @@ -413,20 +428,24 @@ func (c *Connection4) OnSslDataEvent(data []byte, event *bpf.SslData, recordChan } } } -func (c *Connection4) OnSyscallEvent(data []byte, event *bpf.SyscallEventData, recordChannel chan RecordWithConn) { +func (c *Connection4) OnSyscallEvent(data []byte, event *bpf.SyscallEventData, recordChannel chan RecordWithConn) bool { + addedToBuffer := true if len(data) > 0 { if c.ssl { if common.ConntrackLog.Level >= logrus.WarnLevel { common.ConntrackLog.Warnf("%s is ssl, but receive syscall event with data!", c.ToString()) } } else { - c.addDataToBufferAndTryParse(data, &event.SyscallEvent.Ke) + addedToBuffer = c.addDataToBufferAndTryParse(data, &event.SyscallEvent.Ke) } } else if event.SyscallEvent.GetSourceFunction() == bpf.AgentSourceFunctionTKSyscallSendfile { // sendfile has no data, so we need to fill a fake data common.ConntrackLog.Errorln("sendfile has no data, so we need to fill a fake data") fakeData := make([]byte, event.SyscallEvent.Ke.Len) - c.addDataToBufferAndTryParse(fakeData, &event.SyscallEvent.Ke) + addedToBuffer = c.addDataToBufferAndTryParse(fakeData, &event.SyscallEvent.Ke) + } + if !addedToBuffer { + return false } c.StreamEvents.AddSyscallEvent(event) @@ -441,6 +460,7 @@ func (c *Connection4) OnSyscallEvent(data []byte, event *bpf.SyscallEventData, r recordChannel <- RecordWithConn{record, c} } } + return true } func (c *Connection4) parseStreamBuffer(streamBuffer *buffer.StreamBuffer, messageType protocol.MessageType, resultQueue *[]protocol.ParsedMessage, ke *bpf.AgentKernEvt) { diff --git a/agent/conn/kern_event_handler.go b/agent/conn/kern_event_handler.go index a9749fb5..b8995e55 100644 --- a/agent/conn/kern_event_handler.go +++ b/agent/conn/kern_event_handler.go @@ -7,6 +7,7 @@ import ( "kyanos/common" "kyanos/monitor" "slices" + "strings" "sync" "github.com/jefurry/logrus" @@ -85,7 +86,7 @@ func (s *KernEventStream) AddSyscallEvent(event *bpf.SyscallEventData) { s.AddKernEvent(&event.SyscallEvent.Ke) } -func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) { +func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) bool { s.kernEventsMu.Lock() defer s.kernEventsMu.Unlock() s.discardEventsIfNeeded() @@ -98,9 +99,24 @@ func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) { index, found := slices.BinarySearchFunc(kernEvtSlice, KernEvent{seq: event.Seq}, func(i KernEvent, j KernEvent) int { return cmp.Compare(i.seq, j.seq) }) + isNicEvnt := event.Step == bpf.AgentStepTDEV_OUT || event.Step == bpf.AgentStepTDEV_IN + var kernEvent *KernEvent if found { - kernEvent = &kernEvtSlice[index] + oldKernEvent := &kernEvtSlice[index] + if oldKernEvent.timestamp > event.Ts && !isNicEvnt { + // this is a duplicate event which belongs to a future conn + oldKernEvent.seq = event.Seq + oldKernEvent.len = int(event.Len) + oldKernEvent.timestamp = event.Ts + oldKernEvent.step = event.Step + kernEvent = oldKernEvent + } else if !isNicEvnt { + kernEvent = &kernEvtSlice[index] + return false + } else { + kernEvent = &kernEvtSlice[index] + } } else { kernEvent = &KernEvent{ seq: event.Seq, @@ -110,20 +126,22 @@ func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) { } } - if event.Step == bpf.AgentStepTDEV_OUT || event.Step == bpf.AgentStepTDEV_IN { + if isNicEvnt { if kernEvent.attributes == nil { kernEvent.attributes = make(map[string]any) } - ifname, err := common.GetInterfaceNameByIndex(int(event.Ifindex), int(event.ConnIdS.TgidFd>>32)) + ifname, err := getInterfaceNameByIndex(int(event.Ifindex), int(event.ConnIdS.TgidFd>>32)) if err != nil { ifname = "unknown" } - kernEvent.UpdateIfTimestampAttr(ifname, int64(event.Ts)) - } else if found { - return - // panic("found duplicate kern event on same seq") + updated := kernEvent.UpdateIfTimestampAttr(ifname, int64(event.Ts)) + if !updated { + return false + } + } + if !found { + kernEvtSlice = slices.Insert(kernEvtSlice, index, *kernEvent) } - kernEvtSlice = slices.Insert(kernEvtSlice, index, *kernEvent) if len(kernEvtSlice) > s.maxLen { if common.ConntrackLog.Level >= logrus.DebugLevel { common.ConntrackLog.Debugf("kern event stream size: %d exceed maxLen", len(kernEvtSlice)) @@ -134,6 +152,7 @@ func (s *KernEventStream) AddKernEvent(event *bpf.AgentKernEvt) { } s.kernEvents[event.Step] = kernEvtSlice } + return true } func (s *KernEventStream) FindSslEventsBySeqAndLen(step bpf.AgentStepT, seq uint64, len int) []SslEvent { @@ -289,8 +308,63 @@ func (kernevent *KernEvent) GetAttributes() map[string]any { return kernevent.attributes } -func (kernevent *KernEvent) UpdateIfTimestampAttr(ifname string, time int64) { +func (kernevent *KernEvent) UpdateIfTimestampAttr(ifname string, time int64) bool { + if timestamp, ok := kernevent.attributes["time-"+ifname]; ok { + if ts, valid := timestamp.(int64); valid { + if ts < time { + return false + } + } + } + kernevent.attributes["time-"+ifname] = time + return true +} + +func (kernevent *KernEvent) GetMaxIfItmestampAttr() (int64, string, bool) { + maxTimestamp := int64(0) + var maxIfname string + found := false + for key, value := range kernevent.attributes { + if strings.HasPrefix(key, "time-") { + if timestamp, ok := value.(int64); ok { + if timestamp > maxTimestamp { + maxTimestamp = timestamp + maxIfname = strings.TrimPrefix(key, "time-") + found = true + } + } + } + } + return maxTimestamp, maxIfname, found +} + +func (kernevent *KernEvent) GetMinIfItmestampAttr() (int64, string, bool) { + minTimestamp := int64(^uint64(0) >> 1) // Max int64 value + var minIfname string + found := false + for key, value := range kernevent.attributes { + if strings.HasPrefix(key, "time-") { + if timestamp, ok := value.(int64); ok { + if timestamp < minTimestamp { + minTimestamp = timestamp + minIfname = strings.TrimPrefix(key, "time-") + found = true + } + } + } + } + return minTimestamp, minIfname, found +} + +func (kernevent *KernEvent) GetTimestampByIfname(ifname string) (int64, bool) { + key := "time-" + ifname + if timestamp, ok := kernevent.attributes[key]; ok { + if ts, valid := timestamp.(int64); valid { + return ts, true + } + } + return 0, false } type SslEvent struct { diff --git a/agent/conn/net.go b/agent/conn/net.go new file mode 100644 index 00000000..27ac02fb --- /dev/null +++ b/agent/conn/net.go @@ -0,0 +1,144 @@ +package conn + +import ( + "errors" + "kyanos/common" + "os" + "os/exec" + "regexp" + "strconv" + "strings" + "sync" + "syscall" +) + +var ifIdxToName map[string]map[string]string = make(map[string]map[string]string) +var netnsIDMap map[string]string = make(map[string]string) +var lock *sync.Mutex = &sync.Mutex{} + +func init() { + nicsFromAllNs, err := GetAllNICs() + if err != nil { + return + } + ifIdxToName = nicsFromAllNs + netnsIDMap, _ = getNetnsIDMap() +} + +func getInterfaceNameByIndex(index int, pid int) (string, error) { + netnsName, found := netnsIDMap[strconv.FormatInt(common.GetNetworkNamespaceFromPid(pid), 10)] + if !found { + netnsName = "default" + } + exist, found := ifIdxToName[netnsName] + if found { + ifName, found := exist[strconv.Itoa(index)] + if found { + return ifName, nil + } + } + return "", errors.New("interface not found") +} + +func parseIpCmdLine(line string) (int, string, bool) { + // 使用正则表达式匹配接口索引和接口名称 + // 假设接口索引是以数字开头,后面跟着冒号和接口名称 + re := regexp.MustCompile(`^(\d+):\s*([^:]+)`) + match := re.FindStringSubmatch(line) + if len(match) < 3 { + return 0, "", false // 没有匹配到 + } + + index, err := strconv.Atoi(match[1]) + if err != nil { + return 0, "", false // 转换索引失败 + } + + interfaceName := match[2] + return index, interfaceName, true // 成功提取 +} + +var ( + hostNetNs int64 +) + +func init() { + hostNetNs = common.GetNetworkNamespaceFromPid(1) +} + +func GetAllNICs() (map[string]map[string]string, error) { + result := make(map[string]map[string]string) + + // 默认命名空间 + defaultNS, err := exec.Command("ip", "link", "show").Output() + if err != nil { + return nil, err + } + result["default"] = parseLinkOutput(string(defaultNS)) + + // 自定义网络命名空间 + nsList, err := exec.Command("ip", "netns", "list").Output() + if err != nil { + // 若命令失败,可能没有任何自定义网络命名空间 + return result, nil + } + namespaces := strings.Split(strings.TrimSpace(string(nsList)), "\n") + + for _, ns := range namespaces { + parts := strings.Fields(ns) + if len(parts) == 0 { + continue + } + nsName := parts[0] + nsOutput, err := exec.Command("ip", "netns", "exec", nsName, "ip", "link", "show").Output() + if err != nil { + // 忽略该命名空间的错误 + continue + } + result[nsName] = parseLinkOutput(string(nsOutput)) + } + + return result, nil +} + +func parseLinkOutput(output string) map[string]string { + interfaces := make(map[string]string) + re := regexp.MustCompile(`^(\d+):\s+([^:]+):`) + for _, line := range strings.Split(output, "\n") { + match := re.FindStringSubmatch(strings.TrimSpace(line)) + if len(match) == 3 { + interfaces[match[1]] = match[2] + } + } + return interfaces +} + +func getNetnsIDMap() (map[string]string, error) { + nsMap := make(map[string]string) + + dir, err := os.Open("/var/run/netns") + if err != nil { + return nil, err + } + defer dir.Close() + + files, err := dir.Readdirnames(0) + if err != nil { + return nil, err + } + + for _, f := range files { + info, err := os.Stat("/var/run/netns/" + f) + if err != nil { + continue + } + statT, ok := info.Sys().(*syscall.Stat_t) + if !ok { + continue + } + inode := strconv.FormatUint(statT.Ino, 10) + nsMap[inode] = f + } + + return nsMap, nil +} diff --git a/agent/conn/processor.go b/agent/conn/processor.go index 6cfc83d7..761985f3 100644 --- a/agent/conn/processor.go +++ b/agent/conn/processor.go @@ -95,6 +95,24 @@ type Processor struct { side common.SideEnum recordProcessor *RecordsProcessor conntrackCloseWaitTimeMills int + tempKernEvents []TimedEvent + tempSyscallEvents []TimedSyscallEvent + tempSslEvents []TimedSslEvent +} + +type TimedEvent struct { + event *bpf.AgentKernEvt + timestamp time.Time +} + +type TimedSyscallEvent struct { + event *bpf.SyscallEventData + timestamp time.Time +} + +type TimedSslEvent struct { + event *bpf.SslData + timestamp time.Time } func initProcessor(name string, wg *sync.WaitGroup, ctx context.Context, connManager *ConnManager, filter protocol.ProtocolFilter, @@ -116,6 +134,9 @@ func initProcessor(name string, wg *sync.WaitGroup, ctx context.Context, connMan records: make([]RecordWithConn, 0), } p.conntrackCloseWaitTimeMills = conntrackCloseWaitTimeMills + p.tempKernEvents = make([]TimedEvent, 0, 100) // Preallocate with a capacity of 100 + p.tempSyscallEvents = make([]TimedSyscallEvent, 0, 100) // Preallocate with a capacity of 100 + p.tempSslEvents = make([]TimedSslEvent, 0, 100) // Preallocate with a capacity of 100 return p } @@ -137,6 +158,11 @@ func (p *Processor) AddKernEvent(record *bpf.AgentKernEvt) { func (p *Processor) run() { recordChannel := make(chan RecordWithConn) go p.recordProcessor.Run(recordChannel, time.NewTicker(1*time.Second)) + + // Timer to process kern, syscall, and ssl events + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + for { select { case <-p.ctx.Done(): @@ -171,7 +197,7 @@ func (p *Processor) run() { }(conn) } else if event.ConnType == bpf.AgentConnTypeTKProtocolInfer { // 协议推断 - conn = p.connManager.FindConnection4Or(TgidFd, event.Ts+common.LaunchEpochTime) + conn = p.connManager.LookupConnection4ByTimestamp(TgidFd, event.Ts+common.LaunchEpochTime) // previousProtocol := conn.Protocol if conn != nil && conn.Status != Closed { conn.Protocol = event.ConnInfo.Protocol @@ -199,16 +225,16 @@ func (p *Processor) run() { if isProtocolInterested && !isSideNotMatched(p, conn) { if conn.Protocol != bpf.AgentTrafficProtocolTKProtocolUnknown { for _, sysEvent := range conn.TempSyscallEvents { - if sysEvent.SyscallEvent.Ke.Ts > conn.ConnectStartTs { + if conn.timeBoundCheck(sysEvent.SyscallEvent.Ke.Ts) { if common.ConntrackLog.Level >= logrus.DebugLevel { common.ConntrackLog.Debugf("%s process %d temp syscall events before infer\n", conn.ToString(), len(conn.TempSyscallEvents)) } conn.OnSyscallEvent(sysEvent.Buf, sysEvent, recordChannel) } } - conn.TempConnEvents = conn.TempConnEvents[0:0] + conn.TempSyscallEvents = conn.TempSyscallEvents[0:0] for _, sslEvent := range conn.TempSslEvents { - if sslEvent.SslEventHeader.Ke.Ts > conn.ConnectStartTs { + if conn.timeBoundCheck(sslEvent.SslEventHeader.Ke.Ts) { if common.ConntrackLog.Level >= logrus.DebugLevel { common.ConntrackLog.Debugf("%s process %d temp ssl events before infer\n", conn.ToString(), len(conn.TempSslEvents)) } @@ -217,8 +243,17 @@ func (p *Processor) run() { } conn.TempSslEvents = conn.TempSslEvents[0:0] conn.UpdateConnectionTraceable(true) + // handle kern events + for _, kernEvent := range conn.TempKernEvents { + if conn.timeBoundCheck(kernEvent.Ts) { + if common.ConntrackLog.Level >= logrus.DebugLevel { + common.ConntrackLog.Debugf("%s process %d temp kern events before infer\n", conn.ToString(), len(conn.TempKernEvents)) + } + conn.OnKernEvent(kernEvent) + } + } + conn.TempKernEvents = conn.TempKernEvents[0:0] } - conn.TempKernEvents = conn.TempKernEvents[0:0] conn.TempConnEvents = conn.TempConnEvents[0:0] } else { if common.ConntrackLog.Level >= logrus.DebugLevel { @@ -242,124 +277,246 @@ func (p *Processor) run() { common.ConntrackLog.Debugf("[conn][ts=%d] %s | type: %s, protocol: %d, \n", event.Ts, conn.ToString(), eventType, conn.Protocol) } case event := <-p.syscallEvents: - tgidFd := event.SyscallEvent.Ke.ConnIdS.TgidFd - conn := p.connManager.FindConnection4Or(tgidFd, event.SyscallEvent.Ke.Ts+common.LaunchEpochTime) - event.SyscallEvent.Ke.Ts += common.LaunchEpochTime - if conn != nil && conn.Status == Closed { - conn.AddSyscallEvent(event) - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[syscall][closed conn][len=%d][ts=%d][fn=%d]%s | %s", max(event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Len), event.SyscallEvent.Ke.Ts, event.SyscallEvent.GetSourceFunction(), conn.ToString(), string(event.Buf)) - } - continue - } - if conn != nil && !conn.tracable { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[syscall][no-trace][len=%d][ts=%d]%s | %s", event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Ts, conn.ToString(), string(event.Buf)) - } - continue - } - if conn != nil && conn.ProtocolInferred() { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[syscall][len=%d][ts=%d][fn=%d]%s | %s", max(event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Len), event.SyscallEvent.Ke.Ts, event.SyscallEvent.GetSourceFunction(), conn.ToString(), string(event.Buf)) - } + p.handleSyscallEvent(event, recordChannel) + case event := <-p.sslEvents: + p.handleSslEvent(event, recordChannel) + case event := <-p.kernEvents: + p.handleKernEvent(event, recordChannel) + case <-ticker.C: + p.processTimedKernEvents(recordChannel) + p.processTimedSyscallEvents(recordChannel) + p.processTimedSslEvents(recordChannel) + } + } +} - conn.OnSyscallEvent(event.Buf, event, recordChannel) - } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnset { - conn.AddSyscallEvent(event) - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[syscall][protocol unset][ts=%d][len=%d]%s | %s", event.SyscallEvent.Ke.Ts, event.SyscallEvent.BufSize, conn.ToString(), string(event.Buf)) - } +func (p *Processor) handleKernEvent(event *bpf.AgentKernEvt, recordChannel chan RecordWithConn) { + // Add event to the temporary queue + p.tempKernEvents = append(p.tempKernEvents, TimedEvent{event: event, timestamp: time.Now()}) - } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnknown { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[syscall][protocol unknown][ts=%d][len=%d]%s | %s", event.SyscallEvent.Ke.Ts, event.SyscallEvent.BufSize, conn.ToString(), string(event.Buf)) - } - } else { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[syscall][no conn][ts=%d][tgid=%d fd=%d][len=%d] %s", event.SyscallEvent.Ke.Ts, tgidFd>>32, uint32(tgidFd), event.SyscallEvent.BufSize, string(event.Buf)) - } - } - case event := <-p.sslEvents: - tgidFd := event.SslEventHeader.Ke.ConnIdS.TgidFd - conn := p.connManager.FindConnection4Or(tgidFd, event.SslEventHeader.Ke.Ts+common.LaunchEpochTime) - event.SslEventHeader.Ke.Ts += common.LaunchEpochTime - if conn != nil && conn.Status == Closed { - conn.AddSslEvent(event) - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[ssl][closed conn][len=%d][ts=%d]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, conn.ToString(), string(event.Buf)) - } - continue - } - if conn != nil && !conn.tracable { - conn.AddSslEvent(event) - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[ssl][no-trace][len=%d][ts=%d]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, conn.ToString(), string(event.Buf)) - } - continue - } - if conn != nil && conn.ProtocolInferred() { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[ssl][len=%d][ts=%d]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, conn.ToString(), string(event.Buf)) - } + // Process events in the queue that have been there for more than 100ms + p.processOldKernEvents(recordChannel) +} - conn.OnSslDataEvent(event.Buf, event, recordChannel) - } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnset { - conn.AddSslEvent(event) - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[ssl][protocol unset][len=%d]%s | %s", event.SslEventHeader.BufSize, conn.ToString(), string(event.Buf)) - } - } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnknown { - conn.AddSslEvent(event) +func (p *Processor) processTimedKernEvents(recordChannel chan RecordWithConn) { + p.processOldKernEvents(recordChannel) +} + +func (p *Processor) processOldKernEvents(recordChannel chan RecordWithConn) { + now := time.Now() + lastIndex := 0 + for i := 0; i < len(p.tempKernEvents); i++ { + if now.Sub(p.tempKernEvents[i].timestamp) > 100*time.Millisecond { + p.processKernEvent(p.tempKernEvents[i].event, recordChannel) + lastIndex = i + 1 + } else { + break + } + } + p.tempKernEvents = p.tempKernEvents[lastIndex:] +} + +func (p *Processor) processKernEvent(event *bpf.AgentKernEvt, recordChannel chan RecordWithConn) { + tgidFd := event.ConnIdS.TgidFd + event.Ts += common.LaunchEpochTime + conn := p.connManager.LookupConnection4ByTimestamp(tgidFd, event.Ts) + timeCheck := conn != nil && conn.timeBoundCheck(event.Ts) + if conn != nil && conn.Status == Closed { + if timeCheck { + conn.OnKernEvent(event) + } else { + conn.AddKernEvent(event) + } + + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[closed conn]%s", FormatKernEvt(event, conn)) + } + return + } + if event.Len > 0 && conn != nil && conn.Protocol != bpf.AgentTrafficProtocolTKProtocolUnknown { + if conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnset { + added := conn.OnKernEvent(event) + + if added { if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[ssl][protocol unknown][len=%d]%s | %s", event.SslEventHeader.BufSize, conn.ToString(), string(event.Buf)) + common.BPFEventLog.Debugf("[protocol-unset]%s", FormatKernEvt(event, conn)) } } else { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[ssl][no conn][tgid=%d fd=%d][len=%d] %s", tgidFd>>32, uint32(tgidFd), event.SslEventHeader.BufSize, string(event.Buf)) - } - } - case event := <-p.kernEvents: - tgidFd := event.ConnIdS.TgidFd - conn := p.connManager.FindConnection4Or(tgidFd, event.Ts+common.LaunchEpochTime) - if conn != nil && conn.Status == Closed { conn.AddKernEvent(event) - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[closed conn]%s", FormatKernEvt(event, conn)) - } - continue } - event.Ts += common.LaunchEpochTime - if event.Len > 0 && conn != nil && conn.Protocol != bpf.AgentTrafficProtocolTKProtocolUnknown { - if conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnset { - conn.OnKernEvent(event) - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[protocol-unset]%s", FormatKernEvt(event, conn)) - } - } else if conn.Protocol != bpf.AgentTrafficProtocolTKProtocolUnknown { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("%s", FormatKernEvt(event, conn)) - } - conn.OnKernEvent(event) - } - } else if event.Len > 0 && conn != nil { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[protocol-unknown]%s\n", FormatKernEvt(event, conn)) - } - } else if event.Len == 0 && conn != nil { - conn.OnKernEvent(event) - } else if conn == nil { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[no-conn]%s\n", FormatKernEvt(event, conn)) - } - } else { - if common.BPFEventLog.Level >= logrus.DebugLevel { - common.BPFEventLog.Debugf("[other]%s\n", FormatKernEvt(event, conn)) - } + } else if conn.Protocol != bpf.AgentTrafficProtocolTKProtocolUnknown { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("%s", FormatKernEvt(event, conn)) } + conn.OnKernEvent(event) + } + } else if event.Len > 0 && conn != nil { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[protocol-unknown]%s\n", FormatKernEvt(event, conn)) + } + } else if event.Len == 0 && conn != nil { + conn.OnKernEvent(event) + } else if conn == nil { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[no-conn]%s\n", FormatKernEvt(event, conn)) + } + } else { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[other]%s\n", FormatKernEvt(event, conn)) + } + } +} + +func (p *Processor) handleSyscallEvent(event *bpf.SyscallEventData, recordChannel chan RecordWithConn) { + // Add event to the temporary queue + p.tempSyscallEvents = append(p.tempSyscallEvents, TimedSyscallEvent{event: event, timestamp: time.Now()}) + + // Process events in the queue that have been there for more than 100ms + p.processOldSyscallEvents(recordChannel) + // p.processSyscallEvent(event, recordChannel) +} + +func (p *Processor) processTimedSyscallEvents(recordChannel chan RecordWithConn) { + p.processOldSyscallEvents(recordChannel) +} + +func (p *Processor) processOldSyscallEvents(recordChannel chan RecordWithConn) { + now := time.Now() + lastIndex := 0 + for i := 0; i < len(p.tempSyscallEvents); i++ { + if now.Sub(p.tempSyscallEvents[i].timestamp) > 100*time.Millisecond { + p.processSyscallEvent(p.tempSyscallEvents[i].event, recordChannel) + lastIndex = i + 1 + } else { + break } } + p.tempSyscallEvents = p.tempSyscallEvents[lastIndex:] } + +func (p *Processor) processSyscallEvent(event *bpf.SyscallEventData, recordChannel chan RecordWithConn) { + tgidFd := event.SyscallEvent.Ke.ConnIdS.TgidFd + conn := p.connManager.LookupConnection4ByTimestamp(tgidFd, event.SyscallEvent.Ke.Ts+common.LaunchEpochTime) + event.SyscallEvent.Ke.Ts += common.LaunchEpochTime + + timeCheck := conn != nil && conn.timeBoundCheck(event.SyscallEvent.Ke.Ts) + + if conn != nil && conn.Status == Closed { + if timeCheck { + conn.OnSyscallEvent(event.Buf, event, recordChannel) + } else { + conn.AddSyscallEvent(event) + } + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[syscall][closed conn][len=%d][ts=%d][fn=%d][check=%v]%s | %s", max(event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Len), event.SyscallEvent.Ke.Ts, event.SyscallEvent.GetSourceFunction(), timeCheck, conn.ToString(), string(event.Buf)) + } + return + } + if conn != nil && !conn.tracable { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[syscall][no-trace][len=%d][ts=%d]%s | %s", event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Ts, conn.ToString(), string(event.Buf)) + } + return + } + if conn != nil && conn.ProtocolInferred() && timeCheck { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[syscall][len=%d][ts=%d][fn=%d]%s | %s", max(event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Len), event.SyscallEvent.Ke.Ts, event.SyscallEvent.GetSourceFunction(), conn.ToString(), string(event.Buf)) + } + + addedToBuffer := conn.OnSyscallEvent(event.Buf, event, recordChannel) + if addedToBuffer { + conn.AddSyscallEvent(event) + } + } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnset { + conn.AddSyscallEvent(event) + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[syscall][protocol unset][ts=%d][len=%d]%s | %s", event.SyscallEvent.Ke.Ts, event.SyscallEvent.BufSize, conn.ToString(), string(event.Buf)) + } + + } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnknown { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[syscall][protocol unknown][ts=%d][len=%d]%s | %s", event.SyscallEvent.Ke.Ts, event.SyscallEvent.BufSize, conn.ToString(), string(event.Buf)) + } + } else { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[syscall][no conn][ts=%d][tgid=%d fd=%d][len=%d] %s", event.SyscallEvent.Ke.Ts, tgidFd>>32, uint32(tgidFd), event.SyscallEvent.BufSize, string(event.Buf)) + } + } +} + +func (p *Processor) handleSslEvent(event *bpf.SslData, recordChannel chan RecordWithConn) { + // Add event to the temporary queue + p.tempSslEvents = append(p.tempSslEvents, TimedSslEvent{event: event, timestamp: time.Now()}) + + // Process events in the queue that have been there for more than 100ms + p.processOldSslEvents(recordChannel) +} + +func (p *Processor) processTimedSslEvents(recordChannel chan RecordWithConn) { + p.processOldSslEvents(recordChannel) +} + +func (p *Processor) processOldSslEvents(recordChannel chan RecordWithConn) { + now := time.Now() + lastIndex := 0 + for i := 0; i < len(p.tempSslEvents); i++ { + if now.Sub(p.tempSslEvents[i].timestamp) > 100*time.Millisecond { + p.processSslEvent(p.tempSslEvents[i].event, recordChannel) + lastIndex = i + 1 + } else { + break + } + } + p.tempSslEvents = p.tempSslEvents[lastIndex:] +} + +func (p *Processor) processSslEvent(event *bpf.SslData, recordChannel chan RecordWithConn) { + tgidFd := event.SslEventHeader.Ke.ConnIdS.TgidFd + conn := p.connManager.LookupConnection4ByTimestamp(tgidFd, event.SslEventHeader.Ke.Ts+common.LaunchEpochTime) + event.SslEventHeader.Ke.Ts += common.LaunchEpochTime + timeCheck := conn != nil && conn.timeBoundCheck(event.SslEventHeader.Ke.Ts) + if conn != nil && conn.Status == Closed { + if timeCheck { + conn.OnSslDataEvent(event.Buf, event, recordChannel) + } else { + conn.AddSslEvent(event) + } + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[ssl][closed conn][len=%d][ts=%d][check=%v]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, timeCheck, conn.ToString(), string(event.Buf)) + } + return + } + if conn != nil && !conn.tracable { + conn.AddSslEvent(event) + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[ssl][no-trace][len=%d][ts=%d]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, conn.ToString(), string(event.Buf)) + } + return + } + if conn != nil && conn.ProtocolInferred() { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[ssl][len=%d][ts=%d]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, conn.ToString(), string(event.Buf)) + } + + conn.OnSslDataEvent(event.Buf, event, recordChannel) + } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnset { + conn.AddSslEvent(event) + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[ssl][protocol unset][len=%d]%s | %s", event.SslEventHeader.BufSize, conn.ToString(), string(event.Buf)) + } + } else if conn != nil && conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnknown { + conn.AddSslEvent(event) + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[ssl][protocol unknown][len=%d]%s | %s", event.SslEventHeader.BufSize, conn.ToString(), string(event.Buf)) + } + } else { + if common.BPFEventLog.Level >= logrus.DebugLevel { + common.BPFEventLog.Debugf("[ssl][no conn][tgid=%d fd=%d][len=%d] %s", tgidFd>>32, uint32(tgidFd), event.SslEventHeader.BufSize, string(event.Buf)) + } + } +} + func isSideNotMatched(p *Processor, conn *Connection4) bool { return (p.side != common.AllSide) && ((conn.Role == bpf.AgentEndpointRoleTKRoleClient) != (p.side == common.ClientSide)) } @@ -380,9 +537,9 @@ func onRoleChanged(p *Processor, conn *Connection4) { func FormatKernEvt(evt *bpf.AgentKernEvt, conn *Connection4) string { var interfaceStr string if evt.Ifindex != 0 { - name, err := common.GetInterfaceNameByIndex(int(evt.Ifindex), int(evt.ConnIdS.TgidFd>>32)) + name, err := getInterfaceNameByIndex(int(evt.Ifindex), int(evt.ConnIdS.TgidFd>>32)) if err != nil { - interfaceStr = "[if=unknown]" + interfaceStr = fmt.Sprintf("[if=%d]", evt.Ifindex) } else { interfaceStr = fmt.Sprintf("[if=%s]", name) } diff --git a/agent/metadata/process.go b/agent/metadata/process.go index 7a6093cb..517ff977 100644 --- a/agent/metadata/process.go +++ b/agent/metadata/process.go @@ -1,17 +1,120 @@ package metadata -import "kyanos/common" +import ( + "context" + "kyanos/bpf" + "kyanos/common" + "sync" + "time" + + "github.com/shirou/gopsutil/process" +) + +var cleanupTimeout = 5 * time.Second const defaultProcDir = "/proc" +type PIDInfo struct { + PID int + NetNS int64 + Timestamp time.Time +} + var ( HostMntNs int64 HostPidNs int64 HostNetNs int64 + pidCache = sync.Map{} + deadPids = sync.Map{} + cacheLock sync.Mutex ) func init() { HostPidNs = common.GetPidNamespaceFromPid(1) HostMntNs = common.GetMountNamespaceFromPid(1) HostNetNs = common.GetNetworkNamespaceFromPid(1) + go func() { + for range time.Tick(1 * time.Second) { + cleanupDeadPIDs() + } + }() +} + +func StartHandleSchedExecEvent(ch chan *bpf.AgentProcessExecEvent, ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case execEvent := <-ch: + proc, err := process.NewProcess(execEvent.Pid) + if err != nil { + common.AgentLog.Infof("Failed to create process for PID %d: %v", execEvent.Pid, err) + continue + } + startPID(int(proc.Pid), common.GetNetworkNamespaceFromPid(int(proc.Pid))) + } + } + }() +} + +func StartHandleSchedExitEvent(ch chan *bpf.AgentProcessExitEvent, ctx context.Context) { + go func() { + for { + select { + case <-ctx.Done(): + return + case execEvent := <-ch: + stopPID(int(execEvent.Pid)) + } + } + }() +} + +func startPID(pid int, netns int64) { + cacheLock.Lock() + defer cacheLock.Unlock() + common.AgentLog.Infof("Start tracking PID %d, netns: %d", pid, netns) + pidCache.Store(pid, PIDInfo{ + PID: pid, + NetNS: netns, + Timestamp: time.Now(), + }) +} + +func stopPID(pid int) { + cacheLock.Lock() + defer cacheLock.Unlock() + common.AgentLog.Debugf("Stop tracking PID %d, netns: %d", pid) + if info, exists := pidCache.Load(pid); exists { + pidCache.Delete(pid) + pidInfo := info.(PIDInfo) + pidInfo.Timestamp = time.Now() + deadPids.Store(pid, pidInfo) + } +} + +func cleanupDeadPIDs() { + cacheLock.Lock() + defer cacheLock.Unlock() + now := time.Now() + deadPids.Range(func(key, value interface{}) bool { + info := value.(PIDInfo) + if now.Sub(info.Timestamp) > cleanupTimeout { + deadPids.Delete(key) + } + return true + }) +} + +func GetPidInfo(pid int) PIDInfo { + if info, exists := pidCache.Load(pid); exists { + return info.(PIDInfo) + } + // find from deadPids + if info, exists := deadPids.Load(pid); exists { + return info.(PIDInfo) + } + + return PIDInfo{} } diff --git a/agent/render/watch/time_detail.go b/agent/render/watch/time_detail.go index 1c5ebf3c..1a9add4d 100644 --- a/agent/render/watch/time_detail.go +++ b/agent/render/watch/time_detail.go @@ -99,20 +99,28 @@ func addSocketBufferDiagram(duration int64, prevDiagram *diagrams.Shape, shapes Content: "", Type: arrowType, } - connectFunc(prevDiagram, &lastNicToSocketArrow) + if prevDiagram != nil { + connectFunc(prevDiagram, &lastNicToSocketArrow) + } socketBuffer := diagrams.Shape{ Content: fmt.Sprintf(" Socket(used:%.2fms) ", c.ConvertDurationToMillisecondsIfNeeded(float64(duration), false)), Type: diagrams.Rectangle, } - connectFunc(&lastNicToSocketArrow, &socketBuffer) + if prevDiagram != nil { + connectFunc(&lastNicToSocketArrow, &socketBuffer) + } socketToAppArrow := diagrams.Shape{ Content: "", Type: arrowType, } connectFunc(&socketBuffer, &socketToAppArrow) defer func() { - *shapes = append(*shapes, &lastNicToSocketArrow, &socketBuffer) + if prevDiagram != nil { + *shapes = append(*shapes, &lastNicToSocketArrow, &socketBuffer) + } else { + *shapes = append(*shapes, &socketBuffer) + } }() return &socketToAppArrow } @@ -121,7 +129,7 @@ func getFlowChartString(diagram *diagrams.Diagram) string { s := diagrams.NewStore() canvasRow := 200 canvas := draw.NewCanvas(canvasRow, canvasRow) - canvas.Cursor.X = canvasRow / 4 + canvas.Cursor.X = calculateFirstComponentOffsetAtX(diagram) c.DefaultLog.Debugf("shapes: %v", diagram.S) for _, shape := range diagram.S { c.DefaultLog.Debugf("shape: %v", shape) @@ -171,6 +179,39 @@ func ViewRecordTimeDetailAsFlowChartForServer(r *common.AnnotatedRecord) string return getFlowChartString(diagram) } +func calculateFirstComponentOffsetAtX(diagram *diagrams.Diagram) (maxX int) { + upperX := 0 + bottomX := 0 + downArrowIndex := 0 + // 1. find the down arrow + for i, shape := range diagram.S { + if shape.Type == diagrams.DownArrow { + downArrowIndex = i + break + } + } + for i, shape := range diagram.S { + if i < downArrowIndex { + if shape.Type > diagrams.HRectangle { + upperX += diagrams.ARROWLEN + 1 + } else { + upperX += len(shape.Content) + } + } else if i > downArrowIndex { + if shape.Type > diagrams.HRectangle { + bottomX += diagrams.ARROWLEN + 1 + } else { + bottomX += len(shape.Content) + } + } + } + if upperX > bottomX { + return len(diagram.S[0].Content) + 10 + } else { + return bottomX - upperX + len(diagram.S[0].Content) + 10 + } +} + func ViewRecordTimeDetailAsFlowChartForClientSide(r *common.AnnotatedRecord) string { shapes := make([]*diagrams.Shape, 0) diagram := diagrams.New() diff --git a/bpf/loader/container.go b/bpf/loader/container.go index 2288e07f..48123420 100644 --- a/bpf/loader/container.go +++ b/bpf/loader/container.go @@ -124,18 +124,3 @@ func removeNonFilterAbleContainers(containers []types.Container) []types.Contain } return final } - -func initProcExitEventChannel(ctx context.Context) chan *bpf.AgentProcessExitEvent { - ch := make(chan *bpf.AgentProcessExitEvent, 10) - go func() { - for { - select { - case <-ctx.Done(): - return - case evt := <-ch: - common.DeleteIfIdxToNameEntry(int(evt.Pid)) - } - } - }() - return ch -} diff --git a/bpf/loader/loader.go b/bpf/loader/loader.go index 477fb794..6385da86 100644 --- a/bpf/loader/loader.go +++ b/bpf/loader/loader.go @@ -7,6 +7,7 @@ import ( "fmt" ac "kyanos/agent/common" "kyanos/agent/compatible" + "kyanos/agent/metadata" "kyanos/agent/uprobe" "kyanos/bpf" "kyanos/common" @@ -116,11 +117,14 @@ func LoadBPF(options *ac.AgentOptions) (*BPF, error) { } options.LoadPorgressChannel <- "🍓 Setup traffic filters" - bpf.PullProcessExitEvents(options.Ctx, []chan *bpf.AgentProcessExitEvent{initProcExitEventChannel(options.Ctx)}) - return bf, nil } +const ( + execEventChannelBufferSize = 10 + exitEventChannelBufferSize = 10 +) + func (bf *BPF) AttachProgs(options *ac.AgentOptions) error { var links *list.List var err error @@ -135,16 +139,10 @@ func (bf *BPF) AttachProgs(options *ac.AgentOptions) error { options.LoadPorgressChannel <- "🍆 Attached base eBPF programs." + bf.attachExecEventChannels(options) + bf.attachExitEventChannels(options) if !options.DisableOpensslUprobe { - uprobeSchedEventChannel := make(chan *bpf.AgentProcessExecEvent, 10) - uprobe.StartHandleSchedExecEvent(uprobeSchedEventChannel) - execEventChannels := []chan *bpf.AgentProcessExecEvent{uprobeSchedEventChannel} - if options.ProcessExecEventChannel != nil { - execEventChannels = append(execEventChannels, options.ProcessExecEventChannel) - } - bpf.PullProcessExecEvents(options.Ctx, &execEventChannels) - - attachOpenSslUprobes(links, options, options.Kv, bf.Objs) + bf.attachOpenSslUprobes(links, options) options.LoadPorgressChannel <- "🍕 Attached ssl eBPF programs." } attachSchedProgs(links) @@ -154,6 +152,34 @@ func (bf *BPF) AttachProgs(options *ac.AgentOptions) error { return nil } +func (bf *BPF) attachExecEventChannels(options *ac.AgentOptions) { + execEventChannels := []chan *bpf.AgentProcessExecEvent{} + execEventChannelForMetadata := make(chan *bpf.AgentProcessExecEvent, execEventChannelBufferSize) + execEventChannels = append(execEventChannels, execEventChannelForMetadata) + metadata.StartHandleSchedExecEvent(execEventChannelForMetadata, options.Ctx) + if !options.DisableOpensslUprobe { + uprobeSchedEventChannel := make(chan *bpf.AgentProcessExecEvent, execEventChannelBufferSize) + uprobe.StartHandleSchedExecEvent(uprobeSchedEventChannel) + execEventChannels = append(execEventChannels, uprobeSchedEventChannel) + if options.ProcessExecEventChannel != nil { + execEventChannels = append(execEventChannels, options.ProcessExecEventChannel) + } + bpf.PullProcessExecEvents(options.Ctx, &execEventChannels) + } +} + +func (bf *BPF) attachOpenSslUprobes(links *list.List, options *ac.AgentOptions) { + attachOpenSslUprobes(links, options, options.Kv, bf.Objs) +} + +func (bf *BPF) attachExitEventChannels(options *ac.AgentOptions) { + exitEventChannels := []chan *bpf.AgentProcessExitEvent{} + exitEventChannelForMetadata := make(chan *bpf.AgentProcessExitEvent, exitEventChannelBufferSize) + metadata.StartHandleSchedExitEvent(exitEventChannelForMetadata, options.Ctx) + exitEventChannels = append(exitEventChannels, exitEventChannelForMetadata) + bpf.PullProcessExitEvents(options.Ctx, exitEventChannels) +} + // writeToFile writes the []uint8 slice to a specified file in the system's temp directory. // If the temp directory does not exist, it creates a ".kyanos" directory in the current directory. func writeToFile(data []uint8, filename string) (string, error) { diff --git a/bpf/pktlatency.bpf.c b/bpf/pktlatency.bpf.c index e496d9af..307f96e8 100644 --- a/bpf/pktlatency.bpf.c +++ b/bpf/pktlatency.bpf.c @@ -216,10 +216,6 @@ static __always_inline void report_kern_evt(struct parse_kern_evt_body *param) struct conn_id_s_t* conn_id_s = bpf_map_lookup_elem(&sock_key_conn_id_map, key); if (conn_id_s == NULL || conn_id_s->no_trace) { - // if (key->sport==3306&& step == DEV_IN) { - // bpf_printk("discard!"); - // print_sock_key(key); - // } return; } uint64_t tgid_fd = conn_id_s->tgid_fd; @@ -484,7 +480,6 @@ static __always_inline int parse_skb(void* ctx, struct sk_buff *skb, bool sk_not tcp_len -= ip_hdr_len; BPF_CORE_READ_INTO(&proto_l4, ipv4, protocol); l4 = ip + ip_hdr_len; - bpf_printk("ipipprotocol_l4:%d,len:%lld",proto_l4,_C(skb, len));//629 } } else if (l3_proto == ETH_P_IPV6) { proto_l4 = _(ipv6->nexthdr); diff --git a/common/net.go b/common/net.go deleted file mode 100644 index da805dea..00000000 --- a/common/net.go +++ /dev/null @@ -1,101 +0,0 @@ -package common - -import ( - "bufio" - "net" - "regexp" - "strconv" - "strings" - "sync" -) - -var ifIdxToName map[int]string = make(map[int]string) -var lock *sync.Mutex = &sync.Mutex{} - -func init() { - ifs, err := net.Interfaces() - if err == nil { - for _, each := range ifs { - ifIdxToName[each.Index] = each.Name - } - } -} - -func DeleteIfIdxToNameEntry(pid int) { - delete(ifIdxToName, pid) -} - -func GetInterfaceNameByIndex(index int, pid int) (string, error) { - exist, found := ifIdxToName[index] - if found { - return exist, nil - } - - netNs := GetNetworkNamespaceFromPid(pid) - var result string - - lock.Lock() - defer lock.Unlock() - if netNs == hostNetNs { - exist, found := ifIdxToName[index] - if found { - return exist, nil - } - interfc, err := net.InterfaceByIndex(index) - if err != nil { - result = "" - // return "", fmt.Errorf("GetInterfaceNameByIndex(%d) err: %v ", index, err) - } else { - result = interfc.Name - // ifIdxToName[interfc.Index] = interfc.Name - // return interfc.Name, nil - } - } else { - config := NsEnterConfig{ - Net: true, - Target: pid, - } - stdout, _, _ := config.Execute("sh", "-c", "ip a") - scanner := bufio.NewScanner(strings.NewReader(stdout)) - for scanner.Scan() { - line := scanner.Text() - if strings.TrimSpace(line) == "" || !strings.Contains(line, ":") { - continue - } - - parsedIndex, parsedName, ok := parseIpCmdLine(line) - if ok && index == parsedIndex { - result = parsedName - break - } - } - } - ifIdxToName[index] = result - return result, nil -} - -func parseIpCmdLine(line string) (int, string, bool) { - // 使用正则表达式匹配接口索引和接口名称 - // 假设接口索引是以数字开头,后面跟着冒号和接口名称 - re := regexp.MustCompile(`^(\d+):\s*([^:]+)`) - match := re.FindStringSubmatch(line) - if len(match) < 3 { - return 0, "", false // 没有匹配到 - } - - index, err := strconv.Atoi(match[1]) - if err != nil { - return 0, "", false // 转换索引失败 - } - - interfaceName := match[2] - return index, interfaceName, true // 成功提取 -} - -var ( - hostNetNs int64 -) - -func init() { - hostNetNs = GetNetworkNamespaceFromPid(1) -} diff --git a/testdata/ipip_test.sh b/testdata/ipip_test.sh index f77195fc..19931b96 100755 --- a/testdata/ipip_test.sh +++ b/testdata/ipip_test.sh @@ -1,34 +1,49 @@ #!/bin/bash -set -ex -# 创建两个网络命名空间 -ip netns add ns1 -ip netns add ns2 -# 创建两对 veth pair ,一端各挂在一个命名空间下 -ip link add v1 type veth peer name v1_p -ip link add v2 type veth peer name v2_p - -ip link set v1 netns ns1 -ip link set v2 netns ns2 -# 分别配置地址,并启用 -ip addr add 10.10.10.1/24 dev v1_p -ip link set v1_p up -ip addr add 10.10.20.1/24 dev v2_p -ip link set v2_p up - -ip netns exec ns1 ip addr add 10.10.10.2/24 dev v1 -ip netns exec ns1 ip link set v1 up -ip netns exec ns2 ip addr add 10.10.20.2/24 dev v2 -ip netns exec ns2 ip link set v2 up - -# 分别配置路由 -ip netns exec ns1 route add -net 10.10.20.0/24 gw 10.10.10.1 -ip netns exec ns2 route add -net 10.10.10.0/24 gw 10.10.20.1 - -# 创建 tun设备,并设置为ipip隧道 -# ip netns exec ns1 ip tunnel add tun1 mode ipip remote 10.10.20.2 local 10.10.10.2 -# ip netns exec ns1 ip link set tun1 up -# ip netns exec ns1 ip addr add 10.10.100.10 peer 10.10.200.10 dev tun1 - -# ip netns exec ns2 ip tunnel add tun2 mode ipip remote 10.10.10.2 local 10.10.20.2 -# ip netns exec ns2 ip link set tun2 up -# ip netns exec ns2 ip addr add 10.10.200.10 peer 10.10.100.10 dev tun2 \ No newline at end of file + +# Enable IPIP module +modprobe ipip + +# Create the namespaces +ip netns add host1 +ip netns add host2 +ip netns add internet + +# Create the topology +ip link add veth0 type veth peer name veth1 +ip link add veth2 type veth peer name veth3 + +ip link set veth0 netns host1 +ip link set veth1 netns internet +ip link set veth2 netns internet +ip link set veth3 netns host2 + +ip netns exec host1 ip addr add 172.16.10.2/24 dev veth0 +ip netns exec host1 ip link set veth0 up +ip netns exec host1 ip link set lo up + +ip netns exec host2 ip addr add 152.16.10.2/24 dev veth3 +ip netns exec host2 ip link set veth3 up +ip netns exec host2 ip link set lo up + +ip netns exec internet ip addr add 172.16.10.1/24 dev veth1 +ip netns exec internet ip link set veth1 up +ip netns exec internet ip addr add 152.16.10.1/24 dev veth2 +ip netns exec internet ip link set veth2 up +ip netns exec internet ip link set lo up +ip netns exec internet sysctl -w net.ipv4.ip_forward=1 + +# Create gre tunnel on host1 +ip netns exec host1 ip tunnel add tun0 mode ipip local 172.16.10.2 remote 152.16.10.2 ttl 255 +ip netns exec host1 ip addr add 192.168.50.1/30 dev tun0 +ip netns exec host1 ip link set tun0 up + +# Create gre tunnel on host2 +ip netns exec host2 ip tunnel add tun0 mode ipip local 152.16.10.2 remote 172.16.10.2 ttl 255 +ip netns exec host2 ip addr add 192.168.50.2/30 dev tun0 +ip netns exec host2 ip link set tun0 up + +# Add static route +ip netns exec host1 route add -net 152.16.10.0/24 gw 172.16.10.1 +ip netns exec host2 route add -net 172.16.10.0/24 gw 152.16.10.1 + +echo "Setup done." \ No newline at end of file diff --git a/testdata/ipip_test3.sh b/testdata/ipip_test3.sh deleted file mode 100644 index d0ec5fb3..00000000 --- a/testdata/ipip_test3.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -set -ex -# -- A -ip link add name mybr0 type bridge -ip addr add 10.42.1.1/24 dev mybr0 -ip link set dev mybr0 up - -ip tunnel add tunl1 mode ipip remote 10.0.4.2 local 10.0.4.9 -ip addr add 10.42.1.1/24 dev tunl1 -ip link set tunl1 up - -# 为了保证我们通过创建的 IPIP 隧道来访问两个不同主机上的子网,我们需要手动添加如下静态路由 -ip route add 10.42.2.0/24 dev tunl1 - -# -- B -ip link add name mybr0 type bridge -ip addr add 10.42.2.1/24 dev mybr0 -ip link set dev mybr0 up - -ip tunnel add tunl1 mode ipip remote 10.0.4.9 local 10.0.4.2 -ip addr add 10.42.2.1/24 dev tunl1 -ip link set tunl1 up - -ip route add 10.42.1.0/24 dev tunl1 \ No newline at end of file diff --git a/testdata/ipip_test_clean.sh b/testdata/ipip_test_clean.sh index bf863d51..41588b4d 100755 --- a/testdata/ipip_test_clean.sh +++ b/testdata/ipip_test_clean.sh @@ -1,8 +1,6 @@ #!/bin/bash set -ex - -ip link del v1_p -ip link del v2_p -ip netns del ns1 -ip netns del ns2 \ No newline at end of file +ip netns del host1 +ip netns del host2 +ip netns del internet \ No newline at end of file