Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bpf/ssl): first HTTPS request on the server side might not be captured #259

Merged
merged 1 commit into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 26 additions & 10 deletions agent/analysis/stat.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,33 +202,38 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
annotatedRecord.ReqPlainTextSize = events.ingressMessage.ByteSize()
annotatedRecord.RespPlainTextSize = events.egressMessage.ByteSize()
}
canCalculateReadPathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslReadSyscallEvents)
canCalculateWritePathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslWriteSyscallEvents)
annotatedRecord.ReqSize = events.ingressKernLen
annotatedRecord.RespSize = events.egressKernLen
if annotatedRecord.StartTs != math.MaxUint64 && hasDevOutEvents {
if annotatedRecord.StartTs != math.MaxUint64 && hasDevOutEvents &&
(canCalculateReadPathTime && canCalculateWritePathTime) {
annotatedRecord.TotalDuration = float64(annotatedRecord.EndTs) - float64(annotatedRecord.StartTs)
}
if hasReadSyscallEvents && hasWriteSyscallEvents {
if hasReadSyscallEvents && hasWriteSyscallEvents && canCalculateReadPathTime && canCalculateWritePathTime {
annotatedRecord.BlackBoxDuration = float64(events.writeSyscallEvents[len(events.writeSyscallEvents)-1].GetEndTs()) - float64(events.readSyscallEvents[0].GetStartTs())
} else {
annotatedRecord.BlackBoxDuration = float64(events.egressMessage.TimestampNs()) - float64(events.ingressMessage.TimestampNs())
}
if hasUserCopyEvents && hasTcpInEvents {
if hasUserCopyEvents && hasTcpInEvents && canCalculateReadPathTime {
annotatedRecord.ReadFromSocketBufferDuration = float64(events.userCopyEvents[len(events.userCopyEvents)-1].GetStartTs()) - float64(events.tcpInEvents[0].GetStartTs())
}
if hasTcpInEvents && hasNicInEvents {
if hasTcpInEvents && hasNicInEvents && canCalculateWritePathTime {
annotatedRecord.CopyToSocketBufferDuration = float64(events.tcpInEvents[len(events.tcpInEvents)-1].GetStartTs() - events.nicIngressEvents[0].GetStartTs())
}
annotatedRecord.ReqSyscallEventDetails = KernEventsToEventDetails[analysisCommon.SyscallEventDetail](events.readSyscallEvents)
annotatedRecord.RespSyscallEventDetails = KernEventsToEventDetails[analysisCommon.SyscallEventDetail](events.writeSyscallEvents)
annotatedRecord.ReqNicEventDetails = KernEventsToNicEventDetails(events.nicIngressEvents)
annotatedRecord.RespNicEventDetails = KernEventsToNicEventDetails(events.devOutEvents)
} else {
if hasWriteSyscallEvents {
canCalculateReadPathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslReadSyscallEvents)
canCalculateWritePathTime := !connection.IsSsl() || isKernEvtCanMatchSslEvt(events.sslWriteSyscallEvents)
if hasWriteSyscallEvents && canCalculateWritePathTime {
annotatedRecord.StartTs = findMinTimestamp(events.writeSyscallEvents, true)
} else {
annotatedRecord.StartTs = events.egressMessage.TimestampNs()
}
if hasReadSyscallEvents {
if hasReadSyscallEvents && canCalculateReadPathTime {
annotatedRecord.EndTs = findMaxTimestamp(events.readSyscallEvents, false)
} else {
annotatedRecord.EndTs = events.ingressMessage.TimestampNs()
Expand All @@ -239,12 +244,12 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
}
annotatedRecord.ReqSize = events.egressKernLen
annotatedRecord.RespSize = events.ingressKernLen
if hasReadSyscallEvents && hasWriteSyscallEvents {
if hasReadSyscallEvents && hasWriteSyscallEvents && canCalculateReadPathTime && canCalculateWritePathTime {
annotatedRecord.TotalDuration = float64(annotatedRecord.EndTs) - float64(annotatedRecord.StartTs)
} else {
annotatedRecord.TotalDuration = float64(events.ingressMessage.TimestampNs()) - float64(events.egressMessage.TimestampNs())
}
if hasNicInEvents && hasDevOutEvents {
if hasNicInEvents && hasDevOutEvents && canCalculateReadPathTime && canCalculateWritePathTime {
nicIngressTimestamp := int64(0)
for _, nicIngressEvent := range events.nicIngressEvents {
_nicIngressTimestamp, _, ok := nicIngressEvent.GetMinIfItmestampAttr()
Expand All @@ -271,7 +276,7 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
annotatedRecord.BlackBoxDuration = -1
}
}
if (hasUserCopyEvents || hasReadSyscallEvents) && hasTcpInEvents {
if (hasUserCopyEvents || hasReadSyscallEvents) && hasTcpInEvents && canCalculateReadPathTime {
var readFromEndTime float64
if hasUserCopyEvents {
readFromEndTime = float64(events.userCopyEvents[len(events.userCopyEvents)-1].GetStartTs())
Expand All @@ -280,7 +285,7 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
}
annotatedRecord.ReadFromSocketBufferDuration = readFromEndTime - float64(events.tcpInEvents[0].GetStartTs())
}
if hasTcpInEvents && hasNicInEvents {
if hasTcpInEvents && hasNicInEvents && canCalculateReadPathTime {
annotatedRecord.CopyToSocketBufferDuration = float64(events.tcpInEvents[len(events.tcpInEvents)-1].GetStartTs() - events.nicIngressEvents[0].GetStartTs())
}
annotatedRecord.ReqSyscallEventDetails = KernEventsToEventDetails[analysisCommon.SyscallEventDetail](events.writeSyscallEvents)
Expand Down Expand Up @@ -319,6 +324,17 @@ func (s *StatRecorder) ReceiveRecord(r protocol.Record, connection *conn.Connect
return nil
}

// some syscalls are not nested int ssl events, so we need to check if all ssl events have kernLen>0
// otherwise, we can't calculate the duration related to kern events because the kern seq is not valid
func isKernEvtCanMatchSslEvt(events []conn.SslEvent) bool {
for _, each := range events {
if each.KernLen == 0 {
return false
}
}
return true
}

func findMaxTimestamp(events []conn.KernEvent, useStartTs bool) uint64 {
var maxTimestamp uint64 = 0
for _, each := range events {
Expand Down
26 changes: 15 additions & 11 deletions agent/conn/conntrack.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type Connection4 struct {

ssl bool

tracable bool
tracable bpf.AgentConnTraceStateT
onRoleChanged func()

TempKernEvents []*bpf.AgentKernEvt
Expand Down Expand Up @@ -75,7 +75,7 @@ func NewConnFromEvent(event *bpf.AgentConnEvtT, p *Processor) *Connection4 {
Role: event.ConnInfo.Role,
TgidFd: TgidFd,
Status: Connected,
tracable: true,
tracable: bpf.AgentConnTraceStateTUnset,

MessageFilter: p.messageFilter,
LatencyFilter: p.latencyFilter,
Expand Down Expand Up @@ -330,37 +330,37 @@ func (c *Connection4) OnClose(needClearBpfMap bool) {
monitor.UnregisterMetricExporter(c.StreamEvents)
}

func (c *Connection4) UpdateConnectionTraceable(traceable bool) {
if c.tracable == traceable {
func (c *Connection4) UpdateConnectionTraceable(traceableState bpf.AgentConnTraceStateT) {
if c.tracable == traceableState {
return
}
c.tracable = traceable
c.tracable = traceableState
key, _ := c.extractSockKeys()
sockKeyConnIdMap := bpf.GetMapFromObjs(bpf.Objs, "SockKeyConnIdMap")
c.doUpdateConnIdMapProtocolToUnknwon(key, sockKeyConnIdMap, traceable)
c.doUpdateConnIdMapProtocolToUnknwon(key, sockKeyConnIdMap, traceableState)
// c.doUpdateConnIdMapProtocolToUnknwon(revKey, sockKeyConnIdMap, traceable)

connInfoMap := bpf.GetMapFromObjs(bpf.Objs, "ConnInfoMap")
connInfo := bpf.AgentConnInfoT{}
err := connInfoMap.Lookup(c.TgidFd, &connInfo)
if err == nil {
connInfo.NoTrace = !traceable
connInfo.NoTrace = traceableState
connInfoMap.Update(c.TgidFd, &connInfo, ebpf.UpdateExist)
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v success!", c.ToString(), traceable)
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v success!", c.ToString(), traceableState)
}
} else {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v, but no entry in map found!", c.ToString(), traceable)
common.ConntrackLog.Debugf("try to update %s conn_info_map to traceable: %v, but no entry in map found!", c.ToString(), traceableState)
}
}
}

func (c *Connection4) doUpdateConnIdMapProtocolToUnknwon(key bpf.AgentSockKey, m *ebpf.Map, traceable bool) {
func (c *Connection4) doUpdateConnIdMapProtocolToUnknwon(key bpf.AgentSockKey, m *ebpf.Map, traceable bpf.AgentConnTraceStateT) {
var connIds bpf.AgentConnIdS_t
err := m.Lookup(&key, &connIds)
if err == nil {
connIds.NoTrace = !traceable
connIds.NoTrace = traceable
m.Update(&key, &connIds, ebpf.UpdateExist)
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("try to update %s conn_id_map to traceable: %v, success, sock key: %v", c.ToString(), traceable, key)
Expand All @@ -372,6 +372,10 @@ func (c *Connection4) doUpdateConnIdMapProtocolToUnknwon(key bpf.AgentSockKey, m
}
}

func (c *Connection4) IsTraceble() bool {
return c.tracable <= bpf.AgentConnTraceStateTTraceable
}

// func (c *Connection4) OnCloseWithoutClearBpfMap() {
// c.OnClose(false)
// }
Expand Down
4 changes: 2 additions & 2 deletions agent/conn/first_packet_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func (p *FirstPacketProcessor) processEvent(event *timedFirstPacketEvent) {
channel := p.channels[int(conn.(*Connection4).TgidFd)%len(p.channels)]
connId := &bpf.AgentConnIdS_t{
TgidFd: conn.(*Connection4).TgidFd,
NoTrace: false,
NoTrace: conn.(*Connection4).tracable,
}
common.BPFEventLog.Debugf("%s First packet event: %+v", conn.(*Connection4).ToString(), event.FirstPacketEvent)
kernEvent := timedFirstPacketEventAsKernEvent(event, connId)
Expand All @@ -90,7 +90,7 @@ func (p *FirstPacketProcessor) extractTgidFdFromSockKey(key *bpf.AgentSockKey) (
sockKeyConnIdMap := bpf.GetMapFromObjs(bpf.Objs, "SockKeyConnIdMap")
var connIds bpf.AgentConnIdS_t
err := sockKeyConnIdMap.Lookup(key, &connIds)
if err == nil && !connIds.NoTrace {
if err == nil && connIds.NoTrace <= bpf.AgentConnTraceStateTTraceable {
return &connIds, nil
}
return nil, err
Expand Down
20 changes: 14 additions & 6 deletions agent/conn/processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ func (p *Processor) run() {
// previousProtocol := conn.Protocol
if conn != nil && conn.Status != Closed {
conn.Protocol = event.ConnInfo.Protocol
common.ConntrackLog.Debugf("[protocol-infer][%s] protocol updated: %d", conn.ToString(), conn.Protocol)
} else {
if conn == nil {
missedConn := NewConnFromEvent(event, p)
Expand All @@ -226,6 +227,7 @@ func (p *Processor) run() {
p.connManager.AddConnection4(TgidFd, missedConn)
conn = missedConn
} else {
common.ConntrackLog.Debugf("[protocol-infer][%s] protocol not updated: %d", conn.ToString(), conn.Protocol)
continue
}
}
Expand Down Expand Up @@ -258,7 +260,7 @@ func (p *Processor) run() {
}
}
conn.TempSslEvents = conn.TempSslEvents[0:0]
conn.UpdateConnectionTraceable(true)
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTTraceable)
// handle kern events
for _, kernEvent := range conn.TempKernEvents {
if conn.timeBoundCheck(kernEvent.Ts) {
Expand All @@ -275,7 +277,13 @@ func (p *Processor) run() {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("%s discarded due to not interested, isProtocolInterested: %v, isSideNotMatched:%v", conn.ToString(), isProtocolInterested, isSideNotMatched(p, conn))
}
conn.UpdateConnectionTraceable(false)
if conn.Protocol == bpf.AgentTrafficProtocolTKProtocolUnknown {
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTProtocolUnknown)
} else if !isProtocolInterested {
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTProtocolNotMatched)
} else {
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTOther)
}
// conn.OnClose(true)
}
}
Expand Down Expand Up @@ -464,7 +472,7 @@ func (p *Processor) processSyscallEvent(event *bpf.SyscallEventData, recordChann
}
return
}
if conn != nil && !conn.tracable {
if conn != nil && !conn.IsTraceble() {
if common.BPFEventLog.Level >= logrus.DebugLevel {
common.BPFEventLog.Debugf("[syscall][no-trace][len=%d][ts=%d]%s | %s", event.SyscallEvent.BufSize, event.SyscallEvent.Ke.Ts, conn.ToString(), string(event.Buf))
}
Expand Down Expand Up @@ -538,7 +546,7 @@ func (p *Processor) processSslEvent(event *bpf.SslData, recordChannel chan Recor
}
return
}
if conn != nil && !conn.tracable {
if conn != nil && !conn.IsTraceble() {
conn.AddSslEvent(event)
if common.BPFEventLog.Level >= logrus.DebugLevel {
common.BPFEventLog.Debugf("[ssl][no-trace][len=%d][ts=%d]%s | %s", event.SslEventHeader.BufSize, event.SslEventHeader.Ke.Ts, conn.ToString(), string(event.Buf))
Expand Down Expand Up @@ -576,12 +584,12 @@ func onRoleChanged(p *Processor, conn *Connection4) {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("[onRoleChanged] %s discarded due to not matched by side", conn.ToString())
}
conn.UpdateConnectionTraceable(false)
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTOther)
} else {
if common.ConntrackLog.Level >= logrus.DebugLevel {
common.ConntrackLog.Debugf("[onRoleChanged] %s actived due to matched by side", conn.ToString())
}
conn.UpdateConnectionTraceable(true)
conn.UpdateConnectionTraceable(bpf.AgentConnTraceStateTTraceable)
}
}

Expand Down
19 changes: 15 additions & 4 deletions bpf/agent_arm64_bpfel.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 15 additions & 4 deletions bpf/agent_x86_bpfel.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 12 additions & 8 deletions bpf/data_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ static __inline bool should_trace_conn(struct conn_info_t *conn_info) {
// return true;
// }

return conn_info->protocol != kProtocolUnknown && !conn_info->no_trace;
return conn_info->protocol != kProtocolUnknown && conn_info->no_trace <= traceable ;
}

static void __always_inline report_syscall_buf_without_data(void* ctx, uint64_t seq, struct conn_id_s_t *conn_id_s, size_t len, enum step_t step, uint64_t ts, uint32_t ts_delta, enum source_function_t source_fn) {
Expand Down Expand Up @@ -276,7 +276,7 @@ static __always_inline void process_sendfile_with_conn_info(void* ctx, struct se
} else {
step = direct == kEgress ? SYSCALL_OUT : SYSCALL_IN;
}
if (conn_info->protocol != kProtocolUnknown && (!conn_info->no_trace)) {//, bytes_count
if (conn_info->protocol != kProtocolUnknown && ( conn_info->no_trace <= traceable)) {//, bytes_count
report_syscall_buf_without_data(ctx, seq, &conn_id_s, bytes_count, step, args->start_ts, args->end_ts - args->start_ts, kSyscallSendfile);
}
}
Expand All @@ -286,12 +286,12 @@ static __always_inline void process_syscall_data_with_conn_info(void* ctx, struc
bool inferred = false;
if ((conn_info->protocol == kProtocolUnset || conn_info->protocol == kProtocolUnknown) && with_data) {
enum traffic_protocol_t before_infer = conn_info->protocol;
// bpf_printk("[protocol infer]:start, bc:%d", bytes_count);
// bpf_printk("SSL[protocol infer]:start, bc:%d", bytes_count);
// conn_info->protocol = protocol_message.protocol;
struct protocol_message_t protocol_message = infer_protocol(args->buf, bytes_count, conn_info);
if (before_infer != protocol_message.protocol) {
conn_info->protocol = protocol_message.protocol;
// bpf_printk("[protocol infer]: %d, func: %d", conn_info->protocol, args->source_fn);
// bpf_printk("SSL[protocol infer]: %d, func: %d", conn_info->protocol, args->source_fn);

if (conn_info->role == kRoleUnknown && protocol_message.type != kUnknown) {
conn_info->role = ((direct == kEgress) ^ (protocol_message.type == kResponse))
Expand All @@ -313,20 +313,24 @@ static __always_inline void process_syscall_data_with_conn_info(void* ctx, struc
} else {
step = direct == kEgress ? SYSCALL_OUT : SYSCALL_IN;
}

if (conn_info->protocol != kProtocolUnknown && (inferred || !conn_info->no_trace)) {//, bytes_count

if (conn_info->protocol != kProtocolUnknown && (inferred || conn_info->no_trace <= traceable) ||
// condition below is for the case when protocol is already inffered in previous syscall
// but user space have not yet updated the conn_info.no_trace to traceable.
// so when conn_info.protocol is not unknown but the cause of trace state is unknown, we still trace data.
(conn_info->protocol != kProtocolUnknown && conn_info->no_trace == protocol_unknown)) {
if (is_ssl) {
uint64_t syscall_seq = (direct == kEgress ? conn_info->write_bytes : conn_info->read_bytes) + 1;
seq = (direct == kEgress ? conn_info->ssl_write_bytes : conn_info->ssl_read_bytes) + 1;
report_ssl_evt(ctx, seq, &conn_id_s, bytes_count, step, args, syscall_len < 0 ? 0 : (syscall_seq - syscall_len), syscall_len < 0 ? 0 : syscall_len);
// bpf_printk("report ssl evt, seq: %lld len: %d", seq, bytes_count);
// bpf_printk("SSLreport ssl evt, seq: %lld len: %d, syscall_len:%d", seq, bytes_count, syscall_len);
} else if (with_data) {
report_syscall_evt(ctx, seq, &conn_id_s, bytes_count, step, args);
} else {
report_syscall_buf_without_data(ctx, seq, &conn_id_s, bytes_count, step, args->start_ts, args->end_ts - args->start_ts, args->source_fn);
}
} else {
// bpf_printk("no trace, bytes_count:%d", bytes_count);
// bpf_printk("SSLno trace, bytes_count:%d,p:%d,infer:%d", bytes_count,conn_info->protocol,inferred);
}
}

Expand Down
Loading
Loading