diff --git a/pkg/ebpf/common/common.go b/pkg/ebpf/common/common.go index bcf2b6c41..ed5753baf 100644 --- a/pkg/ebpf/common/common.go +++ b/pkg/ebpf/common/common.go @@ -174,7 +174,7 @@ type EBPFParseContext struct { h2c *lru.Cache[uint64, h2Connection] redisDBCache *simplelru.LRU[BpfConnectionInfoT, int] couchbaseBucketCache *simplelru.LRU[BpfConnectionInfoT, CouchbaseBucketInfo] - largeBuffers *expirable.LRU[largeBufferKey, *largeBuffer] + largeBuffers *expirable.LRU[largeBufferKey, *LargeBuffer] mongoRequestCache PendingMongoDBRequests mysqlPreparedStatements *simplelru.LRU[mysqlPreparedStatementsKey, string] postgresPreparedStatements *simplelru.LRU[postgresPreparedStatementsKey, string] @@ -213,7 +213,7 @@ func NewEBPFParseContext(cfg *config.EBPFTracer, spansChan *msg.Queue[[]request. ) h2c, _ := lru.New[uint64, h2Connection](1024 * 10) - largeBuffers := expirable.NewLRU[largeBufferKey, *largeBuffer](1024, nil, 5*time.Minute) + largeBuffers := expirable.NewLRU[largeBufferKey, *LargeBuffer](1024, nil, 5*time.Minute) if cfg != nil { protocolDebug = cfg.ProtocolDebug diff --git a/pkg/ebpf/common/couchbase_detect_transform.go b/pkg/ebpf/common/couchbase_detect_transform.go index 7cc154cc5..9ba574c29 100644 --- a/pkg/ebpf/common/couchbase_detect_transform.go +++ b/pkg/ebpf/common/couchbase_detect_transform.go @@ -31,12 +31,14 @@ type CouchbaseInfo struct { // ProcessPossibleCouchbaseEvent attempts to parse the event as a Couchbase memcached binary protocol event. // Returns a slice of CouchbaseInfo if successful, along with a boolean indicating if the event should be ignored, // and an error if parsing failed. Multiple packets may be present in a single TCP segment due to pipelining. -func ProcessPossibleCouchbaseEvent(event *TCPRequestInfo, requestBuf []byte, responseBuf []byte, bucketCache *simplelru.LRU[BpfConnectionInfoT, CouchbaseBucketInfo]) (*CouchbaseInfo, bool, error) { - info, ignore, err := processCouchbaseEvent(event.ConnInfo, requestBuf, responseBuf, bucketCache) +func ProcessPossibleCouchbaseEvent(event *TCPRequestInfo, requestBuf *LargeBuffer, responseBuf *LargeBuffer, bucketCache *simplelru.LRU[BpfConnectionInfoT, CouchbaseBucketInfo]) (*CouchbaseInfo, bool, error) { + reqRaw := requestBuf.UnsafeView() + respRaw := responseBuf.UnsafeView() + info, ignore, err := processCouchbaseEvent(event.ConnInfo, reqRaw, respRaw, bucketCache) // If parsing failed (error or no valid packets found), try with buffers reversed if err != nil { // Try with buffers reversed - we might have captured it backwards - info, ignore, err = processCouchbaseEvent(event.ConnInfo, responseBuf, requestBuf, bucketCache) + info, ignore, err = processCouchbaseEvent(event.ConnInfo, respRaw, reqRaw, bucketCache) if err == nil { reverseTCPEvent(event) return info, false, nil diff --git a/pkg/ebpf/common/couchbase_detect_transform_test.go b/pkg/ebpf/common/couchbase_detect_transform_test.go index 621df48fb..c0ffccfa4 100644 --- a/pkg/ebpf/common/couchbase_detect_transform_test.go +++ b/pkg/ebpf/common/couchbase_detect_transform_test.go @@ -551,7 +551,7 @@ func TestProcessPossibleCouchbaseEventReversedBuffers(t *testing.T) { } // Test with buffers in correct order - info, ignore, err := ProcessPossibleCouchbaseEvent(event, requestBuf, responseBuf, cache) + info, ignore, err := ProcessPossibleCouchbaseEvent(event, NewLargeBufferFrom(requestBuf), NewLargeBufferFrom(responseBuf), cache) require.NoError(t, err) assert.False(t, ignore) assert.Equal(t, "GET", info.Operation) @@ -563,7 +563,7 @@ func TestProcessPossibleCouchbaseEventReversedBuffers(t *testing.T) { ConnInfo: connInfo, Direction: 1, } - info, ignore, err = ProcessPossibleCouchbaseEvent(event2, garbageBuf, requestBuf, cache) + info, ignore, err = ProcessPossibleCouchbaseEvent(event2, NewLargeBufferFrom(garbageBuf), NewLargeBufferFrom(requestBuf), cache) require.NoError(t, err) assert.False(t, ignore) require.NotNil(t, info) @@ -594,7 +594,7 @@ func TestProcessPossibleCouchbaseEventConnectionIsolation(t *testing.T) { selectBucket1Resp := makeCouchbaseResponsePacket(couchbasekv.OpcodeSelectBucket, couchbasekv.StatusSuccess, "") event1 := &TCPRequestInfo{ConnInfo: connInfo1, Direction: 1} - _, ignore, err := ProcessPossibleCouchbaseEvent(event1, selectBucket1Req, selectBucket1Resp, cache) + _, ignore, err := ProcessPossibleCouchbaseEvent(event1, NewLargeBufferFrom(selectBucket1Req), NewLargeBufferFrom(selectBucket1Resp), cache) require.NoError(t, err) assert.True(t, ignore) // SELECT_BUCKET is ignored @@ -603,7 +603,7 @@ func TestProcessPossibleCouchbaseEventConnectionIsolation(t *testing.T) { selectBucket2Resp := makeCouchbaseResponsePacket(couchbasekv.OpcodeSelectBucket, couchbasekv.StatusSuccess, "") event2 := &TCPRequestInfo{ConnInfo: connInfo2, Direction: 1} - _, ignore, err = ProcessPossibleCouchbaseEvent(event2, selectBucket2Req, selectBucket2Resp, cache) + _, ignore, err = ProcessPossibleCouchbaseEvent(event2, NewLargeBufferFrom(selectBucket2Req), NewLargeBufferFrom(selectBucket2Resp), cache) require.NoError(t, err) assert.True(t, ignore) @@ -612,7 +612,7 @@ func TestProcessPossibleCouchbaseEventConnectionIsolation(t *testing.T) { getCollID1Resp := makeCouchbaseResponsePacket(couchbasekv.OpcodeCollectionsGetID, couchbasekv.StatusSuccess, "") event1 = &TCPRequestInfo{ConnInfo: connInfo1, Direction: 1} - _, ignore, err = ProcessPossibleCouchbaseEvent(event1, getCollID1Req, getCollID1Resp, cache) + _, ignore, err = ProcessPossibleCouchbaseEvent(event1, NewLargeBufferFrom(getCollID1Req), NewLargeBufferFrom(getCollID1Resp), cache) require.NoError(t, err) assert.True(t, ignore) @@ -621,7 +621,7 @@ func TestProcessPossibleCouchbaseEventConnectionIsolation(t *testing.T) { getCollID2Resp := makeCouchbaseResponsePacket(couchbasekv.OpcodeCollectionsGetID, couchbasekv.StatusSuccess, "") event2 = &TCPRequestInfo{ConnInfo: connInfo2, Direction: 1} - _, ignore, err = ProcessPossibleCouchbaseEvent(event2, getCollID2Req, getCollID2Resp, cache) + _, ignore, err = ProcessPossibleCouchbaseEvent(event2, NewLargeBufferFrom(getCollID2Req), NewLargeBufferFrom(getCollID2Resp), cache) require.NoError(t, err) assert.True(t, ignore) @@ -631,7 +631,7 @@ func TestProcessPossibleCouchbaseEventConnectionIsolation(t *testing.T) { // GET on connection 1 should have bucket1/scope1/coll1 event1 = &TCPRequestInfo{ConnInfo: connInfo1, Direction: 1} - info1, ignore, err := ProcessPossibleCouchbaseEvent(event1, getReq, getResp, cache) + info1, ignore, err := ProcessPossibleCouchbaseEvent(event1, NewLargeBufferFrom(getReq), NewLargeBufferFrom(getResp), cache) require.NoError(t, err) assert.False(t, ignore) require.NotNil(t, info1) @@ -641,7 +641,7 @@ func TestProcessPossibleCouchbaseEventConnectionIsolation(t *testing.T) { // GET on connection 2 should have bucket2/scope2/coll2 event2 = &TCPRequestInfo{ConnInfo: connInfo2, Direction: 1} - info2, ignore, err := ProcessPossibleCouchbaseEvent(event2, getReq, getResp, cache) + info2, ignore, err := ProcessPossibleCouchbaseEvent(event2, NewLargeBufferFrom(getReq), NewLargeBufferFrom(getResp), cache) require.NoError(t, err) assert.False(t, ignore) require.NotNil(t, info2) diff --git a/pkg/ebpf/common/fast_cgi_detect_transform.go b/pkg/ebpf/common/fast_cgi_detect_transform.go index 727fab366..273fe0ccb 100644 --- a/pkg/ebpf/common/fast_cgi_detect_transform.go +++ b/pkg/ebpf/common/fast_cgi_detect_transform.go @@ -108,50 +108,51 @@ func parseCGITable(b []byte) map[string]string { return res } -func maybeFastCGI(b []byte) bool { - if len(b) <= fastCGIRequestHeaderLen { +func maybeFastCGI(b *LargeBuffer) bool { + if b.Len() <= fastCGIRequestHeaderLen { return false } - - methodPos := bytes.Index(b, []byte(requestMethodKey)) - - return methodPos >= 0 + return bytes.Contains(b.UnsafeView(), []byte(requestMethodKey)) } -func parseHeader(b []byte) ([]byte, error) { +func parseHeader(b *LargeBuffer) ([]byte, error) { + r := b.NewReader() for { - hdr, err := readFastCGIHeader(b) + if r.Remaining() < fastCGIRequestHeaderLen { + return nil, errors.New("payload too short") + } + hdrBytes, err := r.ReadN(fastCGIRequestHeaderLen) + if err != nil { + return nil, errors.New("payload too short") + } + hdr, err := readFastCGIHeader(hdrBytes) if err != nil { return nil, errors.New("payload too short") } if hdr.Type == fcgiFrameTypeParams { - if len(b) <= fastCGIRequestHeaderLen { + if r.Remaining() == 0 { return nil, errors.New("payload too short") } - b = b[fastCGIRequestHeaderLen:] - break + rest, _ := r.ReadN(r.Remaining()) + return rest, nil } - payloadOffset := int(fastCGIRequestHeaderLen + hdr.ContentLength + uint16(hdr.PaddingLength)) - if len(b) <= payloadOffset { + payloadOffset := int(hdr.ContentLength) + int(hdr.PaddingLength) + if err := r.Skip(payloadOffset); err != nil { return nil, errors.New("payload too short") } - b = b[payloadOffset:] } - - return b, nil } -func detectFastCGI(b, rb []byte) (string, string, int) { - var err error - b, err = parseHeader(b) +func detectFastCGI(b, rb *LargeBuffer) (string, string, int) { + raw, err := parseHeader(b) if err != nil { return "", "", -1 } - methodPos := bytes.Index(b, []byte(requestMethodKey)) + methodPos := bytes.Index(raw, []byte(requestMethodKey)) if methodPos >= 0 { - kv := parseCGITable(b) + kv := parseCGITable(raw) method, ok := kv[requestMethodKey] if !ok { @@ -162,17 +163,18 @@ func detectFastCGI(b, rb []byte) (string, string, int) { // Translate the status code into HTTP, 200 OK, 500 ERR status := 200 - if len(rb) >= 2 { - if rb[1] == responseError { + rbRaw := rb.UnsafeView() + if len(rbRaw) >= 2 { + if rbRaw[1] == responseError { status = 500 } - statusPos := bytes.Index(rb, []byte(responseStatusKey)) + statusPos := bytes.Index(rbRaw, []byte(responseStatusKey)) if statusPos >= 0 { - rb = rb[statusPos+len(responseStatusKey):] - nextSpace := bytes.Index(rb, []byte(" ")) + rbRaw = rbRaw[statusPos+len(responseStatusKey):] + nextSpace := bytes.Index(rbRaw, []byte(" ")) if nextSpace > 0 { - statusStr := string(rb[:nextSpace]) + statusStr := string(rbRaw[:nextSpace]) if parsed, err := strconv.ParseInt(statusStr, 10, 32); err == nil { status = int(parsed) } diff --git a/pkg/ebpf/common/fast_cgi_detect_transform_test.go b/pkg/ebpf/common/fast_cgi_detect_transform_test.go index ef22f60db..f1857955f 100644 --- a/pkg/ebpf/common/fast_cgi_detect_transform_test.go +++ b/pkg/ebpf/common/fast_cgi_detect_transform_test.go @@ -45,7 +45,7 @@ func TestMaybeFastCGI(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ilen := min(len(tt.input), tt.inputLen) - res := maybeFastCGI(tt.input[0:ilen]) + res := maybeFastCGI(NewLargeBufferFrom(tt.input[0:ilen])) assert.Equal(t, tt.expected, res) }) } @@ -214,7 +214,7 @@ func TestDetectFastCGI(t *testing.T) { t.Run(tt.name, func(t *testing.T) { ilen := min(len(tt.input), tt.inputLen) olen := min(len(tt.output), tt.outputLen) - method, path, status := detectFastCGI(tt.input[0:ilen], tt.output[0:olen]) + method, path, status := detectFastCGI(NewLargeBufferFrom(tt.input[0:ilen]), NewLargeBufferFrom(tt.output[0:olen])) assert.Equal(t, tt.expectedMethod, method) assert.Equal(t, tt.expectedPath, path) assert.Equal(t, tt.expectedResult, status) diff --git a/pkg/ebpf/common/go_kafka_transform.go b/pkg/ebpf/common/go_kafka_transform.go index d46bcf0c4..f71b84a90 100644 --- a/pkg/ebpf/common/go_kafka_transform.go +++ b/pkg/ebpf/common/go_kafka_transform.go @@ -19,7 +19,7 @@ func ReadGoSaramaRequestIntoSpan(record *ringbuf.Record) (request.Span, bool, er return request.Span{}, true, err } - info, ignore, err := ProcessKafkaRequest(event.Buf[:], nil) + info, ignore, err := ProcessKafkaRequest(NewLargeBufferFrom(event.Buf[:]).NewReader(), nil) if err == nil && !ignore { return GoKafkaSaramaToSpan(event, info), false, nil diff --git a/pkg/ebpf/common/http2grpc_transform.go b/pkg/ebpf/common/http2grpc_transform.go index d0cedfc14..307e0fb18 100644 --- a/pkg/ebpf/common/http2grpc_transform.go +++ b/pkg/ebpf/common/http2grpc_transform.go @@ -544,16 +544,16 @@ func isLikelyHTTP2(data []uint8, eventLen int) bool { return false } -func isHTTP2(data []uint8, eventLen int) bool { +func isHTTP2(data *LargeBuffer, eventLen int) bool { // Parsing HTTP2 frames with the Go HTTP2/gRPC parser is very expensive. // Therefore, we replicate some of our HTTP2 frame reader from eBPF here to // check if this payload even remotely looks like HTTP2/gRPC, e.g. we must // find a resonably looking HTTP "headers" frame. - if !isLikelyHTTP2(data, eventLen) { + if !isLikelyHTTP2(data.UnsafeView(), eventLen) { return false } - framer := byteFramer(data) + framer := http2.NewFramer(io.Discard, data.NewReader()) for { f, err := framer.ReadFrame() diff --git a/pkg/ebpf/common/http2grpc_transform_test.go b/pkg/ebpf/common/http2grpc_transform_test.go index 0a3b1ed06..9a3a43a19 100644 --- a/pkg/ebpf/common/http2grpc_transform_test.go +++ b/pkg/ebpf/common/http2grpc_transform_test.go @@ -117,7 +117,7 @@ func TestHTTP2QuickDetection(t *testing.T) { t.Run(tt.name, func(t *testing.T) { res := isLikelyHTTP2(tt.input, tt.inputLen) assert.Equal(t, tt.expectedQuick, res) - res1 := isHTTP2(tt.input, tt.inputLen) + res1 := isHTTP2(NewLargeBufferFrom(tt.input), tt.inputLen) assert.Equal(t, tt.expected, res1) }) } @@ -535,7 +535,7 @@ func TestHandleHeaderField(t *testing.T) { func BenchmarkIsHTTP2(b *testing.B) { for i := 0; i < b.N; i++ { for _, tt := range isHTTP2TestCases { - _ = isHTTP2(tt.input, tt.inputLen) + _ = isHTTP2(NewLargeBufferFrom(tt.input), tt.inputLen) } } } diff --git a/pkg/ebpf/common/http_transform.go b/pkg/ebpf/common/http_transform.go index 478bdee1d..b8099cb46 100644 --- a/pkg/ebpf/common/http_transform.go +++ b/pkg/ebpf/common/http_transform.go @@ -191,7 +191,7 @@ func ReadHTTPInfoIntoSpan(parseCtx *EBPFParseContext, record *ringbuf.Record, fi func HTTPInfoEventToSpan(parseCtx *EBPFParseContext, event *BPFHTTPInfo) (request.Span, bool, error) { var ( - requestBuffer, responseBuffer []byte + requestBuffer, responseBuffer *LargeBuffer hasResponse bool isClient = isClientEvent(event.Type) ) @@ -204,7 +204,7 @@ func HTTPInfoEventToSpan(parseCtx *EBPFParseContext, event *BPFHTTPInfo) (reques requestBuffer = b } else { slog.Debug("missing large buffer for HTTP request", "traceID", event.Tp.TraceId, "conn", event.ConnInfo, "packetType", packetTypeRequest) - requestBuffer = event.Buf[:] + requestBuffer = NewLargeBufferFrom(event.Buf[:]) } b, ok = extractTCPLargeBuffer(parseCtx, event.Tp.TraceId, packetTypeResponse, directionByPacketType(packetTypeResponse, isClient), event.ConnInfo) @@ -215,7 +215,7 @@ func HTTPInfoEventToSpan(parseCtx *EBPFParseContext, event *BPFHTTPInfo) (reques slog.Debug("missing large buffer for HTTP response", "traceID", event.Tp.TraceId, "conn", event.ConnInfo, "packetType", packetTypeResponse) } } else { - requestBuffer = event.Buf[:] + requestBuffer = NewLargeBufferFrom(event.Buf[:]) } if parseCtx != nil && !parseCtx.payloadExtraction.Enabled() { @@ -229,7 +229,8 @@ func HTTPInfoEventToSpan(parseCtx *EBPFParseContext, event *BPFHTTPInfo) (reques return httpRequestToSpan(event, requestBuffer), false, nil } - req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(requestBuffer))) + // http.ReadRequest requires a *bufio.Reader; that one allocation is unavoidable. + req, err := http.ReadRequest(bufio.NewReader(requestBuffer.NewReader())) resp, err2 := httpSafeParseResponse(responseBuffer, req) if err != nil || err2 != nil { slog.Debug("error while parsing http request or response, falling back to manual HTTP info parsing", "reqErr", err, "respErr", err2) @@ -242,19 +243,19 @@ func HTTPInfoEventToSpan(parseCtx *EBPFParseContext, event *BPFHTTPInfo) (reques // HTTP response buffers might have been sent incomplete, before the full body. // Try to parse the original buffer first, if an EOF is encountered, append an empty // body to the buffer and try again. -func httpSafeParseResponse(responseBuffer []byte, req *http.Request) (*http.Response, error) { - rd := bufio.NewReader(bytes.NewReader(responseBuffer)) +func httpSafeParseResponse(responseBuffer *LargeBuffer, req *http.Request) (*http.Response, error) { + rd := bufio.NewReader(responseBuffer.NewReader()) resp, err := http.ReadResponse(rd, req) if err != nil && errors.Is(err, io.ErrUnexpectedEOF) { - // Append empty body and try again - responseBuffer := append(responseBuffer, []byte("\r\n\r\n")...) - rd = bufio.NewReader(bytes.NewReader(responseBuffer)) + // Append empty body terminator and retry with a fresh reader. + responseBuffer.AppendChunk([]byte("\r\n\r\n")) + rd.Reset(responseBuffer.NewReader()) return http.ReadResponse(rd, req) } return resp, nil } -func httpRequestToSpan(event *BPFHTTPInfo, requestBuffer []byte) request.Span { +func httpRequestToSpan(event *BPFHTTPInfo, requestBuffer *LargeBuffer) request.Span { var ( result = HTTPInfo{BPFHTTPInfo: *event} bufHost string @@ -262,6 +263,8 @@ func httpRequestToSpan(event *BPFHTTPInfo, requestBuffer []byte) request.Span { parsedHost bool ) + raw := requestBuffer.UnsafeView() + // When we can't find the connection info, we signal that through making the // source and destination ports equal to max short. E.g. async SSL if event.ConnInfo.S_port != 0 || event.ConnInfo.D_port != 0 { @@ -269,7 +272,7 @@ func httpRequestToSpan(event *BPFHTTPInfo, requestBuffer []byte) request.Span { result.Host = target result.Peer = source } else { - bufHost, bufPort = httpHostFromBuf(requestBuffer) + bufHost, bufPort = httpHostFromBuf(raw) parsedHost = true if bufPort >= 0 { @@ -277,11 +280,11 @@ func httpRequestToSpan(event *BPFHTTPInfo, requestBuffer []byte) request.Span { result.ConnInfo.D_port = uint16(bufPort) } } - result.URL = httpURLFromBuf(requestBuffer) - result.Method = httpMethodFromBuf(requestBuffer) + result.URL = httpURLFromBuf(raw) + result.Method = httpMethodFromBuf(raw) if request.EventType(result.Type) == request.EventTypeHTTPClient && !parsedHost { - bufHost, _ = httpHostFromBuf(requestBuffer) + bufHost, _ = httpHostFromBuf(raw) } result.HeaderHost = bufHost @@ -290,64 +293,55 @@ func httpRequestToSpan(event *BPFHTTPInfo, requestBuffer []byte) request.Span { } func httpURLFromBuf(req []byte) string { - buf := string(req) - space := strings.Index(buf, " ") - if space < 0 { - return "" + if end := bytes.IndexByte(req, 0); end >= 0 { + req = req[:end] } - bufEnd := bytes.IndexByte(req, 0) // We assume the buffer was zero initialized in eBPF - if bufEnd < 0 { - bufEnd = len(buf) - } - - if space+1 > bufEnd { + space := bytes.IndexByte(req, ' ') + if space < 0 { return "" } - nextSpace := strings.IndexAny(buf[space+1:bufEnd], " \r\n") + req = req[space+1:] + + nextSpace := bytes.IndexAny(req, " \r\n") if nextSpace < 0 { - return buf[space+1 : bufEnd] + return string(req) } - end := min(nextSpace+space+1, bufEnd) - - return buf[space+1 : end] + return string(req[:nextSpace]) } func httpMethodFromBuf(req []byte) string { - buf := string(req) - space := strings.Index(buf, " ") - if space < 0 { + method, _, found := bytes.Cut(req, []byte(" ")) + if !found { return "" } - return buf[:space] + return string(method) } func httpHostFromBuf(req []byte) (string, int) { - buf := cstr(req) - - host := "Host: " - idx := strings.Index(buf, host) + if end := bytes.IndexByte(req, 0); end >= 0 { + req = req[:end] + } + idx := bytes.Index(req, []byte("Host: ")) if idx < 0 { return "", -1 } - buf = buf[idx+len(host):] - - rIdx := strings.Index(buf, "\r") + req = req[idx+len("Host: "):] // only parse full host information, partial may // get the wrong name or wrong port - if rIdx < 0 { + hostPort, _, found := bytes.Cut(req, []byte("\r")) + if !found { return "", -1 } - - host, portStr, err := net.SplitHostPort(buf[:rIdx]) + host, portStr, err := net.SplitHostPort(string(hostPort)) if err != nil { - return buf[:rIdx], -1 + return string(hostPort), -1 } port, _ := strconv.Atoi(portStr) diff --git a/pkg/ebpf/common/kafka_detect_transform.go b/pkg/ebpf/common/kafka_detect_transform.go index f1f0447d9..6b2bd2047 100644 --- a/pkg/ebpf/common/kafka_detect_transform.go +++ b/pkg/ebpf/common/kafka_detect_transform.go @@ -46,11 +46,18 @@ func (k Operation) String() string { // ProcessPossibleKafkaEvent processes a TCP packet and returns error if the packet is not a valid Kafka request. // Otherwise, return kafka.Info with the processed data. -func ProcessPossibleKafkaEvent(event *TCPRequestInfo, pkt []byte, rpkt []byte, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { +func ProcessPossibleKafkaEvent(event *TCPRequestInfo, pkt *LargeBufferReader, rpkt *LargeBufferReader, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { k, ok, err := ProcessKafkaEvent(pkt, rpkt, kafkaTopicUUIDToName) if err != nil { // If we are getting the information in the response buffer, the event // must be reversed and that's how we captured it. + // Reset readers before retrying with swapped buffers. + if pkt != nil { + pkt.Reset() + } + if rpkt != nil { + rpkt.Reset() + } k, ok, err = ProcessKafkaEvent(rpkt, pkt, kafkaTopicUUIDToName) if err == nil { reverseTCPEvent(event) @@ -59,16 +66,16 @@ func ProcessPossibleKafkaEvent(event *TCPRequestInfo, pkt []byte, rpkt []byte, k return k, ok, err } -func ProcessKafkaEvent(pkt []byte, rpkt []byte, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { - hdr, offset, err := kafkaparser.ParseKafkaRequestHeader(pkt) +func ProcessKafkaEvent(pkt *LargeBufferReader, rpkt *LargeBufferReader, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { + hdr, err := kafkaparser.ParseKafkaRequestHeader(pkt) if err != nil { return nil, true, err } switch hdr.APIKey { case kafkaparser.APIKeyProduce: - return processProduceRequest(pkt, hdr, offset) + return processProduceRequest(pkt, hdr) case kafkaparser.APIKeyFetch: - return processFetchRequest(pkt, hdr, offset, kafkaTopicUUIDToName) + return processFetchRequest(pkt, hdr, kafkaTopicUUIDToName) case kafkaparser.APIKeyMetadata: return processMetadataResponse(rpkt, hdr, kafkaTopicUUIDToName) default: @@ -76,8 +83,8 @@ func ProcessKafkaEvent(pkt []byte, rpkt []byte, kafkaTopicUUIDToName *simplelru. } } -func processProduceRequest(pkt []byte, hdr *kafkaparser.KafkaRequestHeader, offset kafkaparser.Offset) (*KafkaInfo, bool, error) { - produceReq, err := kafkaparser.ParseProduceRequest(pkt, hdr, offset) +func processProduceRequest(pkt *LargeBufferReader, hdr *kafkaparser.KafkaRequestHeader) (*KafkaInfo, bool, error) { + produceReq, err := kafkaparser.ParseProduceRequest(pkt, hdr) if err != nil { return nil, true, err } @@ -96,8 +103,8 @@ func processProduceRequest(pkt []byte, hdr *kafkaparser.KafkaRequestHeader, offs }, false, nil } -func processFetchRequest(pkt []byte, hdr *kafkaparser.KafkaRequestHeader, offset kafkaparser.Offset, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { - fetchReq, err := kafkaparser.ParseFetchRequest(pkt, hdr, offset) +func processFetchRequest(pkt *LargeBufferReader, hdr *kafkaparser.KafkaRequestHeader, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { + fetchReq, err := kafkaparser.ParseFetchRequest(pkt, hdr) if err != nil { return nil, true, err } @@ -127,13 +134,16 @@ func processFetchRequest(pkt []byte, hdr *kafkaparser.KafkaRequestHeader, offset }, false, nil } -func processMetadataResponse(rpkt []byte, hdr *kafkaparser.KafkaRequestHeader, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { +func processMetadataResponse(rpkt *LargeBufferReader, hdr *kafkaparser.KafkaRequestHeader, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { + if rpkt == nil { + return nil, true, errors.New("no response buffer for metadata request") + } // only interested in response - _, offset, err := kafkaparser.ParseKafkaResponseHeader(rpkt, hdr) + _, err := kafkaparser.ParseKafkaResponseHeader(rpkt, hdr) if err != nil { return nil, true, err } - metadataResponse, err := kafkaparser.ParseMetadataResponse(rpkt, hdr, offset) + metadataResponse, err := kafkaparser.ParseMetadataResponse(rpkt, hdr) if err != nil { return nil, true, err } @@ -143,16 +153,16 @@ func processMetadataResponse(rpkt []byte, hdr *kafkaparser.KafkaRequestHeader, k return nil, true, nil } -func ProcessKafkaRequest(pkt []byte, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { - hdr, offset, err := kafkaparser.ParseKafkaRequestHeader(pkt) +func ProcessKafkaRequest(pkt *LargeBufferReader, kafkaTopicUUIDToName *simplelru.LRU[kafkaparser.UUID, string]) (*KafkaInfo, bool, error) { + hdr, err := kafkaparser.ParseKafkaRequestHeader(pkt) if err != nil { return nil, true, err } switch hdr.APIKey { case kafkaparser.APIKeyProduce: - return processProduceRequest(pkt, hdr, offset) + return processProduceRequest(pkt, hdr) case kafkaparser.APIKeyFetch: - return processFetchRequest(pkt, hdr, offset, kafkaTopicUUIDToName) + return processFetchRequest(pkt, hdr, kafkaTopicUUIDToName) default: return nil, true, errors.New("unsupported Kafka API key") } diff --git a/pkg/ebpf/common/kafka_detect_transform_test.go b/pkg/ebpf/common/kafka_detect_transform_test.go index 5710a6af9..5d2c186f1 100644 --- a/pkg/ebpf/common/kafka_detect_transform_test.go +++ b/pkg/ebpf/common/kafka_detect_transform_test.go @@ -172,12 +172,12 @@ func TestProcessKafkaRequest(t *testing.T) { cache, _ := simplelru.NewLRU[kafkaparser.UUID, string](1000, nil) if len(tt.preRequests) > 0 { for _, preInput := range tt.preRequests { - _, ignore, err := ProcessKafkaEvent(preInput.request, preInput.response, cache) + _, ignore, err := ProcessKafkaEvent(NewLargeBufferFrom(preInput.request).NewReader(), NewLargeBufferFrom(preInput.response).NewReader(), cache) require.NoError(t, err) require.True(t, ignore) } } - res, _, err := ProcessKafkaEvent(tt.request, nil, cache) + res, _, err := ProcessKafkaEvent(NewLargeBufferFrom(tt.request).NewReader(), nil, cache) if tt.err { assert.Error(t, err) return diff --git a/pkg/ebpf/common/large_buffer.go b/pkg/ebpf/common/large_buffer.go new file mode 100644 index 000000000..f4335e96d --- /dev/null +++ b/pkg/ebpf/common/large_buffer.go @@ -0,0 +1,561 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package ebpfcommon // import "go.opentelemetry.io/obi/pkg/ebpf/common" + +import ( + "encoding/binary" + "fmt" + "io" +) + +// LargeBuffer assembles chunked eBPF ring-buffer events into a contiguous byte stream. +// +// # Storage +// +// Chunks are stored independently as [][]byte. Each [LargeBuffer.AppendChunk] call allocates +// exactly one new slice and records its header (pointer + length + capacity, 24 bytes) in the +// chunk index. No previously-written chunk data is ever reallocated or copied when new chunks +// arrive — contrast with a flat []byte whose backing array must be copied on every capacity +// growth. +// +// # Reading +// +// Use [LargeBuffer.NewReader] to create a [LargeBufferReader] positioned at byte 0, then call +// its cursor methods ([LargeBufferReader.ReadN], [LargeBufferReader.Peek], etc.) to parse the +// payload. Multiple independent readers can operate on the same buffer simultaneously. +// +// For random-access reads without a cursor, use [LargeBuffer.UnsafeViewAt], +// [LargeBuffer.CopyAt], and the scalar helpers (e.g. [LargeBuffer.U32BEAt]) directly on the +// buffer. +// +// # Ring-buffer memory safety +// +// eBPF ring-buffer records share kernel-mapped memory that is reclaimed on the next ReadInto +// call. [LargeBuffer.AppendChunk] always copies the provided data into a new Go-owned allocation, +// so no reference to ring-buffer memory is retained across event-loop iterations. +// +// [LargeBuffer.NewLargeBufferFrom] is the only exception: it wraps an existing slice without +// copying. It is safe only when the wrapped slice outlives all reads — use it exclusively for +// inline event buffers consumed within the same call frame. +type LargeBuffer struct { + chunks [][]byte + total int + scratch []byte // used only by UnsafeViewAt for cross-chunk absolute-offset reads +} + +// NewLargeBuffer returns an empty LargeBuffer ready to receive chunks. +func NewLargeBuffer() *LargeBuffer { + return &LargeBuffer{} +} + +// NewLargeBufferFrom wraps b as a single-chunk LargeBuffer without copying. +// +// The caller must ensure that b remains valid for the lifetime of all reads. +// Do NOT use this with ring-buffer memory that will be reclaimed across event-loop iterations. +// Safe use: inline event fields (e.g. event.Buf[:]) consumed within the same call frame. +func NewLargeBufferFrom(b []byte) *LargeBuffer { + return &LargeBuffer{ + chunks: [][]byte{b}, + total: len(b), + } +} + +// AppendChunk copies data into a new independently-allocated chunk. +func (lb *LargeBuffer) AppendChunk(data []byte) { + chunk := make([]byte, len(data)) + copy(chunk, data) + + lb.chunks = append(lb.chunks, chunk) + lb.total += len(data) +} + +// Len returns the total number of bytes across all chunks. +func (lb *LargeBuffer) Len() int { + return lb.total +} + +// IsEmpty reports whether the buffer contains no bytes (Len() == 0). +func (lb *LargeBuffer) IsEmpty() bool { + return lb.total == 0 +} + +// CloneBytes returns a freshly allocated copy of all chunks. +// The caller owns the returned slice — it is never shared with the LargeBuffer's +// internal storage. +func (lb *LargeBuffer) CloneBytes() []byte { + if lb.total == 0 { + return nil + } + + out := make([]byte, lb.total) + pos := 0 + + for _, chunk := range lb.chunks { + pos += copy(out[pos:], chunk) + } + + return out +} + +// Reset clears all chunks, returning the LargeBuffer to its zero value. +// The scratch buffer is retained to avoid re-allocation on the next use. +// Intended for future use with sync.Pool to allow instance reuse. +func (lb *LargeBuffer) Reset() { + lb.chunks = lb.chunks[:0] + lb.total = 0 +} + +// NewReader creates a new [LargeBufferReader] positioned at byte 0 of the buffer. +// Multiple independent readers can operate on the same buffer simultaneously. +func (lb *LargeBuffer) NewReader() *LargeBufferReader { + return &LargeBufferReader{lb: lb} +} + +// ── Absolute-offset access ──────────────────────────────────────────────────── + +// findChunk maps absOff (an absolute byte offset from the start of the buffer) to the chunk +// index and the byte offset within that chunk. Returns (-1, 0) when absOff is out of +// [0, lb.total). O(number of chunks); fast for the typical 1–3 chunk case. +func (lb *LargeBuffer) findChunk(absOff int) (int, int) { + if absOff < 0 || absOff >= lb.total { + return -1, 0 + } + + pos := 0 + + for i, chunk := range lb.chunks { + end := pos + len(chunk) + if absOff < end { + return i, absOff - pos + } + pos = end + } + + return -1, 0 +} + +// UnsafeView returns a view over the entire buffer contents, equivalent to UnsafeViewAt(0, Len()). +// +// The returned slice MUST NOT be retained across the next UnsafeView or UnsafeViewAt call on the +// same buffer. Returns nil when the buffer is empty. +func (lb *LargeBuffer) UnsafeView() []byte { + if lb.total == 0 { + return nil + } + + b, _ := lb.UnsafeViewAt(0, lb.total) + + return b +} + +// UnsafeViewAt returns n bytes starting at absOff. +// +// Zero-copy path: when all n bytes lie within one chunk, a sub-slice of that chunk's backing +// array is returned — no allocation, no copy. +// +// Cross-chunk path: bytes are copied into the internal scratch buffer (grown as needed, never +// freed). The same scratch slice is reused on subsequent cross-chunk calls. +// +// The returned slice MUST NOT be retained across the next UnsafeViewAt call on the same buffer. +// +// Returns an error when the range [absOff, absOff+n) is out of [0, Len()). +func (lb *LargeBuffer) UnsafeViewAt(absOff, n int) ([]byte, error) { + if n == 0 { + return []byte{}, nil + } + + if n < 0 || absOff < 0 || absOff+n > lb.total { + return nil, fmt.Errorf("LargeBuffer.UnsafeViewAt: [%d, %d) out of range [0, %d)", absOff, absOff+n, lb.total) + } + + ci, off := lb.findChunk(absOff) + + // Fast path: all bytes within one chunk — zero-copy. + if off+n <= len(lb.chunks[ci]) { + return lb.chunks[ci][off : off+n], nil + } + + // Slow path: crosses chunk boundary — copy into reusable scratch. + if cap(lb.scratch) < n { + lb.scratch = make([]byte, n) + } + + lb.scratch = lb.scratch[:n] + + for filled := 0; filled < n; { + copied := copy(lb.scratch[filled:], lb.chunks[ci][off:]) + filled += copied + ci++ + off = 0 + } + + return lb.scratch, nil +} + +// CopyAt copies exactly len(dst) bytes starting at absolute offset absOff into dst. +// +// Works across chunk boundaries. The caller owns the result. +// Returns an error when the range [absOff, absOff+len(dst)) is out of [0, Len()). +func (lb *LargeBuffer) CopyAt(absOff int, dst []byte) error { + n := len(dst) + + if n == 0 { + return nil + } + + if absOff < 0 || absOff+n > lb.total { + return fmt.Errorf("LargeBuffer.CopyAt: [%d, %d) out of range [0, %d)", absOff, absOff+n, lb.total) + } + + ci, off := lb.findChunk(absOff) + + for filled := 0; filled < n; { + copied := copy(dst[filled:], lb.chunks[ci][off:]) + filled += copied + ci++ + off = 0 + } + + return nil +} + +// ── Scalar helpers ──────────────────────────────────────────────────────────── +// +// Each helper reads a fixed-width integer at absOff. +// All delegate to UnsafeViewAt: zero-copy within a chunk, scratch-backed across boundaries. + +// U8At reads a uint8 at absOff. +func (lb *LargeBuffer) U8At(absOff int) (uint8, error) { + b, err := lb.UnsafeViewAt(absOff, 1) + if err != nil { + return 0, err + } + + return b[0], nil +} + +// U16BEAt reads a big-endian uint16 at absOff. +func (lb *LargeBuffer) U16BEAt(absOff int) (uint16, error) { + b, err := lb.UnsafeViewAt(absOff, 2) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint16(b), nil +} + +// U32BEAt reads a big-endian uint32 at absOff. +func (lb *LargeBuffer) U32BEAt(absOff int) (uint32, error) { + b, err := lb.UnsafeViewAt(absOff, 4) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint32(b), nil +} + +// U64BEAt reads a big-endian uint64 at absOff. +func (lb *LargeBuffer) U64BEAt(absOff int) (uint64, error) { + b, err := lb.UnsafeViewAt(absOff, 8) + if err != nil { + return 0, err + } + + return binary.BigEndian.Uint64(b), nil +} + +// I16BEAt reads a big-endian int16 at absOff. +func (lb *LargeBuffer) I16BEAt(absOff int) (int16, error) { + v, err := lb.U16BEAt(absOff) + + return int16(v), err +} + +// I32BEAt reads a big-endian int32 at absOff. +func (lb *LargeBuffer) I32BEAt(absOff int) (int32, error) { + v, err := lb.U32BEAt(absOff) + + return int32(v), err +} + +// I64BEAt reads a big-endian int64 at absOff. +func (lb *LargeBuffer) I64BEAt(absOff int) (int64, error) { + v, err := lb.U64BEAt(absOff) + + return int64(v), err +} + +// U16LEAt reads a little-endian uint16 at absOff. +func (lb *LargeBuffer) U16LEAt(absOff int) (uint16, error) { + b, err := lb.UnsafeViewAt(absOff, 2) + if err != nil { + return 0, err + } + + return binary.LittleEndian.Uint16(b), nil +} + +// U32LEAt reads a little-endian uint32 at absOff. +func (lb *LargeBuffer) U32LEAt(absOff int) (uint32, error) { + b, err := lb.UnsafeViewAt(absOff, 4) + if err != nil { + return 0, err + } + + return binary.LittleEndian.Uint32(b), nil +} + +// U64LEAt reads a little-endian uint64 at absOff. +func (lb *LargeBuffer) U64LEAt(absOff int) (uint64, error) { + b, err := lb.UnsafeViewAt(absOff, 8) + if err != nil { + return 0, err + } + + return binary.LittleEndian.Uint64(b), nil +} + +// I16LEAt reads a little-endian int16 at absOff. +func (lb *LargeBuffer) I16LEAt(absOff int) (int16, error) { + v, err := lb.U16LEAt(absOff) + + return int16(v), err +} + +// I32LEAt reads a little-endian int32 at absOff. +func (lb *LargeBuffer) I32LEAt(absOff int) (int32, error) { + v, err := lb.U32LEAt(absOff) + + return int32(v), err +} + +// I64LEAt reads a little-endian int64 at absOff. +func (lb *LargeBuffer) I64LEAt(absOff int) (int64, error) { + v, err := lb.U64LEAt(absOff) + + return int64(v), err +} + +// ── LargeBufferReader ───────────────────────────────────────────────────────── + +// LargeBufferReader provides cursor-based sequential access to a [LargeBuffer]. +// +// Create with [LargeBuffer.NewReader]. Multiple independent readers can operate on the same +// buffer simultaneously — each reader maintains its own cursor and scratch buffer. +// +// ReadN returns the next n bytes and advances the cursor. +// When all n bytes lie within the current chunk, a sub-slice of that chunk's backing array is +// returned (zero allocation, zero copy). When n crosses a chunk boundary the internal scratch +// buffer is reused (one copy, no heap allocation after the first cross-boundary read). The +// returned slice must NOT be retained across the next ReadN or Read call. +// +// LargeBufferReader implements [io.Reader] for use with bufio.NewReader and stream-oriented +// parsers such as net/http. +type LargeBufferReader struct { + lb *LargeBuffer + rchunk int // index of the current read chunk + roff int // byte offset within lb.chunks[rchunk] + scratch []byte +} + +// Reset repositions this reader to the beginning of the buffer. +// Equivalent to discarding this reader and calling lb.NewReader(), but without the allocation. +func (r *LargeBufferReader) Reset() { + r.rchunk = 0 + r.roff = 0 +} + +// Remaining returns the number of unread bytes from the cursor to the end. +func (r *LargeBufferReader) Remaining() int { + consumed := r.roff + + for i := range r.rchunk { + consumed += len(r.lb.chunks[i]) + } + + return r.lb.total - consumed +} + +// ReadOffset returns the current cursor position as an absolute byte offset from the start of +// the buffer. +func (r *LargeBufferReader) ReadOffset() int { + return r.lb.total - r.Remaining() +} + +// BaseOffset always returns 0. Provided for API symmetry with ReadOffset. +func (r *LargeBufferReader) BaseOffset() int { + return 0 +} + +// ── Cursor-based access ─────────────────────────────────────────────────────── + +// ReadN returns exactly n bytes starting at the current read position and advances the cursor. +// +// Zero-copy path: when all n bytes lie within the current chunk, a sub-slice of that chunk's +// backing array is returned — no allocation, no copy. +// +// Cross-chunk path: bytes are copied into the internal scratch buffer (grown as needed, never +// freed). The same scratch slice is reused on subsequent cross-chunk calls. +// +// The returned slice MUST NOT be retained across the next ReadN or Read call. +func (r *LargeBufferReader) ReadN(n int) ([]byte, error) { + if n == 0 { + return nil, nil + } + + if n > r.Remaining() { + return nil, fmt.Errorf("LargeBuffer.ReadN: requested %d bytes but only %d remaining", n, r.Remaining()) + } + + // Fast path: all bytes within the current chunk — zero allocation, zero copy. + if r.rchunk < len(r.lb.chunks) && r.roff+n <= len(r.lb.chunks[r.rchunk]) { + s := r.lb.chunks[r.rchunk][r.roff : r.roff+n] + + r.roff += n + + if r.roff == len(r.lb.chunks[r.rchunk]) { + r.rchunk++ + r.roff = 0 + } + + return s, nil + } + + // Slow path: copy across chunk boundaries into reusable scratch. + if cap(r.scratch) < n { + r.scratch = make([]byte, n) + } + + r.scratch = r.scratch[:n] + r.copyN(r.scratch) + + return r.scratch, nil +} + +// Peek returns the next n bytes without advancing the read cursor. +// +// Zero-copy path: when all n bytes lie within the current chunk, a sub-slice of that chunk is +// returned with no allocation. +// +// Cross-chunk path: copies into the internal scratch buffer (same reuse semantics as ReadN). +// +// The returned slice MUST NOT be retained across the next ReadN or Read call. +func (r *LargeBufferReader) Peek(n int) ([]byte, error) { + if n == 0 { + return nil, nil + } + + if n > r.Remaining() { + return nil, fmt.Errorf("LargeBuffer.Peek: requested %d bytes but only %d remaining", n, r.Remaining()) + } + + // Fast path: within current chunk — return sub-slice directly. + if r.rchunk < len(r.lb.chunks) && r.roff+n <= len(r.lb.chunks[r.rchunk]) { + return r.lb.chunks[r.rchunk][r.roff : r.roff+n], nil + } + + // Slow path: copy into scratch, then restore cursor position. + savedChunk, savedOff := r.rchunk, r.roff + + if cap(r.scratch) < n { + r.scratch = make([]byte, n) + } + r.scratch = r.scratch[:n] + r.copyN(r.scratch) + + r.rchunk, r.roff = savedChunk, savedOff + + return r.scratch, nil +} + +// Skip advances the read cursor by n bytes without copying any data. +func (r *LargeBufferReader) Skip(n int) error { + if n > r.Remaining() { + return fmt.Errorf("LargeBuffer.Skip: requested %d bytes but only %d remaining", n, r.Remaining()) + } + + for n > 0 { + avail := len(r.lb.chunks[r.rchunk]) - r.roff + + if n < avail { + r.roff += n + return nil + } + + n -= avail + r.rchunk++ + r.roff = 0 + } + + return nil +} + +// Read implements [io.Reader]. Fills p with up to len(p) bytes from the current read position. +// +// Returns (n, nil) when bytes were read but the cursor has not yet reached the end. +// Returns (0, io.EOF) when the cursor is already at the end of the buffer. +// Per the io.Reader contract, may return (n, nil) even when the last byte was just read; +// the subsequent call returns (0, io.EOF). +func (r *LargeBufferReader) Read(p []byte) (int, error) { + if r.Remaining() == 0 { + return 0, io.EOF + } + + n := 0 + for n < len(p) && r.rchunk < len(r.lb.chunks) { + src := r.lb.chunks[r.rchunk][r.roff:] + copied := copy(p[n:], src) + + n += copied + r.roff += copied + + if r.roff == len(r.lb.chunks[r.rchunk]) { + r.rchunk++ + r.roff = 0 + } + } + + return n, nil +} + +// Bytes returns the unread portion of the buffer (from the current read cursor to the end) +// without advancing the cursor — analogous to [bytes.Buffer.Bytes]. +// +// Zero-copy path: when all remaining bytes lie within the current chunk, a sub-slice of that +// chunk's backing array is returned with no allocation. +// +// Cross-chunk path: copies into the internal scratch buffer (same reuse semantics as ReadN). +// The returned slice MUST NOT be retained across the next ReadN, Read, or Bytes call. +// +// Returns nil when there are no unread bytes remaining (Remaining() == 0). +func (r *LargeBufferReader) Bytes() []byte { + rem := r.Remaining() + + if rem == 0 { + return nil + } + + b, _ := r.Peek(rem) + + return b +} + +// copyN copies exactly len(dst) bytes from the current read position into dst, advancing the +// cursor. Assumes the caller has already verified that len(dst) <= r.Remaining(). +func (r *LargeBufferReader) copyN(dst []byte) { + filled := 0 + + for filled < len(dst) { + src := r.lb.chunks[r.rchunk][r.roff:] + copied := copy(dst[filled:], src) + + filled += copied + r.roff += copied + + if r.roff == len(r.lb.chunks[r.rchunk]) { + r.rchunk++ + r.roff = 0 + } + } +} diff --git a/pkg/ebpf/common/large_buffer_test.go b/pkg/ebpf/common/large_buffer_test.go new file mode 100644 index 000000000..f45d7a309 --- /dev/null +++ b/pkg/ebpf/common/large_buffer_test.go @@ -0,0 +1,1004 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package ebpfcommon + +import ( + "bufio" + "bytes" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// ── Construction ───────────────────────────────────────────────────────────── + +func TestNewLargeBuffer_empty(t *testing.T) { + lb := NewLargeBuffer() + r := lb.NewReader() + + assert.Equal(t, 0, lb.Len()) + assert.Equal(t, 0, r.Remaining()) +} + +func TestNewLargeBufferFrom_wrapsWithoutCopy(t *testing.T) { + src := []byte("hello") + lb := NewLargeBufferFrom(src) + r := lb.NewReader() + + assert.Equal(t, 5, lb.Len()) + assert.Equal(t, 5, r.Remaining()) + + got, err := r.ReadN(5) + require.NoError(t, err) + assert.Equal(t, src, got) + + // Verify the slice is backed by the same array (zero-copy). + assert.Equal(t, &src[0], &got[0]) +} + +// ── AppendChunk ─────────────────────────────────────────────────────────────── + +func TestAppendChunk_copiesData(t *testing.T) { + src := []byte("world") + lb := NewLargeBuffer() + lb.AppendChunk(src) + + // Mutating src must not affect the buffer. + src[0] = 'X' + + got, err := lb.NewReader().ReadN(5) + require.NoError(t, err) + assert.Equal(t, "world", string(got)) +} + +func TestAppendChunk_multipleChunks(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("foo")) + lb.AppendChunk([]byte("bar")) + lb.AppendChunk([]byte("baz")) + + assert.Equal(t, 9, lb.Len()) + assert.Equal(t, 9, lb.NewReader().Remaining()) +} + +// ── ReadN ───────────────────────────────────────────────────────────────────── + +func TestReadN_withinChunk_zeroCopy(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdefgh")) + + r := lb.NewReader() + allocs := testing.AllocsPerRun(100, func() { + r.Reset() + _, _ = r.ReadN(4) + }) + + assert.InDelta(t, float64(0), allocs, 0, "ReadN within a single chunk must not allocate") +} + +func TestReadN_withinChunk_returnsCorrectBytes(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdefgh")) + r := lb.NewReader() + + got, err := r.ReadN(3) + require.NoError(t, err) + assert.Equal(t, "abc", string(got)) + + got, err = r.ReadN(3) + require.NoError(t, err) + assert.Equal(t, "def", string(got)) + + assert.Equal(t, 2, r.Remaining()) +} + +func TestReadN_crossChunk_reusesScatch(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + + r := lb.NewReader() + + // Warm up scratch. + _, _ = r.ReadN(4) + scratch1 := r.scratch + + r.Reset() + _, _ = r.ReadN(4) + scratch2 := r.scratch + + // Same backing array reused. + assert.Equal(t, &scratch1[0], &scratch2[0], "scratch buffer must be reused across cross-chunk ReadN calls") +} + +func TestReadN_crossChunk_returnsCorrectBytes(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + + got, err := lb.NewReader().ReadN(5) + require.NoError(t, err) + assert.Equal(t, "abcde", string(got)) +} + +func TestReadN_exactlyChunkBoundary(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + r := lb.NewReader() + + got, err := r.ReadN(3) + require.NoError(t, err) + assert.Equal(t, "abc", string(got)) + + got, err = r.ReadN(3) + require.NoError(t, err) + assert.Equal(t, "def", string(got)) + + assert.Equal(t, 0, r.Remaining()) +} + +func TestReadN_tooManyBytes_returnsError(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + + _, err := lb.NewReader().ReadN(10) + assert.Error(t, err) +} + +func TestReadN_zero_returnsNil(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + + got, err := lb.NewReader().ReadN(0) + require.NoError(t, err) + assert.Nil(t, got) +} + +// ── Peek ────────────────────────────────────────────────────────────────────── + +func TestPeek_doesNotAdvanceCursor(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + r := lb.NewReader() + + p, err := r.Peek(3) + require.NoError(t, err) + assert.Equal(t, "hel", string(p)) + assert.Equal(t, 5, r.Remaining(), "Peek must not advance cursor") + + got, err := r.ReadN(5) + require.NoError(t, err) + assert.Equal(t, "hello", string(got)) +} + +func TestPeek_crossChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("ab")) + lb.AppendChunk([]byte("cd")) + r := lb.NewReader() + + p, err := r.Peek(3) + require.NoError(t, err) + assert.Equal(t, "abc", string(p)) + assert.Equal(t, 4, r.Remaining(), "Peek must not advance cursor") +} + +// ── Skip ────────────────────────────────────────────────────────────────────── + +func TestSkip_withinChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdef")) + r := lb.NewReader() + + require.NoError(t, r.Skip(3)) + assert.Equal(t, 3, r.Remaining()) + + got, err := r.ReadN(3) + require.NoError(t, err) + assert.Equal(t, "def", string(got)) +} + +func TestSkip_crossChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + r := lb.NewReader() + + require.NoError(t, r.Skip(4)) + + got, err := r.ReadN(2) + require.NoError(t, err) + assert.Equal(t, "ef", string(got)) +} + +func TestSkip_tooMany_returnsError(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + + assert.Error(t, lb.NewReader().Skip(10)) +} + +// ── Remaining ──────────────────────────────────────────────────────────────── + +func TestRemaining_tracksReadPosition(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + r := lb.NewReader() + + assert.Equal(t, 6, r.Remaining()) + + _, _ = r.ReadN(2) + assert.Equal(t, 4, r.Remaining()) + + _, _ = r.ReadN(3) + assert.Equal(t, 1, r.Remaining()) +} + +// ── Reset ───────────────────────────────────────────────────────────────────── + +func TestReaderReset_restartsFromBeginning(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + r := lb.NewReader() + + _, _ = r.ReadN(5) + assert.Equal(t, 0, r.Remaining()) + + r.Reset() + assert.Equal(t, 5, r.Remaining()) + + got, err := r.ReadN(5) + require.NoError(t, err) + assert.Equal(t, "hello", string(got)) +} + +func TestReaderReset_afterAppendChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + r := lb.NewReader() + _, _ = r.ReadN(5) + + lb.AppendChunk([]byte(" world")) + r.Reset() + + got, err := r.ReadN(11) + require.NoError(t, err) + assert.Equal(t, "hello world", string(got)) +} + +// ── Read (io.Reader) ────────────────────────────────────────────────────────── + +func TestRead_ioReaderCompliance(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello ")) + lb.AppendChunk([]byte("world")) + + all, err := io.ReadAll(lb.NewReader()) + require.NoError(t, err) + assert.Equal(t, "hello world", string(all)) +} + +func TestRead_eoFOnEmpty(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + r := lb.NewReader() + + _, _ = io.ReadAll(r) + + n, err := r.Read(make([]byte, 4)) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) +} + +func TestRead_withBufioReader(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("GET / HTTP/1.0\r\nHost: x\r\n\r\n")) + + br := bufio.NewReader(lb.NewReader()) + line, err := br.ReadString('\n') + require.NoError(t, err) + assert.Equal(t, "GET / HTTP/1.0\r\n", line) +} + +// ── Bytes (cursor-aware, non-advancing) ────────────────────────────────────── + +func TestBytes_cursorAtZero_singleChunk_zeroCopy(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + r := lb.NewReader() + + got := r.Bytes() + + // Cursor at start, single chunk: sub-slice of chunk's backing array — zero-copy. + assert.Equal(t, &lb.chunks[0][0], &got[0], "Bytes() at cursor=0 single-chunk must be zero-copy") + assert.Equal(t, "hello", string(got)) + assert.Equal(t, 5, r.Remaining(), "Bytes() must not advance cursor") +} + +func TestBytes_cursorAware_returnsUnreadPortion(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdef")) + r := lb.NewReader() + _, _ = r.ReadN(3) // advance cursor past first 3 bytes + + got := r.Bytes() + assert.Equal(t, "def", string(got)) + assert.Equal(t, 3, r.Remaining(), "Bytes() must not advance cursor") +} + +func TestBytes_multiChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("foo")) + lb.AppendChunk([]byte("bar")) + r := lb.NewReader() + + got := r.Bytes() + assert.Equal(t, "foobar", string(got)) + assert.Equal(t, 6, r.Remaining(), "Bytes() must not advance cursor") +} + +func TestBytes_empty(t *testing.T) { + lb := NewLargeBuffer() + assert.Nil(t, lb.NewReader().Bytes()) +} + +func TestBytes_afterReadAll_returnsNil(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + r := lb.NewReader() + _, _ = r.ReadN(2) + + assert.Nil(t, r.Bytes(), "Bytes() at end of buffer must return nil") +} + +func TestBytes_singleChunk_isSharedView(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + + got := lb.NewReader().Bytes() + + // Bytes() returns a view into the internal chunk — mutating it affects the chunk. + got[0] = 'X' + assert.Equal(t, "Xello", string(lb.chunks[0]), "Bytes() single-chunk must be a shared view, not a copy") +} + +func TestBytes_newLargeBufferFrom_isSharedView(t *testing.T) { + src := []byte("hello") + lb := NewLargeBufferFrom(src) + + got := lb.NewReader().Bytes() + + // Bytes() returns a view into src — mutating it affects the original slice. + got[0] = 'X' + assert.Equal(t, "Xello", string(src), "Bytes() on NewLargeBufferFrom must be a shared view into src") +} + +// ── CloneBytes (cursor-independent) ────────────────────────────────────────── + +func TestCloneBytes_singleChunk_alwaysCopies(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + + got := lb.CloneBytes() + assert.Equal(t, "hello", string(got)) + + // Mutate the returned slice — the internal chunk must be unaffected. + got[0] = 'X' + assert.Equal(t, "hello", string(lb.chunks[0]), "CloneBytes() must return an independent copy") +} + +func TestCloneBytes_multiChunk_materializes(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("foo")) + lb.AppendChunk([]byte("bar")) + + got := lb.CloneBytes() + assert.Equal(t, "foobar", string(got)) +} + +func TestCloneBytes_cursorIndependent(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdef")) + r := lb.NewReader() + _, _ = r.ReadN(3) // advance cursor past first 3 bytes + + got := lb.CloneBytes() + // CloneBytes always returns all chunks regardless of cursor position. + assert.Equal(t, "abcdef", string(got)) +} + +func TestCloneBytes_empty(t *testing.T) { + lb := NewLargeBuffer() + assert.Nil(t, lb.CloneBytes()) +} + +// ── Reset ───────────────────────────────────────────────────────────────────── + +func TestReset_clearsAllState(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("data")) + + lb.Reset() + + assert.Equal(t, 0, lb.Len()) + assert.Empty(t, lb.chunks) +} + +func TestReset_allowsReuse(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("first")) + lb.AppendChunk([]byte("pass")) + + // Populate scratch via a cross-chunk UnsafeView. + _ = lb.UnsafeView() + scratch := lb.scratch + + lb.Reset() + assert.Equal(t, 0, lb.Len()) + assert.True(t, lb.IsEmpty()) + + // Reuse: append new data and read it back correctly. + lb.AppendChunk([]byte("second pass")) + got := lb.UnsafeView() + assert.Equal(t, "second pass", string(got)) + + // Reset must not free the scratch backing array. + assert.Equal(t, &scratch[0], &lb.scratch[0], "scratch backing array must survive Reset") +} + +// ── Multi-chunk edge cases ─────────────────────────────────────────────────── + +func TestReadN_manySmallChunks(t *testing.T) { + lb := NewLargeBuffer() + expected := make([]byte, 0, 26) + + for b := byte('a'); b <= 'z'; b++ { + lb.AppendChunk([]byte{b}) + expected = append(expected, b) + } + + got, err := lb.NewReader().ReadN(26) + require.NoError(t, err) + assert.Equal(t, expected, got) +} + +func TestReadN_spanThreeChunks(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("ab")) + lb.AppendChunk([]byte("cd")) + lb.AppendChunk([]byte("ef")) + + got, err := lb.NewReader().ReadN(5) + require.NoError(t, err) + assert.Equal(t, "abcde", string(got)) +} + +func TestCloneBytes_singleChunkAfterNewLargeBufferFrom(t *testing.T) { + src := []byte("direct") + lb := NewLargeBufferFrom(src) + + got := lb.CloneBytes() + assert.Equal(t, "direct", string(got)) + + // Mutate the returned slice — src must be unaffected. + got[0] = 'X' + assert.Equal(t, "direct", string(src), "CloneBytes() must return an independent copy") +} + +// ── Interleaved reads across all methods ───────────────────────────────────── + +func TestInterleaved_peekReadSkip(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdefghij")) + r := lb.NewReader() + + p, err := r.Peek(3) + require.NoError(t, err) + assert.Equal(t, "abc", string(p)) + + got, err := r.ReadN(2) + require.NoError(t, err) + assert.Equal(t, "ab", string(got)) + + require.NoError(t, r.Skip(3)) + + got, err = r.ReadN(5) + require.NoError(t, err) + assert.Equal(t, "fghij", string(got)) + + assert.Equal(t, 0, r.Remaining()) +} + +// ── ReadOffset / BaseOffset / IsEmpty ──────────────────────────────────────── + +func TestReadOffset_tracksAdvancingCursor(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdef")) + r := lb.NewReader() + + assert.Equal(t, 0, r.ReadOffset()) + + _, _ = r.ReadN(3) + assert.Equal(t, 3, r.ReadOffset()) + + _, _ = r.ReadN(3) + assert.Equal(t, 6, r.ReadOffset()) +} + +func TestReadOffset_afterReset(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + r := lb.NewReader() + _, _ = r.ReadN(5) + + r.Reset() + assert.Equal(t, 0, r.ReadOffset()) +} + +func TestBaseOffset_alwaysZero(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("anything")) + r := lb.NewReader() + _, _ = r.ReadN(4) + + assert.Equal(t, 0, r.BaseOffset()) +} + +func TestIsEmpty(t *testing.T) { + lb := NewLargeBuffer() + assert.True(t, lb.IsEmpty()) + + lb.AppendChunk([]byte("x")) + assert.False(t, lb.IsEmpty()) + + r := lb.NewReader() + _, _ = r.ReadN(1) // cursor at end, but buffer is not empty + assert.False(t, lb.IsEmpty()) +} + +// ── findChunk ───────────────────────────────────────────────────────────────── + +func TestFindChunk_singleChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcde")) + + ci, off := lb.findChunk(0) + assert.Equal(t, 0, ci) + assert.Equal(t, 0, off) + + ci, off = lb.findChunk(4) + assert.Equal(t, 0, ci) + assert.Equal(t, 4, off) +} + +func TestFindChunk_multiChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) // offsets 0-2 + lb.AppendChunk([]byte("de")) // offsets 3-4 + lb.AppendChunk([]byte("fgh")) // offsets 5-7 + + ci, off := lb.findChunk(3) + assert.Equal(t, 1, ci) + assert.Equal(t, 0, off) + + ci, off = lb.findChunk(5) + assert.Equal(t, 2, ci) + assert.Equal(t, 0, off) + + ci, off = lb.findChunk(7) + assert.Equal(t, 2, ci) + assert.Equal(t, 2, off) +} + +func TestFindChunk_outOfRange(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + + ci, _ := lb.findChunk(3) + assert.Equal(t, -1, ci) + + ci, _ = lb.findChunk(-1) + assert.Equal(t, -1, ci) +} + +// ── UnsafeViewAt ────────────────────────────────────────────────────────────── + +func TestUnsafeViewAt_withinSingleChunk_zeroCopy(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdefgh")) + + got, err := lb.UnsafeViewAt(2, 3) + require.NoError(t, err) + assert.Equal(t, "cde", string(got)) + // Verify it's a sub-slice of the chunk (zero-copy). + assert.Equal(t, &lb.chunks[0][2], &got[0]) +} + +func TestUnsafeViewAt_crossBoundary_usesScratch(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) // offsets 0-2 + lb.AppendChunk([]byte("def")) // offsets 3-5 + + // Read straddling the boundary. + got, err := lb.UnsafeViewAt(1, 4) // "bcde" + require.NoError(t, err) + assert.Equal(t, "bcde", string(got)) +} + +func TestUnsafeViewAt_doesNotMoveCursor(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + r := lb.NewReader() + + before := r.ReadOffset() + _, err := lb.UnsafeViewAt(1, 3) + require.NoError(t, err) + assert.Equal(t, before, r.ReadOffset()) +} + +func TestUnsafeViewAt_outOfRange_returnsError(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + + _, err := lb.UnsafeViewAt(1, 5) + require.Error(t, err) + + _, err = lb.UnsafeViewAt(-1, 1) + require.Error(t, err) + + _, err = lb.UnsafeViewAt(0, -1) + require.Error(t, err) +} + +func TestUnsafeViewAt_zero_returnsEmpty(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + + got, err := lb.UnsafeViewAt(0, 0) + require.NoError(t, err) + assert.NotNil(t, got) + assert.Empty(t, got) +} + +func TestUnsafeViewAt_scratchReuseSemantics(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + + // First cross-chunk call — allocates scratch (4 bytes). + got1, _ := lb.UnsafeViewAt(1, 4) // "bcde" + scratch1 := lb.scratch + + // Second cross-chunk call with same size — reuses scratch. + got2, _ := lb.UnsafeViewAt(0, 4) // "abcd" + scratch2 := lb.scratch + + assert.Equal(t, &scratch1[0], &scratch2[0], "scratch buffer must be reused") + // got1 is now stale (points at same scratch as got2 but overwritten). + assert.Equal(t, "abcd", string(got2)) + _ = got1 // intentionally not asserted — it is stale +} + +// ── UnsafeView ─────────────────────────────────────────────────────────────── + +func TestUnsafeView_empty_returnsNil(t *testing.T) { + lb := NewLargeBuffer() + assert.Nil(t, lb.UnsafeView()) +} + +func TestUnsafeView_singleChunk_zeroCopy(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + + got := lb.UnsafeView() + assert.Equal(t, "hello", string(got)) + // Single chunk: must be a direct sub-slice of the chunk's backing array. + assert.Equal(t, &lb.chunks[0][0], &got[0], "UnsafeView single-chunk must be zero-copy") +} + +func TestUnsafeView_multiChunk_materializes(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("foo")) + lb.AppendChunk([]byte("bar")) + + got := lb.UnsafeView() + assert.Equal(t, "foobar", string(got)) +} + +func TestUnsafeView_scratchReuseSemantics(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + + lb.UnsafeView() // cross-chunk → allocates scratch + scratch1 := lb.scratch + + lb.UnsafeView() // second call — reuses scratch + scratch2 := lb.scratch + + assert.Equal(t, &scratch1[0], &scratch2[0], "UnsafeView must reuse scratch across calls") +} + +// ── CopyAt ──────────────────────────────────────────────────────────────────── + +func TestCopyAt_withinSingleChunk(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdefgh")) + + dst := make([]byte, 4) + require.NoError(t, lb.CopyAt(2, dst)) + assert.Equal(t, "cdef", string(dst)) +} + +func TestCopyAt_crossBoundary(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abc")) + lb.AppendChunk([]byte("def")) + + dst := make([]byte, 4) + require.NoError(t, lb.CopyAt(1, dst)) + assert.Equal(t, "bcde", string(dst)) +} + +func TestCopyAt_doesNotMoveCursor(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + r := lb.NewReader() + + before := r.ReadOffset() + dst := make([]byte, 3) + require.NoError(t, lb.CopyAt(1, dst)) + assert.Equal(t, before, r.ReadOffset()) +} + +func TestCopyAt_outOfRange_returnsError(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hi")) + + require.Error(t, lb.CopyAt(0, make([]byte, 10))) + require.Error(t, lb.CopyAt(-1, make([]byte, 1))) +} + +func TestCopyAt_alwaysOwned(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("hello")) + + dst := make([]byte, 5) + require.NoError(t, lb.CopyAt(0, dst)) + + // Mutating dst must not affect the chunk. + dst[0] = 'X' + assert.Equal(t, "hello", string(lb.chunks[0])) +} + +// ── Scalar helpers — big-endian ─────────────────────────────────────────────── + +func TestScalarBE_withinSingleChunk(t *testing.T) { + lb := NewLargeBuffer() + // Lay out known bytes at known offsets. + // offset 0: U8 = 0x42 + // offset 1: U16BE = 0x0102 + // offset 3: U32BE = 0x01020304 + // offset 7: U64BE = 0x0102030405060708 + data := []byte{ + 0x42, + 0x01, 0x02, + 0x01, 0x02, 0x03, 0x04, + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + } + lb.AppendChunk(data) + + u8, err := lb.U8At(0) + require.NoError(t, err) + assert.Equal(t, uint8(0x42), u8) + + u16, err := lb.U16BEAt(1) + require.NoError(t, err) + assert.Equal(t, uint16(0x0102), u16) + + u32, err := lb.U32BEAt(3) + require.NoError(t, err) + assert.Equal(t, uint32(0x01020304), u32) + + u64, err := lb.U64BEAt(7) + require.NoError(t, err) + assert.Equal(t, uint64(0x0102030405060708), u64) + + i16, err := lb.I16BEAt(1) + require.NoError(t, err) + assert.Equal(t, int16(0x0102), i16) + + i32, err := lb.I32BEAt(3) + require.NoError(t, err) + assert.Equal(t, int32(0x01020304), i32) + + i64, err := lb.I64BEAt(7) + require.NoError(t, err) + assert.Equal(t, int64(0x0102030405060708), i64) +} + +func TestScalarBE_crossChunkBoundary(t *testing.T) { + // Split so that the U32 straddles the boundary: chunk0=[0x01,0x02], chunk1=[0x03,0x04,...] + lb := NewLargeBuffer() + lb.AppendChunk([]byte{0x01, 0x02}) + lb.AppendChunk([]byte{0x03, 0x04, 0x05, 0x06, 0x07, 0x08}) + + u32, err := lb.U32BEAt(0) + require.NoError(t, err) + assert.Equal(t, uint32(0x01020304), u32) + + u64, err := lb.U64BEAt(0) + require.NoError(t, err) + assert.Equal(t, uint64(0x0102030405060708), u64) +} + +func TestScalarBE_signedNegativeValues(t *testing.T) { + lb := NewLargeBuffer() + // -1 as int16 BE = 0xFF 0xFF + lb.AppendChunk([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + + i16, err := lb.I16BEAt(0) + require.NoError(t, err) + assert.Equal(t, int16(-1), i16) + + i32, err := lb.I32BEAt(0) + require.NoError(t, err) + assert.Equal(t, int32(-1), i32) + + i64, err := lb.I64BEAt(0) + require.NoError(t, err) + assert.Equal(t, int64(-1), i64) +} + +// ── Scalar helpers — little-endian ──────────────────────────────────────────── + +func TestScalarLE_withinSingleChunk(t *testing.T) { + lb := NewLargeBuffer() + // offset 0: U16LE = 0x0201 → value 0x0102 read as LE + // offset 2: U32LE = 0x04030201 + // offset 6: U64LE = 0x0807060504030201 + data := []byte{ + 0x02, 0x01, + 0x04, 0x03, 0x02, 0x01, + 0x08, 0x07, 0x06, 0x05, 0x04, 0x03, 0x02, 0x01, + } + lb.AppendChunk(data) + + u16, err := lb.U16LEAt(0) + require.NoError(t, err) + assert.Equal(t, uint16(0x0102), u16) + + u32, err := lb.U32LEAt(2) + require.NoError(t, err) + assert.Equal(t, uint32(0x01020304), u32) + + u64, err := lb.U64LEAt(6) + require.NoError(t, err) + assert.Equal(t, uint64(0x0102030405060708), u64) + + i16, err := lb.I16LEAt(0) + require.NoError(t, err) + assert.Equal(t, int16(0x0102), i16) + + i32, err := lb.I32LEAt(2) + require.NoError(t, err) + assert.Equal(t, int32(0x01020304), i32) + + i64, err := lb.I64LEAt(6) + require.NoError(t, err) + assert.Equal(t, int64(0x0102030405060708), i64) +} + +func TestScalarLE_crossChunkBoundary(t *testing.T) { + // U32LE straddles boundary: chunk0=[0x04,0x03], chunk1=[0x02,0x01,...] + lb := NewLargeBuffer() + lb.AppendChunk([]byte{0x04, 0x03}) + lb.AppendChunk([]byte{0x02, 0x01, 0x08, 0x07, 0x06, 0x05}) + + u32, err := lb.U32LEAt(0) + require.NoError(t, err) + assert.Equal(t, uint32(0x01020304), u32) + + u64, err := lb.U64LEAt(0) + require.NoError(t, err) + assert.Equal(t, uint64(0x0506070801020304), u64) +} + +func TestScalarLE_signedNegativeValues(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}) + + i16, err := lb.I16LEAt(0) + require.NoError(t, err) + assert.Equal(t, int16(-1), i16) + + i32, err := lb.I32LEAt(0) + require.NoError(t, err) + assert.Equal(t, int32(-1), i32) + + i64, err := lb.I64LEAt(0) + require.NoError(t, err) + assert.Equal(t, int64(-1), i64) +} + +func TestScalar_outOfRange_returnsError(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte{0x01, 0x02}) + + _, err := lb.U8At(2) + require.Error(t, err) + + _, err = lb.U32BEAt(0) // only 2 bytes, needs 4 + require.Error(t, err) + + _, err = lb.U32LEAt(0) + require.Error(t, err) +} + +func TestCursorUnchanged_byAbsoluteOffsetMethods(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdefgh")) + r := lb.NewReader() + + _, _ = r.ReadN(3) // advance cursor to 3 + before := r.ReadOffset() + + _, _ = lb.UnsafeViewAt(0, 4) + assert.Equal(t, before, r.ReadOffset(), "UnsafeViewAt must not move cursor") + + _ = lb.CopyAt(0, make([]byte, 4)) + assert.Equal(t, before, r.ReadOffset(), "CopyAt must not move cursor") + + _, _ = lb.U32BEAt(0) + assert.Equal(t, before, r.ReadOffset(), "U32BEAt must not move cursor") + + _, _ = lb.U32LEAt(0) + assert.Equal(t, before, r.ReadOffset(), "U32LEAt must not move cursor") +} + +// ── Zero-alloc verification for hot path ───────────────────────────────────── + +func TestReadN_singleChunk_zeroAllocsWithBinaryDecode(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk(bytes.Repeat([]byte{0x01}, 64)) + + r := lb.NewReader() + allocs := testing.AllocsPerRun(1000, func() { + r.Reset() + for r.Remaining() >= 4 { + b, _ := r.ReadN(4) + _ = b[0] | b[1] | b[2] | b[3] // simulate scalar decode + } + }) + + assert.InDelta(t, float64(0), allocs, 0, "hot path (single-chunk scalar decoding) must be zero-alloc") +} + +// ── Multiple readers on same buffer ────────────────────────────────────────── + +func TestMultipleReaders_independent(t *testing.T) { + lb := NewLargeBuffer() + lb.AppendChunk([]byte("abcdef")) + + r1 := lb.NewReader() + r2 := lb.NewReader() + + got1, err := r1.ReadN(3) + require.NoError(t, err) + assert.Equal(t, "abc", string(got1)) + + // r2 is unaffected by r1's advance. + got2, err := r2.ReadN(6) + require.NoError(t, err) + assert.Equal(t, "abcdef", string(got2)) + + // r1 can continue from where it left off. + got1, err = r1.ReadN(3) + require.NoError(t, err) + assert.Equal(t, "def", string(got1)) +} diff --git a/pkg/ebpf/common/mongo_detect_transform.go b/pkg/ebpf/common/mongo_detect_transform.go index d8f3122d2..0da36d3ea 100644 --- a/pkg/ebpf/common/mongo_detect_transform.go +++ b/pkg/ebpf/common/mongo_detect_transform.go @@ -359,17 +359,19 @@ func validateFlagBits(flagBits int32) error { return nil } -func mongoInfoFromEvent(event *TCPRequestInfo, requestBuffer []byte, responseBuffer []byte, mongoRequestCache PendingMongoDBRequests) *mongoSpanInfo { +func mongoInfoFromEvent(event *TCPRequestInfo, requestBuffer *LargeBuffer, responseBuffer *LargeBuffer, mongoRequestCache PendingMongoDBRequests) *mongoSpanInfo { if event.Direction == 0 { return nil } + reqRaw := requestBuffer.UnsafeView() + respRaw := responseBuffer.UnsafeView() var mongoRequest *MongoRequestValue var moreToCome bool - _, _, err := ProcessMongoEvent(requestBuffer, int64(event.StartMonotimeNs), int64(event.EndMonotimeNs), event.ConnInfo, mongoRequestCache) + _, _, err := ProcessMongoEvent(reqRaw, int64(event.StartMonotimeNs), int64(event.EndMonotimeNs), event.ConnInfo, mongoRequestCache) if err != nil { return nil } - mongoRequest, moreToCome, err = ProcessMongoEvent(responseBuffer, int64(event.StartMonotimeNs), int64(event.EndMonotimeNs), event.ConnInfo, mongoRequestCache) + mongoRequest, moreToCome, err = ProcessMongoEvent(respRaw, int64(event.StartMonotimeNs), int64(event.EndMonotimeNs), event.ConnInfo, mongoRequestCache) if err != nil || mongoRequest == nil || moreToCome { return nil } diff --git a/pkg/ebpf/common/mqtt_detect_transform.go b/pkg/ebpf/common/mqtt_detect_transform.go index 0a056b44b..418b27d31 100644 --- a/pkg/ebpf/common/mqtt_detect_transform.go +++ b/pkg/ebpf/common/mqtt_detect_transform.go @@ -46,12 +46,12 @@ func packetTypeToMethod(packetType mqttparser.PacketType) string { // ProcessPossibleMQTTEvent processes a TCP packet and returns error if the packet is not a valid MQTT packet. // Otherwise, returns MQTTInfo with the processed data. The ignore bool indicates whether the event // should be ignored for span creation (e.g., control packets like CONNECT). -func ProcessPossibleMQTTEvent(event *TCPRequestInfo, pkt []byte, rpkt []byte) (*MQTTInfo, bool, error) { - m, ignore, err := ProcessMQTTEvent(pkt) +func ProcessPossibleMQTTEvent(event *TCPRequestInfo, pkt *LargeBuffer, rpkt *LargeBuffer) (*MQTTInfo, bool, error) { + m, ignore, err := ProcessMQTTEvent(pkt.UnsafeView()) if err != nil { // If we are getting the information in the response buffer, the event // must be reversed and that's how we captured it. - m, ignore, err = ProcessMQTTEvent(rpkt) + m, ignore, err = ProcessMQTTEvent(rpkt.UnsafeView()) if err == nil && !ignore { reverseTCPEvent(event) } @@ -177,8 +177,11 @@ func processConnectPacket(pkt []byte, offset int) (*MQTTInfo, bool, error) { // isMQTT performs a quick heuristic check to determine if the packet looks like MQTT. // This is used for userspace protocol detection when the kernel hasn't classified the protocol. -func isMQTT(pkt []byte) bool { - _, err := mqttparser.NewMQTTControlPacket(pkt) +func isMQTT(pkt *LargeBuffer) bool { + if pkt == nil { + return false + } + _, err := mqttparser.NewMQTTControlPacket(pkt.UnsafeView()) return err == nil } diff --git a/pkg/ebpf/common/mqtt_detect_transform_test.go b/pkg/ebpf/common/mqtt_detect_transform_test.go index 46e52d633..fa18bf7d1 100644 --- a/pkg/ebpf/common/mqtt_detect_transform_test.go +++ b/pkg/ebpf/common/mqtt_detect_transform_test.go @@ -295,7 +295,7 @@ func TestProcessPossibleMQTTEvent(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { event := &TCPRequestInfo{} - res, _, err := ProcessPossibleMQTTEvent(event, tt.request, tt.response) + res, _, err := ProcessPossibleMQTTEvent(event, NewLargeBufferFrom(tt.request), NewLargeBufferFrom(tt.response)) if tt.err { assert.Error(t, err) return @@ -429,7 +429,7 @@ func TestIsMQTT(t *testing.T) { validPacket := []byte{0xC0, 0x00} // PINGREQ - minimal valid MQTT packet invalidPacket := []byte{0x00, 0x00} // Reserved packet type (invalid) - assert.True(t, isMQTT(validPacket), "valid MQTT packet should return true") - assert.False(t, isMQTT(invalidPacket), "invalid packet should return false") + assert.True(t, isMQTT(NewLargeBufferFrom(validPacket)), "valid MQTT packet should return true") + assert.False(t, isMQTT(NewLargeBufferFrom(invalidPacket)), "invalid packet should return false") assert.False(t, isMQTT(nil), "nil packet should return false") } diff --git a/pkg/ebpf/common/redis_detect_transform.go b/pkg/ebpf/common/redis_detect_transform.go index 6bcadf266..c5f877e97 100644 --- a/pkg/ebpf/common/redis_detect_transform.go +++ b/pkg/ebpf/common/redis_detect_transform.go @@ -32,12 +32,11 @@ var redisErrorCodes = [...]string{ "READONLY ", } -func isRedis(buf []uint8) bool { - if len(buf) < minRedisFrameLen { +func isRedis(buf *LargeBuffer) bool { + if buf.Len() < minRedisFrameLen { return false } - - return isRedisOp(buf) + return isRedisOp(buf.UnsafeView()) } //nolint:cyclop @@ -183,17 +182,19 @@ func parseRedisRequest(buf string) (string, string, bool) { return op, strings.TrimSpace(text.String()), true } -func redisStatus(buf []byte) (request.DBError, int) { - status := 0 - firstChar := buf[0] - if firstChar != '-' { - return request.DBError{}, status +func redisStatus(buf *LargeBuffer) (request.DBError, int) { + if buf.Len() == 0 { + return request.DBError{}, 0 + } + data := buf.UnsafeView() + if data[0] != '-' { + return request.DBError{}, 0 } - dbError, isError := getRedisError(buf[1:]) + dbError, isError := getRedisError(data[1:]) + status := 0 if isError { status = 1 } - return dbError, status } diff --git a/pkg/ebpf/common/redis_detect_transform_test.go b/pkg/ebpf/common/redis_detect_transform_test.go index 6f706cd10..cba6effbd 100644 --- a/pkg/ebpf/common/redis_detect_transform_test.go +++ b/pkg/ebpf/common/redis_detect_transform_test.go @@ -75,8 +75,8 @@ func TestRedisParsing(t *testing.T) { func TestIsRedis(t *testing.T) { buf := []byte{42, 51, 13, 10, 36, 52, 13, 10, 72, 71, 69, 84, 13, 10, 36, 51, 54, 13, 10, 56, 97, 100, 48, 101, 56, 99, 97, 45, 101, 97, 49, 57, 45, 52, 50, 97, 57, 45, 98, 51, 55, 48, 45, 98, 99, 97, 102, 102, 50, 55, 54, 55, 98, 56, 54, 13, 10, 36, 52, 13, 10, 99, 97, 114, 116, 13, 10, 103, 58, 32, 34, 51, 49, 117, 50, 107, 97, 100, 98, 108, 113, 53, 106, 34, 13, 10, 99, 111, 110, 116, 101, 110, 116, 45, 108, 101, 110, 103, 116, 104, 58, 32, 49, 57, 57, 13, 10, 118, 97, 114, 121, 58, 32, 65, 99, 99, 101, 112, 116, 45, 69, 110, 99, 111, 100, 105, 110, 103, 13, 10, 100, 97, 116, 101, 58, 32, 87, 101, 100, 44, 32, 48, 51, 32, 74, 117, 108, 32, 50, 48, 50, 52, 32, 49, 55, 58, 52, 54, 58, 49, 55, 32, 71, 77, 84, 13, 10, 120, 45, 101, 110, 118, 111, 121, 45, 117, 112, 115, 116, 114, 101, 97, 109, 45, 115, 101, 114, 118, 105, 99, 101, 45, 116, 105, 109, 101, 58, 32, 51, 13, 10, 115, 101, 114, 118, 101, 114, 58, 32, 101, 110, 118, 111, 121, 13, 10, 13, 10, 91, 34, 90, 65, 82, 34, 44, 34, 73, 83, 75, 34, 44, 34, 73, 76, 83, 34, 44, 34, 82, 79, 78, 34, 44, 34, 71, 66, 80, 34, 44, 34, 66, 82, 76, 34, 44, 34} rbuf := []byte{36, 45, 49, 13, 10, 1, 0, 15, 0, 3, 89, 130, 0, 32, 99, 111, 110, 115, 117, 109, 101, 114, 45, 102, 114, 97, 117, 100, 100, 101, 116, 101, 99, 116, 105, 111, 110, 115, 101, 114, 118, 105, 99, 101, 45, 49, 0, 0, 0, 1, 244, 0, 0, 0, 1, 3, 32, 0, 0, 0, 17, 170, 173, 222, 0, 0, 141, 2, 1, 1, 1, 0, 101, 112, 116, 45, 114, 97, 110, 103, 101, 115, 58, 32, 98, 121, 116, 101, 115, 13, 10, 108, 97, 115, 116, 45, 109, 111, 100, 105, 102, 105, 101, 100, 58, 32, 70, 114, 105, 44, 32, 48, 55, 32, 74, 117, 110, 32, 50, 48, 50, 52, 32, 48, 48, 58, 53, 55} - assert.True(t, isRedis(buf)) - assert.True(t, isRedis(rbuf)) + assert.True(t, isRedis(NewLargeBufferFrom(buf))) + assert.True(t, isRedis(NewLargeBufferFrom(rbuf))) } func TestGetRedisDb(t *testing.T) { diff --git a/pkg/ebpf/common/sql_detect_mysql.go b/pkg/ebpf/common/sql_detect_mysql.go index 9fe176f05..d5199f299 100644 --- a/pkg/ebpf/common/sql_detect_mysql.go +++ b/pkg/ebpf/common/sql_detect_mysql.go @@ -84,23 +84,25 @@ func mysqlPreparedStatements(b []byte) (string, string, string) { return op, table, sql } -func handleMySQL(parseCtx *EBPFParseContext, event *TCPRequestInfo, requestBuffer, responseBuffer []byte) (request.Span, error) { +func handleMySQL(parseCtx *EBPFParseContext, event *TCPRequestInfo, requestBuffer, responseBuffer *LargeBuffer) (request.Span, error) { var ( op, table, stmt string span request.Span ) - if len(requestBuffer) < sqlprune.MySQLHdrSize+1 { + if requestBuffer.Len() < sqlprune.MySQLHdrSize+1 { slog.Debug("MySQL request too short") return span, errFallback } - if len(responseBuffer) < sqlprune.MySQLHdrSize+1 { + if responseBuffer.Len() < sqlprune.MySQLHdrSize+1 { slog.Debug("MySQL response too short") return span, errFallback } + reqRaw := requestBuffer.UnsafeView() + respRaw := responseBuffer.UnsafeView() - sqlCommand := sqlprune.SQLParseCommandID(request.DBMySQL, requestBuffer) - sqlError := sqlprune.SQLParseError(request.DBMySQL, responseBuffer) + sqlCommand := sqlprune.SQLParseCommandID(request.DBMySQL, reqRaw) + sqlError := sqlprune.SQLParseError(request.DBMySQL, respRaw) switch sqlCommand { case "STMT_PREPARE": @@ -111,13 +113,13 @@ func handleMySQL(parseCtx *EBPFParseContext, event *TCPRequestInfo, requestBuffe // On the PREPARE command, the statement ID is the first 4 bytes after the header and command ID // in the response buffer. - stmtID := sqlprune.SQLParseStatementID(request.DBMySQL, responseBuffer) + stmtID := sqlprune.SQLParseStatementID(request.DBMySQL, respRaw) if stmtID == 0 { slog.Debug("MySQL PREPARE command with invalid statement ID") return span, errFallback } - _, _, stmt = detectSQL(string(requestBuffer[sqlprune.MySQLHdrSize+1:])) + _, _, stmt = detectSQL(string(reqRaw[sqlprune.MySQLHdrSize+1:])) parseCtx.mysqlPreparedStatements.Add(mysqlPreparedStatementsKey{ connInfo: event.ConnInfo, stmtID: stmtID, @@ -127,7 +129,7 @@ func handleMySQL(parseCtx *EBPFParseContext, event *TCPRequestInfo, requestBuffe case "STMT_EXECUTE": // On the EXECUTE command, the statement ID is the first 4 bytes after the header and command ID // in the request buffer. - stmtID := sqlprune.SQLParseStatementID(request.DBMySQL, requestBuffer) + stmtID := sqlprune.SQLParseStatementID(request.DBMySQL, reqRaw) if stmtID == 0 { slog.Debug("MySQL EXECUTE command with invalid statement ID") return span, errFallback @@ -144,9 +146,9 @@ func handleMySQL(parseCtx *EBPFParseContext, event *TCPRequestInfo, requestBuffe } op, table = sqlprune.SQLParseOperationAndTable(stmt) case "QUERY": - op, table, stmt = detectSQL(string(requestBuffer[sqlprune.MySQLHdrSize+1:])) + op, table, stmt = detectSQL(string(reqRaw[sqlprune.MySQLHdrSize+1:])) default: - slog.Debug("MySQL command ID unhandled", "commandID", requestBuffer[sqlprune.MySQLHdrSize]) + slog.Debug("MySQL command ID unhandled", "commandID", reqRaw[sqlprune.MySQLHdrSize]) return span, errFallback } diff --git a/pkg/ebpf/common/sql_detect_postgres.go b/pkg/ebpf/common/sql_detect_postgres.go index 1810b1be0..7f2d1767c 100644 --- a/pkg/ebpf/common/sql_detect_postgres.go +++ b/pkg/ebpf/common/sql_detect_postgres.go @@ -192,7 +192,7 @@ type postgresMessage struct { } type postgresMessageIterator struct { - buf []byte + r *LargeBufferReader err error eof bool } @@ -202,19 +202,24 @@ func (it *postgresMessageIterator) isEOF() bool { } func (it *postgresMessageIterator) next() (msg postgresMessage) { - if it.err != nil || len(it.buf) == 0 { + if it.err != nil || it.r.Remaining() == 0 { it.eof = true return } - if len(it.buf) < sqlprune.PostgresHdrSize { + if it.r.Remaining() < sqlprune.PostgresHdrSize { it.err = errors.New("remaining buffer too short for message header") return } - msgType := sqlprune.SQLParseCommandID(request.DBPostgres, it.buf) - it.buf = it.buf[1:] - size := int32(binary.BigEndian.Uint32(it.buf[:4])) - it.buf = it.buf[4:] + // Read the 5-byte header (type byte + 4-byte size) atomically. + // SQLParseCommandID needs buf[0] as the type byte; it requires len(buf) >= PostgresHdrSize (5). + hdrBuf, err := it.r.ReadN(sqlprune.PostgresHdrSize) + if err != nil { + it.err = err + return + } + msgType := sqlprune.SQLParseCommandID(request.DBPostgres, hdrBuf) + size := int32(binary.BigEndian.Uint32(hdrBuf[1:5])) if size < sqlprune.PostgresHdrSize-1 { it.err = errors.New("malformed Postgres message") @@ -222,38 +227,50 @@ func (it *postgresMessageIterator) next() (msg postgresMessage) { } payloadSize := size - sqlprune.PostgresHdrSize + 1 - if len(it.buf) < int(payloadSize) { - it.err = fmt.Errorf("remaining buffer too short for message data: expected %d bytes, got %d", payloadSize, len(it.buf)) + if it.r.Remaining() < int(payloadSize) { + it.err = fmt.Errorf("remaining buffer too short for message data: expected %d bytes, got %d", payloadSize, it.r.Remaining()) return } - data := it.buf[:payloadSize] - it.buf = it.buf[payloadSize:] + // ReadN is safe: all uses of msg.data convert it to a Go string before the next + // it.next() call, so scratch overwrite between iterations is not a concern. + // Use empty non-nil slice for zero-length payloads to match []byte{} semantics. + data := []byte{} + if payloadSize > 0 { + data, err = it.r.ReadN(int(payloadSize)) + if err != nil { + it.err = err + return + } + } msg = postgresMessage{typ: msgType, data: data} return } -func handlePostgres(parseCtx *EBPFParseContext, event *TCPRequestInfo, requestBuffer, responseBuffer []byte) (request.Span, error) { +func handlePostgres(parseCtx *EBPFParseContext, event *TCPRequestInfo, requestBuffer, responseBuffer *LargeBufferReader) (request.Span, error) { var ( hasSpan bool op, table, stmt string span request.Span ) - if len(requestBuffer) < sqlprune.PostgresHdrSize+1 { + if requestBuffer.Remaining() < sqlprune.PostgresHdrSize+1 { slog.Debug("Postgres request too short") return span, errFallback } - if len(responseBuffer) < sqlprune.PostgresHdrSize+1 { + if responseBuffer.Remaining() < sqlprune.PostgresHdrSize+1 { slog.Debug("Postgres response too short") return span, errFallback } + // ReadN(remaining) for response — materialized once for sqlprune.SQLParseError. + respRaw, _ := responseBuffer.ReadN(responseBuffer.Remaining()) + var ( msg postgresMessage - it = &postgresMessageIterator{buf: requestBuffer} - sqlError = sqlprune.SQLParseError(request.DBPostgres, responseBuffer) + it = &postgresMessageIterator{r: requestBuffer} + sqlError = sqlprune.SQLParseError(request.DBPostgres, respRaw) ) Loop: diff --git a/pkg/ebpf/common/sql_detect_postgres_test.go b/pkg/ebpf/common/sql_detect_postgres_test.go index 248b215b1..57f5aded7 100644 --- a/pkg/ebpf/common/sql_detect_postgres_test.go +++ b/pkg/ebpf/common/sql_detect_postgres_test.go @@ -82,7 +82,7 @@ func TestPostgresMessagesIterator(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got []postgresMessage - it := &postgresMessageIterator{buf: tt.buf} + it := &postgresMessageIterator{r: NewLargeBufferFrom(tt.buf).NewReader()} for { msg := it.next() if it.isEOF() { @@ -112,8 +112,11 @@ func TestPostgresMessagesIteratorNoAllocs(t *testing.T) { return b }() + lb := NewLargeBufferFrom(buf) + r := lb.NewReader() allocs := testing.AllocsPerRun(1000, func() { - it := &postgresMessageIterator{buf: buf} + r.Reset() + it := postgresMessageIterator{r: r} for { it.next() diff --git a/pkg/ebpf/common/sql_detect_transform.go b/pkg/ebpf/common/sql_detect_transform.go index 8b8738ca6..1148028c8 100644 --- a/pkg/ebpf/common/sql_detect_transform.go +++ b/pkg/ebpf/common/sql_detect_transform.go @@ -58,20 +58,21 @@ func isASCII(s string) bool { return true } -func detectSQLPayload(useHeuristics bool, b []byte) (string, string, string, request.SQLKind) { - sqlKind := sqlKind(b) +func detectSQLPayload(useHeuristics bool, b *LargeBuffer) (string, string, string, request.SQLKind) { + raw := b.UnsafeView() + sqlKind := sqlKind(raw) if !useHeuristics { if sqlKind == request.DBGeneric { return "", "", "", sqlKind } } - op, table, sql := detectSQL(string(b)) + op, table, sql := detectSQL(string(raw)) if !validSQL(op, table, sqlKind) { switch sqlKind { case request.DBPostgres: - op, table, sql = postgresPreparedStatements(b) + op, table, sql = postgresPreparedStatements(raw) case request.DBMySQL: - op, table, sql = mysqlPreparedStatements(b) + op, table, sql = mysqlPreparedStatements(raw) } } diff --git a/pkg/ebpf/common/sql_detect_transform_test.go b/pkg/ebpf/common/sql_detect_transform_test.go index c28bddbe7..1884005ae 100644 --- a/pkg/ebpf/common/sql_detect_transform_test.go +++ b/pkg/ebpf/common/sql_detect_transform_test.go @@ -185,12 +185,12 @@ func TestPostgresQueryParsing(t *testing.T) { }, } { t.Run(ts.name, func(t *testing.T) { - op, table, sql, _ := detectSQLPayload(false, ts.bytes) + op, table, sql, _ := detectSQLPayload(false, NewLargeBufferFrom(ts.bytes)) assert.Equal(t, ts.op, op) assert.Equal(t, ts.table, table) assert.Equal(t, ts.sql, sql) - op, table, sql, _ = detectSQLPayload(true, ts.bytes) + op, table, sql, _ = detectSQLPayload(true, NewLargeBufferFrom(ts.bytes)) assert.Equal(t, ts.op, op) assert.Equal(t, ts.table, table) assert.Equal(t, ts.sql, sql) diff --git a/pkg/ebpf/common/tcp_detect_transform.go b/pkg/ebpf/common/tcp_detect_transform.go index 3338f8a53..c6849639f 100644 --- a/pkg/ebpf/common/tcp_detect_transform.go +++ b/pkg/ebpf/common/tcp_detect_transform.go @@ -43,14 +43,14 @@ func ReadTCPRequestIntoSpan(parseCtx *EBPFParseContext, cfg *config.EBPFTracer, requestBuffer, responseBuffer := getBuffers(parseCtx, event) if cfg.ProtocolDebug { - fmt.Printf("[>] %v\n", requestBuffer) - fmt.Printf("[<] %v\n", responseBuffer) + fmt.Printf("[>] %v\n", requestBuffer.UnsafeView()) + fmt.Printf("[<] %v\n", responseBuffer.UnsafeView()) } // We might know already the protocol for this event switch event.ProtocolType { case ProtocolTypeKafka: - k, ignore, err := ProcessPossibleKafkaEvent(event, requestBuffer, responseBuffer, parseCtx.kafkaTopicUUIDToName) + k, ignore, err := ProcessPossibleKafkaEvent(event, requestBuffer.NewReader(), responseBuffer.NewReader(), parseCtx.kafkaTopicUUIDToName) if ignore && err == nil { return request.Span{}, true, nil // parsed kafka event, but we don't want to create a span for it } @@ -82,7 +82,7 @@ func ReadTCPRequestIntoSpan(parseCtx *EBPFParseContext, cfg *config.EBPFTracer, return span, false, nil case ProtocolTypePostgres: - span, err := handlePostgres(parseCtx, event, requestBuffer, responseBuffer) + span, err := handlePostgres(parseCtx, event, requestBuffer.NewReader(), responseBuffer.NewReader()) if errors.Is(err, errFallback) { slog.Debug("Postgres: falling back to generic handler") break @@ -136,13 +136,13 @@ func ReadTCPRequestIntoSpan(parseCtx *EBPFParseContext, cfg *config.EBPFTracer, switch { case isRedis(requestBuffer) && isRedis(responseBuffer): - op, text, ok := parseRedisRequest(string(requestBuffer)) + op, text, ok := parseRedisRequest(string(requestBuffer.UnsafeView())) if ok { var status int var redisErr request.DBError if op == "" { - op, text, ok = parseRedisRequest(string(responseBuffer)) + op, text, ok = parseRedisRequest(string(responseBuffer.UnsafeView())) if !ok || op == "" { return request.Span{}, true, nil // ignore if we couldn't parse it } @@ -179,7 +179,7 @@ func ReadTCPRequestIntoSpan(parseCtx *EBPFParseContext, cfg *config.EBPFTracer, return request.Span{}, true, nil // ignore for now, next event will be parsed } else { // we should not arrive here, leave it for completeness - k, ignore, err := ProcessPossibleKafkaEvent(event, requestBuffer, responseBuffer, parseCtx.kafkaTopicUUIDToName) + k, ignore, err := ProcessPossibleKafkaEvent(event, requestBuffer.NewReader(), responseBuffer.NewReader(), parseCtx.kafkaTopicUUIDToName) if ignore && err == nil { return request.Span{}, true, nil // parsed kafka event, but we don't want to create a span for it } @@ -190,25 +190,25 @@ func ReadTCPRequestIntoSpan(parseCtx *EBPFParseContext, cfg *config.EBPFTracer, } if cfg.ProtocolDebug { - fmt.Printf("![>] %v\n", requestBuffer) - fmt.Printf("![<] %v\n", responseBuffer) + fmt.Printf("![>] %v\n", requestBuffer.UnsafeView()) + fmt.Printf("![<] %v\n", responseBuffer.UnsafeView()) } return request.Span{}, true, nil // ignore if we couldn't parse it } -func getBuffers(parseCtx *EBPFParseContext, event *TCPRequestInfo) (req []byte, resp []byte) { +func getBuffers(parseCtx *EBPFParseContext, event *TCPRequestInfo) (req *LargeBuffer, resp *LargeBuffer) { l := int(event.Len) if l < 0 || len(event.Buf) < l { l = len(event.Buf) } - req = event.Buf[:l] + req = NewLargeBufferFrom(event.Buf[:l]) l = int(event.RespLen) if l < 0 || len(event.Rbuf) < l { l = len(event.Rbuf) } - resp = event.Rbuf[:l] + resp = NewLargeBufferFrom(event.Rbuf[:l]) if event.HasLargeBuffers == 1 { if b, ok := extractTCPLargeBuffer(parseCtx, event.Tp.TraceId, packetTypeRequest, directionByPacketType(packetTypeRequest, !event.IsServer), event.ConnInfo); ok { diff --git a/pkg/ebpf/common/tcp_detect_transform_test.go b/pkg/ebpf/common/tcp_detect_transform_test.go index 7b91fdfc9..da17cdfd2 100644 --- a/pkg/ebpf/common/tcp_detect_transform_test.go +++ b/pkg/ebpf/common/tcp_detect_transform_test.go @@ -168,7 +168,7 @@ func TestRedisDetection(t *testing.T) { } { lines := strings.Split(s, "|") test := strings.Join(lines, "\r\n") - assert.True(t, isRedis([]uint8(test))) + assert.True(t, isRedis(NewLargeBufferFrom([]uint8(test)))) assert.True(t, isRedisOp([]uint8(test))) } @@ -182,7 +182,7 @@ func TestRedisDetection(t *testing.T) { } { lines := strings.Split(s, "|") test := strings.Join(lines, "\r\n") - assert.False(t, isRedis([]uint8(test))) + assert.False(t, isRedis(NewLargeBufferFrom([]uint8(test)))) assert.False(t, isRedisOp([]uint8(test))) } } @@ -191,7 +191,7 @@ func TestTCPReqKafkaParsing(t *testing.T) { // kafka message b := []byte{0, 0, 0, 94, 0, 1, 0, 11, 0, 0, 0, 224, 0, 6, 115, 97, 114, 97, 109, 97, 255, 255, 255, 255, 0, 0, 1, 244, 0, 0, 0, 1, 6, 64, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 0, 0, 0, 1, 0, 9, 105, 109, 112, 111, 114, 116, 97, 110, 116, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, 0, 0, 0, 0, 0, 0, 0} r := makeTCPReq(string(b), 343534) - k, _, err := ProcessKafkaRequest(b, nil) + k, _, err := ProcessKafkaRequest(NewLargeBufferFrom(b).NewReader(), nil) require.NoError(t, err) s := TCPToKafkaToSpan(&r, k) assert.NotNil(t, s) @@ -240,7 +240,7 @@ func TestTCPReqMQTTHeuristicFailure(t *testing.T) { } // Verify the heuristic passes but full parsing fails - assert.True(t, isMQTT(b), "packet should pass isMQTT heuristic") + assert.True(t, isMQTT(NewLargeBufferFrom(b)), "packet should pass isMQTT heuristic") _, _, err := ProcessMQTTEvent(b) require.Error(t, err, "full MQTT parsing should fail") diff --git a/pkg/ebpf/common/tcp_large_buffer.go b/pkg/ebpf/common/tcp_large_buffer.go index 9d2b81678..5cc095e52 100644 --- a/pkg/ebpf/common/tcp_large_buffer.go +++ b/pkg/ebpf/common/tcp_large_buffer.go @@ -11,16 +11,11 @@ import ( "go.opentelemetry.io/obi/pkg/internal/ebpf/ringbuf" ) -type ( - largeBufferKey struct { - traceID [16]uint8 - packetType, direction uint8 - connInfo BpfConnectionInfoT - } - largeBuffer struct { - buf []byte - } -) +type largeBufferKey struct { + traceID [16]uint8 + packetType, direction uint8 + connInfo BpfConnectionInfoT +} const ( largeBufferActionInit = iota @@ -43,22 +38,26 @@ func appendTCPLargeBuffer(parseCtx *EBPFParseContext, record *ringbuf.Record) (r } if parseCtx.protocolDebug { - fmt.Printf(">>> LargeBufferAppend: (packet=%d direction=%d action=%d size=%d)\n%s\n", event.PacketType, event.Direction, event.Action, event.Len, string(record.RawSample[hdrSize:hdrSize+event.Len])) + fmt.Printf(">>> LargeBufferAppend: (packet=%d direction=%d action=%d size=%d)\n%s\n", + event.PacketType, event.Direction, event.Action, event.Len, + string(record.RawSample[hdrSize:hdrSize+event.Len])) } + chunk := record.RawSample[hdrSize : hdrSize+event.Len] + switch event.Action { case largeBufferActionInit: - newBuffer := make([]byte, event.Len) - copy(newBuffer, record.RawSample[hdrSize:]) - parseCtx.largeBuffers.Add(key, &largeBuffer{ - buf: newBuffer, - }) + lb := NewLargeBuffer() + lb.AppendChunk(chunk) + parseCtx.largeBuffers.Add(key, lb) + case largeBufferActionAppend: lb, ok := parseCtx.largeBuffers.Get(key) if !ok { return request.Span{}, true, nil } - lb.buf = append(lb.buf, record.RawSample[hdrSize:hdrSize+event.Len]...) + lb.AppendChunk(chunk) + default: return request.Span{}, true, fmt.Errorf("invalid large buffer action: %d", event.Action) } @@ -66,7 +65,12 @@ func appendTCPLargeBuffer(parseCtx *EBPFParseContext, record *ringbuf.Record) (r return request.Span{}, true, nil } -func extractTCPLargeBuffer(parseCtx *EBPFParseContext, traceID [16]uint8, packetType, direction uint8, connInfo BpfConnectionInfoT) ([]byte, bool) { +func extractTCPLargeBuffer( + parseCtx *EBPFParseContext, + traceID [16]uint8, + packetType, direction uint8, + connInfo BpfConnectionInfoT, +) (*LargeBuffer, bool) { key := largeBufferKey{ traceID: traceID, packetType: packetType, @@ -74,18 +78,20 @@ func extractTCPLargeBuffer(parseCtx *EBPFParseContext, traceID [16]uint8, packet connInfo: connInfo, } - //nolint:gocritic - if lb, ok := parseCtx.largeBuffers.Get(key); ok { + lb, ok := parseCtx.largeBuffers.Get(key) + if !ok { if parseCtx.protocolDebug { - fmt.Printf("<<< LargeBufferExtract: (packet=%d direction=%d len=%d)\n%s\n", key.packetType, key.direction, len(lb.buf), string(lb.buf)) - } - parseCtx.largeBuffers.Remove(key) - return lb.buf, true - } else { - if parseCtx.protocolDebug { - fmt.Printf("<<< LargeBufferExtract: not found!(packet=%d direction=%d)\n", key.packetType, key.direction) + fmt.Printf("<<< LargeBufferExtract: not found! (packet=%d direction=%d)\n", key.packetType, key.direction) } + return nil, false + } + + if parseCtx.protocolDebug { + fmt.Printf("<<< LargeBufferExtract: (packet=%d direction=%d len=%d)\n%s\n", + key.packetType, key.direction, lb.Len(), lb.UnsafeView()) } - return nil, false + parseCtx.largeBuffers.Remove(key) + + return lb, true } diff --git a/pkg/ebpf/common/tcp_large_buffer_test.go b/pkg/ebpf/common/tcp_large_buffer_test.go index 6042c91d4..dd8f6d177 100644 --- a/pkg/ebpf/common/tcp_large_buffer_test.go +++ b/pkg/ebpf/common/tcp_large_buffer_test.go @@ -22,7 +22,7 @@ func TestTCPLargeBuffers(t *testing.T) { verifyLargeBuffer := func(traceID [16]uint8, packetType, direction uint8, connInfo BpfConnectionInfoT, expectedBuf string) { buf, ok := extractTCPLargeBuffer(pctx, traceID, packetType, direction, connInfo) require.True(t, ok, "Expected to find large buffer") - require.Equal(t, expectedBuf, unix.ByteSliceToString(buf), "Buffer content mismatch") + require.Equal(t, expectedBuf, unix.ByteSliceToString(buf.UnsafeView()), "Buffer content mismatch") } firstEvent := TCPLargeBufferHeader{ diff --git a/pkg/internal/ebpf/kafkaparser/common.go b/pkg/internal/ebpf/kafkaparser/common.go index 3df5ff671..443294fac 100644 --- a/pkg/internal/ebpf/kafkaparser/common.go +++ b/pkg/internal/ebpf/kafkaparser/common.go @@ -33,10 +33,7 @@ const ( APIKeyMetadata KafkaAPIKey = 3 ) -type ( - UUID [UUIDLen]byte - Offset = int -) +type UUID [UUIDLen]byte type KafkaRequestHeader struct { MessageSize int32 @@ -51,88 +48,121 @@ type KafkaResponseHeader struct { CorrelationID int32 } -func ParseKafkaRequestHeader(pkt []byte) (*KafkaRequestHeader, Offset, error) { - if len(pkt) < MinKafkaRequestLen { - return nil, 0, errors.New("packet too short for Kafka request header") +// byteReader is the sequential-read interface satisfied by *LargeBuffer. +// Defined here so sub-packages don't need to import ebpfcommon (which would be circular). +type byteReader interface { + ReadN(n int) ([]byte, error) + Peek(n int) ([]byte, error) + Skip(n int) error + Remaining() int +} + +func ParseKafkaRequestHeader(r byteReader) (*KafkaRequestHeader, error) { + if r.Remaining() < MinKafkaRequestLen { + return nil, errors.New("packet too short for Kafka request header") + } + + msgSizeBytes, err := r.ReadN(Int32Len) + if err != nil { + return nil, err + } + apiKeyBytes, err := r.ReadN(Int16Len) + if err != nil { + return nil, err + } + apiVersionBytes, err := r.ReadN(Int16Len) + if err != nil { + return nil, err + } + correlationIDBytes, err := r.ReadN(Int32Len) + if err != nil { + return nil, err } header := &KafkaRequestHeader{ - MessageSize: int32(binary.BigEndian.Uint32(pkt[0:4])), - APIKey: KafkaAPIKey(int16(binary.BigEndian.Uint16(pkt[4:6]))), - APIVersion: int16(binary.BigEndian.Uint16(pkt[6:8])), - CorrelationID: int32(binary.BigEndian.Uint32(pkt[8:12])), + MessageSize: int32(binary.BigEndian.Uint32(msgSizeBytes)), + APIKey: KafkaAPIKey(int16(binary.BigEndian.Uint16(apiKeyBytes))), + APIVersion: int16(binary.BigEndian.Uint16(apiVersionBytes)), + CorrelationID: int32(binary.BigEndian.Uint32(correlationIDBytes)), } - clientIDSize := int16(binary.BigEndian.Uint16(pkt[12:14])) - err := validateKafkaRequestHeader(header) + clientIDSizeBytes, err := r.ReadN(Int16Len) if err != nil { - return nil, 0, err + return nil, err + } + clientIDSize := int16(binary.BigEndian.Uint16(clientIDSizeBytes)) + + if err := validateKafkaRequestHeader(header); err != nil { + return nil, err } if clientIDSize < 0 { - return nil, 0, errors.New("invalid client ID size") + return nil, errors.New("invalid client ID size") } - offset := MinKafkaRequestLen if clientIDSize == 0 { header.ClientID = "" - return header, offset, nil + return header, nil } - if offset+int(clientIDSize) > len(pkt) { - return nil, 0, errors.New("packet too short for client ID") + if r.Remaining() < int(clientIDSize) { + return nil, errors.New("packet too short for client ID") } - header.ClientID = string(pkt[offset : offset+int(clientIDSize)]) - offset += int(clientIDSize) - offset, err = skipTaggedFields(pkt, header, offset) + clientIDBytes, err := r.ReadN(int(clientIDSize)) if err != nil { - return nil, 0, err + return nil, err + } + header.ClientID = string(clientIDBytes) + + if err := skipTaggedFields(r, header); err != nil { + return nil, err } - return header, offset, nil + return header, nil } -func ParseKafkaResponseHeader(pkt []byte, requestHeader *KafkaRequestHeader) (*KafkaResponseHeader, Offset, error) { - if len(pkt) < MinKafkaResponseLen { - return nil, 0, errors.New("packet too short for Kafka response header") +func ParseKafkaResponseHeader(r byteReader, requestHeader *KafkaRequestHeader) (*KafkaResponseHeader, error) { + if r.Remaining() < MinKafkaResponseLen { + return nil, errors.New("packet too short for Kafka response header") + } + msgSizeBytes, err := r.ReadN(Int32Len) + if err != nil { + return nil, err + } + correlationIDBytes, err := r.ReadN(Int32Len) + if err != nil { + return nil, err } header := &KafkaResponseHeader{ - MessageSize: int32(binary.BigEndian.Uint32(pkt[0:4])), - CorrelationID: int32(binary.BigEndian.Uint32(pkt[4:8])), + MessageSize: int32(binary.BigEndian.Uint32(msgSizeBytes)), + CorrelationID: int32(binary.BigEndian.Uint32(correlationIDBytes)), } - offset := MinKafkaResponseLen - err := validateKafkaResponseHeader(header, requestHeader) - if err != nil { - return nil, 0, err + if err := validateKafkaResponseHeader(header, requestHeader); err != nil { + return nil, err } - offset, err = skipTaggedFields(pkt, requestHeader, offset) - if err != nil { - return nil, 0, err + if err := skipTaggedFields(r, requestHeader); err != nil { + return nil, err } - return header, offset, nil + return header, nil } -func skipTaggedFields(pkt []byte, header *KafkaRequestHeader, offset Offset) (Offset, error) { +func skipTaggedFields(r byteReader, header *KafkaRequestHeader) error { if !isFlexible(header) { - return offset, nil // no tagged fields to skip for non-flexible versions + return nil // no tagged fields to skip for non-flexible versions } - taggedFieldsLen, offset, err := readUnsignedVarint(pkt[offset:], offset) + taggedFieldsLen, err := readUnsignedVarint(r) if err != nil { - return 0, err + return err } - for range taggedFieldsLen { - _, offset, err = readUnsignedVarint(pkt[offset:], offset) // read tag ID - if err != nil { - return 0, err + if _, err = readUnsignedVarint(r); err != nil { // read tag ID + return err } - var tagLen int - tagLen, offset, err = readUnsignedVarint(pkt[offset:], offset) // read tag length + tagLen, err := readUnsignedVarint(r) // read tag length if err != nil { - return 0, err + return err } - offset, err = skipBytes(pkt, offset, tagLen) // skip tag value - if err != nil { - return 0, err + if err = r.Skip(tagLen); err != nil { // skip tag value + return err } } - return offset, nil + return nil } func validateKafkaRequestHeader(header *KafkaRequestHeader) error { @@ -202,42 +232,49 @@ func isFlexible(header *KafkaRequestHeader) bool { } } -func readArrayLength(pkt []byte, header *KafkaRequestHeader, offset Offset) (int, Offset, error) { +func readArrayLength(r byteReader, header *KafkaRequestHeader) (int, error) { if isFlexible(header) { - size, offset, err := readUnsignedVarint(pkt[offset:], offset) + size, err := readUnsignedVarint(r) + if err != nil { + return 0, err + } if size == 0 { - return 0, offset, nil // return 0 for null + return 0, nil // return 0 for null } - return size - 1, offset, err - } else { - return readInt32(pkt, offset) + return size - 1, nil } + return readInt32(r) } -func readUUID(pkt []byte, offset Offset) (*UUID, Offset, error) { - if offset+UUIDLen > len(pkt) { - return nil, offset, errors.New("packet too short for topic UUID") +func readUUID(r byteReader) (*UUID, error) { + b, err := r.ReadN(UUIDLen) + if err != nil { + return nil, errors.New("packet too short for topic UUID") } - uuid := (UUID)(pkt[offset : offset+UUIDLen]) - return &uuid, offset + UUIDLen, nil + var uuid UUID + copy(uuid[:], b) + return &uuid, nil } -func readString(pkt []byte, header *KafkaRequestHeader, offset Offset, nullable bool) (string, Offset, error) { - size, offset, err := readStringLength(pkt, header, offset, nullable) +func readString(r byteReader, header *KafkaRequestHeader, nullable bool) (string, error) { + size, err := readStringLength(r, header, nullable) if err != nil { - return "", offset, err + return "", err } if nullable && size == 0 { - return "", offset, nil // return empty string for null + return "", nil // return empty string for null + } + if r.Remaining() < size { + return "", errors.New("string size exceeds packet size") } - if offset+size > len(pkt) { - return "", 0, errors.New("string size exceeds packet size") + b, err := r.ReadN(size) + if err != nil { + return "", errors.New("string size exceeds packet size") } - if !validateKafkaString(pkt[offset:offset+size], size) { - return "", 0, errors.New("invalid characters in string") + if !validateKafkaString(b, size) { + return "", errors.New("invalid characters in string") } - str := string(pkt[offset : offset+size]) - return str, offset + size, nil + return string(b), nil } func validateKafkaString(pkt []byte, size int) bool { @@ -251,80 +288,79 @@ func validateKafkaString(pkt []byte, size int) bool { return true } -func readStringLength(pkt []byte, header *KafkaRequestHeader, offset Offset, nullable bool) (int, Offset, error) { +func readStringLength(r byteReader, header *KafkaRequestHeader, nullable bool) (int, error) { if !isFlexible(header) { // length is stored as a fixed size int16 - if offset+Int16Len > len(pkt) { - return 0, 0, errors.New("packet too short for string length") + if r.Remaining() < Int16Len { + return 0, errors.New("packet too short for string length") } - size := int16(binary.BigEndian.Uint16(pkt[offset:])) + b, err := r.ReadN(Int16Len) + if err != nil { + return 0, errors.New("packet too short for string length") + } + size := int16(binary.BigEndian.Uint16(b)) if nullable && size == -1 { - return 0, offset + Int16Len, nil // return 0 for null + return 0, nil // return 0 for null } if size < 1 { - return 0, 0, errors.New("invalid string size") + return 0, errors.New("invalid string size") } - return int(size), offset + Int16Len, nil + return int(size), nil } // length is stored as a varint - size, offset, err := readUnsignedVarint(pkt[offset:], offset) + size, err := readUnsignedVarint(r) if err != nil { - return 0, 0, err + return 0, err } if nullable && size == 0 { - return 0, offset, nil // return 0 for null + return 0, nil // return 0 for null } if size <= 0 { - return 0, 0, errors.New("invalid string size") + return 0, errors.New("invalid string size") } size-- // size is stored as a varint, so we subtract 1 if size < 0 { - return 0, 0, errors.New("invalid string size") + return 0, errors.New("invalid string size") } - return size, offset, nil + return size, nil } -func readUnsignedVarint(data []byte, offset Offset) (int, Offset, error) { +func readUnsignedVarint(r byteReader) (int, error) { value := 0 i := 0 - for idx := range data { - if idx > len(data) { - return 0, 0, errors.New("offset exceeds data length") + for { + if r.Remaining() == 0 { + return 0, errors.New("data ended before varint was complete") + } + b, err := r.ReadN(1) + if err != nil { + return 0, err } - b := data[idx] - if (b & 0x80) == 0 { - value |= int(b) << i - return value, offset + idx + 1, nil + if (b[0] & 0x80) == 0 { + value |= int(b[0]) << i + return value, nil } - value |= int(b&0x7F) << i + value |= int(b[0]&0x7F) << i i += 7 if i > 28 { - return 0, 0, errors.New("illegal varint") + return 0, errors.New("illegal varint") } } - return 0, 0, errors.New("data ended before varint was complete") } -func skipBytes(pkt []byte, offset Offset, length int) (Offset, error) { - if offset+length > len(pkt) { - return 0, errors.New("offset and length exceed packet size") - } - return offset + length, nil -} - -func readInt32(data []byte, offset Offset) (int, Offset, error) { - if offset+Int32Len > len(data) { - return 0, 0, errors.New("data too short for uint32") +func readInt32(r byteReader) (int, error) { + b, err := r.ReadN(Int32Len) + if err != nil { + return 0, errors.New("data too short for int32") } - value := int(binary.BigEndian.Uint32(data[offset:])) - return value, offset + Int32Len, nil + return int(binary.BigEndian.Uint32(b)), nil } -func readInt64(data []byte, offset Offset) (int64, Offset, error) { - if offset+Int64Len > len(data) { - return 0, 0, errors.New("data too short for uint32") +func readInt64(r byteReader) (int64, error) { + b, err := r.ReadN(Int64Len) + if err != nil { + return 0, errors.New("data too short for int64") } - value := int64(binary.BigEndian.Uint64(data[offset:])) - return value, offset + Int64Len, nil + return int64(binary.BigEndian.Uint64(b)), nil } diff --git a/pkg/internal/ebpf/kafkaparser/common_test.go b/pkg/internal/ebpf/kafkaparser/common_test.go index 7826d4588..5486f245d 100644 --- a/pkg/internal/ebpf/kafkaparser/common_test.go +++ b/pkg/internal/ebpf/kafkaparser/common_test.go @@ -152,7 +152,8 @@ func TestParseKafkaRequestHeader(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - header, offset, err := ParseKafkaRequestHeader(tt.packet) + r := newBytesReader(tt.packet) + header, err := ParseKafkaRequestHeader(r) if tt.expectErr { assert.Error(t, err) @@ -168,11 +169,11 @@ func TestParseKafkaRequestHeader(t *testing.T) { assert.Equal(t, tt.expected.CorrelationID, header.CorrelationID) assert.Equal(t, tt.expected.ClientID, header.ClientID) - expectedOffset := MinKafkaRequestLen + len(tt.expected.ClientID) + expectedConsumed := MinKafkaRequestLen + len(tt.expected.ClientID) if tt.flexible { - expectedOffset++ // Account for tagged fields byte + expectedConsumed++ // Account for tagged fields byte } - assert.Equal(t, expectedOffset, offset) + assert.Equal(t, expectedConsumed, r.Pos()) }) } } @@ -404,7 +405,8 @@ func TestReadArrayLength(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - length, offset, err := readArrayLength(tt.packet, tt.header, tt.offset) + r := newBytesReader(tt.packet[tt.offset:]) + length, err := readArrayLength(r, tt.header) if tt.expectErr { assert.Error(t, err) @@ -413,7 +415,7 @@ func TestReadArrayLength(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.expectedLength, length) - assert.Equal(t, tt.expectedOffset, offset) + assert.Equal(t, tt.expectedOffset-tt.offset, r.Pos()) }) } } @@ -481,7 +483,8 @@ func TestReadUUID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - uuid, offset, err := readUUID(tt.packet, tt.offset) + r := newBytesReader(tt.packet[tt.offset:]) + uuid, err := readUUID(r) if tt.expectErr { assert.Error(t, err) @@ -491,7 +494,7 @@ func TestReadUUID(t *testing.T) { require.NoError(t, err) require.NotNil(t, uuid) assert.Equal(t, tt.expectedUUID, *uuid) - assert.Equal(t, tt.expectedOffset, offset) + assert.Equal(t, tt.expectedOffset-tt.offset, r.Pos()) }) } } @@ -562,7 +565,8 @@ func TestReadString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - str, offset, err := readString(tt.packet, tt.header, tt.offset, tt.nullable) + r := newBytesReader(tt.packet[tt.offset:]) + str, err := readString(r, tt.header, tt.nullable) if tt.expectErr { assert.Error(t, err) @@ -571,43 +575,43 @@ func TestReadString(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.expectedString, str) - assert.Equal(t, tt.expectedOffset, offset) + assert.Equal(t, tt.expectedOffset-tt.offset, r.Pos()) }) } } func TestReadUnsignedVarint(t *testing.T) { tests := []struct { - name string - data []byte - offset int - expectedValue int - expectedOffset int - expectErr bool + name string + data []byte + offset int + expectedValue int + expectedBytes int // bytes consumed + expectErr bool }{ { - name: "single byte varint", - data: []byte{0x05}, - offset: 0, - expectedValue: 5, - expectedOffset: 1, - expectErr: false, + name: "single byte varint", + data: []byte{0x05}, + offset: 0, + expectedValue: 5, + expectedBytes: 1, + expectErr: false, }, { - name: "multi-byte varint", - data: []byte{0x96, 0x01}, // 150 in varint - offset: 0, - expectedValue: 150, - expectedOffset: 2, - expectErr: false, + name: "multi-byte varint", + data: []byte{0x96, 0x01}, // 150 in varint + offset: 0, + expectedValue: 150, + expectedBytes: 2, + expectErr: false, }, { - name: "large varint", - data: []byte{0xFF, 0xFF, 0x7F}, // Large number - offset: 0, - expectedValue: 2097151, - expectedOffset: 3, - expectErr: false, + name: "large varint", + data: []byte{0xFF, 0xFF, 0x7F}, // Large number + offset: 0, + expectedValue: 2097151, + expectedBytes: 3, + expectErr: false, }, { name: "incomplete varint", @@ -625,7 +629,8 @@ func TestReadUnsignedVarint(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - value, offset, err := readUnsignedVarint(tt.data, tt.offset) + r := newBytesReader(tt.data[tt.offset:]) + value, err := readUnsignedVarint(r) if tt.expectErr { assert.Error(t, err) @@ -634,48 +639,49 @@ func TestReadUnsignedVarint(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.expectedValue, value) - assert.Equal(t, tt.expectedOffset, offset) + assert.Equal(t, tt.expectedBytes, r.Pos()) }) } } -func TestSkipBytes(t *testing.T) { +func TestSkip(t *testing.T) { tests := []struct { - name string - packet []byte - offset int - length int - expectedOffset int - expectErr bool + name string + packet []byte + offset int + length int + expectedBytes int // bytes consumed by Skip + expectErr bool }{ { - name: "valid skip", - packet: make([]byte, 20), - offset: 5, - length: 10, - expectedOffset: 15, - expectErr: false, + name: "valid skip", + packet: make([]byte, 20), + offset: 5, + length: 10, + expectedBytes: 10, + expectErr: false, }, { name: "skip exceeds packet", packet: make([]byte, 10), offset: 5, - length: 10, // 5 + 10 > 10 + length: 10, // 5 remaining, but skip 10 expectErr: true, }, { - name: "skip zero bytes", - packet: make([]byte, 10), - offset: 3, - length: 0, - expectedOffset: 3, - expectErr: false, + name: "skip zero bytes", + packet: make([]byte, 10), + offset: 3, + length: 0, + expectedBytes: 0, + expectErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - offset, err := skipBytes(tt.packet, tt.offset, tt.length) + r := newBytesReader(tt.packet[tt.offset:]) + err := r.Skip(tt.length) if tt.expectErr { assert.Error(t, err) @@ -683,7 +689,7 @@ func TestSkipBytes(t *testing.T) { } require.NoError(t, err) - assert.Equal(t, tt.expectedOffset, offset) + assert.Equal(t, tt.expectedBytes, r.Pos()) }) } } @@ -703,7 +709,7 @@ func TestParseKafkaRequestHeaderTruncation(t *testing.T) { for i := 1; i < len(validPacket); i++ { t.Run(fmt.Sprintf("truncated_at_%d", i), func(t *testing.T) { truncated := validPacket[:i] - _, _, err := ParseKafkaRequestHeader(truncated) + _, err := ParseKafkaRequestHeader(newBytesReader(truncated)) assert.Error(t, err, "expected error for truncated packet at position %d", i) }) } diff --git a/pkg/internal/ebpf/kafkaparser/fetch.go b/pkg/internal/ebpf/kafkaparser/fetch.go index 0d748d937..7815998a5 100644 --- a/pkg/internal/ebpf/kafkaparser/fetch.go +++ b/pkg/internal/ebpf/kafkaparser/fetch.go @@ -19,15 +19,14 @@ type FetchRequest struct { Topics []*FetchTopic } -func ParseFetchRequest(pkt []byte, header *KafkaRequestHeader, offset Offset) (*FetchRequest, error) { - offset, err := fetchRequestSkipUntilTopics(pkt, header, offset) - if err != nil { +func ParseFetchRequest(r byteReader, header *KafkaRequestHeader) (*FetchRequest, error) { + if err := fetchRequestSkipUntilTopics(r, header); err != nil { return nil, err } - if offset >= len(pkt) { + if r.Remaining() == 0 { return nil, errors.New("offset exceeds packet size") } - topics, err := parseFetchTopics(pkt, header, offset) + topics, err := parseFetchTopics(r, header) if err != nil { return nil, err } @@ -39,8 +38,8 @@ func ParseFetchRequest(pkt []byte, header *KafkaRequestHeader, offset Offset) (* }, nil } -func fetchRequestSkipUntilTopics(pkt []byte, header *KafkaRequestHeader, offset Offset) (Offset, error) { - var err error +func fetchRequestSkipUntilTopics(r byteReader, header *KafkaRequestHeader) error { + var skipLen int switch { case header.APIVersion >= 15: /* @@ -53,14 +52,12 @@ func fetchRequestSkipUntilTopics(pkt []byte, header *KafkaRequestHeader, offset session_epoch => INT32 Topics => ... */ - offset, err = skipBytes(pkt, offset, - Int32Len+ // max_wait_ms - Int32Len+ // min_bytes - Int32Len+ // max_bytes - Int8Len+ // isolation_level - Int32Len+ // session_id - Int32Len, // session_epoch - ) + skipLen = Int32Len + // max_wait_ms + Int32Len + // min_bytes + Int32Len + // max_bytes + Int8Len + // isolation_level + Int32Len + // session_id + Int32Len // session_epoch case header.APIVersion >= 7: /* Fetch Request (Version: 7-14) => replica_id max_wait_ms min_bytes max_bytes isolation_level session_id session_epoch ... @@ -73,15 +70,13 @@ func fetchRequestSkipUntilTopics(pkt []byte, header *KafkaRequestHeader, offset session_epoch => INT32 Topics => ... */ - offset, err = skipBytes(pkt, offset, - Int32Len+ // replica_id - Int32Len+ // max_wait_ms - Int32Len+ // min_bytes - Int32Len+ // max_bytes - Int8Len+ // isolation_level - Int32Len+ // session_id - Int32Len, // session_epoch - ) + skipLen = Int32Len + // replica_id + Int32Len + // max_wait_ms + Int32Len + // min_bytes + Int32Len + // max_bytes + Int8Len + // isolation_level + Int32Len + // session_id + Int32Len // session_epoch case header.APIVersion >= 4: /* @@ -93,29 +88,26 @@ func fetchRequestSkipUntilTopics(pkt []byte, header *KafkaRequestHeader, offset isolation_level => INT8 Topics => ... */ - offset, err = skipBytes(pkt, offset, - Int32Len+ // replica_id - Int32Len+ // max_wait_ms - Int32Len+ // min_bytes - Int32Len+ // max_bytes - Int8Len, // isolation_level - ) + skipLen = Int32Len + // replica_id + Int32Len + // max_wait_ms + Int32Len + // min_bytes + Int32Len + // max_bytes + Int8Len // isolation_level } - if err != nil { - return 0, err + if skipLen > 0 { + return r.Skip(skipLen) } - return offset, nil + return nil } -func parseFetchTopics(pkt []byte, header *KafkaRequestHeader, offset Offset) ([]*FetchTopic, error) { - topicsLen, offset, err := readArrayLength(pkt, header, offset) +func parseFetchTopics(r byteReader, header *KafkaRequestHeader) ([]*FetchTopic, error) { + topicsLen, err := readArrayLength(r, header) if err != nil { return nil, err } var topics []*FetchTopic - var topic *FetchTopic for range topicsLen { - topic, offset, err = parseFetchTopic(pkt, header, offset) + topic, err := parseFetchTopic(r, header) if err != nil { // return the Topics parsed so far, even if one topic failed return topics, nil @@ -124,10 +116,10 @@ func parseFetchTopics(pkt []byte, header *KafkaRequestHeader, offset Offset) ([] topics = append(topics, topic) } } - return topics, err + return topics, nil } -func parseFetchTopic(pkt []byte, header *KafkaRequestHeader, offset Offset) (*FetchTopic, Offset, error) { +func parseFetchTopic(r byteReader, header *KafkaRequestHeader) (*FetchTopic, error) { var topic FetchTopic var err error if header.APIVersion >= 13 { @@ -139,9 +131,9 @@ func parseFetchTopic(pkt []byte, header *KafkaRequestHeader, offset Offset) (*Fe topic_id => UUID */ var topicUUID *UUID - topicUUID, offset, err = readUUID(pkt, offset) + topicUUID, err = readUUID(r) if err != nil { - return nil, offset, err + return nil, err } topic.UUID = topicUUID } else { @@ -150,49 +142,47 @@ func parseFetchTopic(pkt []byte, header *KafkaRequestHeader, offset Offset) (*Fe topic => STRING / COMPACT_STRING */ var topicName string - topicName, offset, err = readString(pkt, header, offset, false) + topicName, err = readString(r, header, false) if err != nil { - return nil, offset, err + return nil, err } topic.Name = topicName } - partitionCount, offset, err := readArrayLength(pkt, header, offset) + partitionCount, err := readArrayLength(r, header) if err != nil { // if we fail to read Partition count, we can still return the topic - return &topic, offset, nil + return &topic, nil } if partitionCount != 1 { // no need to capture multiple partitions for fetch request, if its 1 Partition we can add it to the span - offset, err = skipFetchPartitions(pkt, header, offset, partitionCount) - if err != nil { + if err = skipFetchPartitions(r, header, partitionCount); err != nil { // if we fail to skip partitions, we can still return the topic - return &topic, offset, nil + return &topic, nil } } else { var partition *FetchPartition - partition, offset, err = parseFetchPartition(pkt, header, offset) + partition, err = parseFetchPartition(r, header) if err != nil { // if we fail to parse Partition, we can still return the topic - return &topic, offset, nil + return &topic, nil } topic.Partition = partition } - offset, err = skipTaggedFields(pkt, header, offset) - if err != nil { - return &topic, offset, nil + if err = skipTaggedFields(r, header); err != nil { + return &topic, nil } - return &topic, offset, nil + return &topic, nil } -func parseFetchPartition(pkt []byte, header *KafkaRequestHeader, offset Offset) (*FetchPartition, Offset, error) { +func parseFetchPartition(r byteReader, header *KafkaRequestHeader) (*FetchPartition, error) { /* partitions => Partition fetch_offset log_start_offset partition_max_bytes Partition => INT32 fetch_offset => INT64 */ - partition, offset, err := readInt32(pkt, offset) + partition, err := readInt32(r) if err != nil { - return nil, offset, err + return nil, err } if header.APIVersion >= 9 { /* @@ -202,22 +192,21 @@ func parseFetchPartition(pkt []byte, header *KafkaRequestHeader, offset Offset) current_leader_epoch => INT32 fetch_offset => INT64 */ - offset, err = skipBytes(pkt, offset, Int32Len) // current_leader_epoch - if err != nil { - return nil, offset, err + if err = r.Skip(Int32Len); err != nil { // current_leader_epoch + return nil, err } } - fetchOffset, offset, err := readInt64(pkt, offset) + fetchOffset, err := readInt64(r) if err != nil { - return nil, offset, err + return nil, err } return &FetchPartition{ Partition: partition, FetchOffset: fetchOffset, - }, offset, nil + }, nil } -func skipFetchPartitions(pkt []byte, header *KafkaRequestHeader, offset Offset, partitionCount int) (Offset, error) { +func skipFetchPartitions(r byteReader, header *KafkaRequestHeader, partitionCount int) error { var fetchPartitionLen int switch { case header.APIVersion >= 12: @@ -273,9 +262,5 @@ func skipFetchPartitions(pkt []byte, header *KafkaRequestHeader, offset Offset, Int64Len + // fetch_offset Int32Len // partition_max_bytes } - offset, err := skipBytes(pkt, offset, fetchPartitionLen*partitionCount) - if err != nil { - return 0, err - } - return offset, nil + return r.Skip(fetchPartitionLen * partitionCount) } diff --git a/pkg/internal/ebpf/kafkaparser/fetch_test.go b/pkg/internal/ebpf/kafkaparser/fetch_test.go index 9dfaa7c14..dc943a129 100644 --- a/pkg/internal/ebpf/kafkaparser/fetch_test.go +++ b/pkg/internal/ebpf/kafkaparser/fetch_test.go @@ -354,7 +354,7 @@ func TestParseFetchRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req, err := ParseFetchRequest(tt.packet, tt.header, 0) + req, err := ParseFetchRequest(newBytesReader(tt.packet), tt.header) if tt.expectErr { assert.Error(t, err) @@ -398,7 +398,7 @@ func TestParseFetchRequestTruncation(t *testing.T) { for i := 1; i < len(validPacket); i++ { t.Run(fmt.Sprintf("truncated_at_%d", i), func(t *testing.T) { truncated := validPacket[:i] - _, err := ParseFetchRequest(truncated, header, 0) + _, err := ParseFetchRequest(newBytesReader(truncated), header) assert.Error(t, err, "expected error for truncated packet at position %d for version %d", i, version) }) } diff --git a/pkg/internal/ebpf/kafkaparser/metadata.go b/pkg/internal/ebpf/kafkaparser/metadata.go index 968577bc0..a8dcecd15 100644 --- a/pkg/internal/ebpf/kafkaparser/metadata.go +++ b/pkg/internal/ebpf/kafkaparser/metadata.go @@ -25,12 +25,11 @@ type MetadataResponse struct { Topics []*MetadataTopic } -func ParseMetadataResponse(pkt []byte, header *KafkaRequestHeader, offset int) (*MetadataResponse, error) { - offset, err := metadataResponseSkipUntilTopics(pkt, header, offset) - if err != nil { +func ParseMetadataResponse(r byteReader, header *KafkaRequestHeader) (*MetadataResponse, error) { + if err := metadataResponseSkipUntilTopics(r, header); err != nil { return nil, err } - topics, err := parseMetadataTopics(pkt, header, offset) + topics, err := parseMetadataTopics(r, header) if err != nil { return nil, err } @@ -42,73 +41,61 @@ func ParseMetadataResponse(pkt []byte, header *KafkaRequestHeader, offset int) ( }, nil } -func metadataResponseSkipUntilTopics(pkt []byte, header *KafkaRequestHeader, offset Offset) (Offset, error) { - var err error - offset, err = skipBytes(pkt, offset, Int32Len) // throttle_time_ms - if err != nil { - return 0, err +func metadataResponseSkipUntilTopics(r byteReader, header *KafkaRequestHeader) error { + if err := r.Skip(Int32Len); err != nil { // throttle_time_ms + return err } - offset, err = skipMetadataResponseBrokers(pkt, header, offset) - if err != nil { - return 0, err + if err := skipMetadataResponseBrokers(r, header); err != nil { + return err } - clusterIDLen, offset, err := readStringLength(pkt, header, offset, true) - if err != nil { - return 0, err - } - offset, err = skipBytes(pkt, offset, clusterIDLen+Int32Len) // cluster_id + controller_id + clusterIDLen, err := readStringLength(r, header, true) if err != nil { - return 0, err + return err } - return offset, nil + return r.Skip(clusterIDLen + Int32Len) // cluster_id + controller_id } -func skipMetadataResponseBrokers(pkt []byte, header *KafkaRequestHeader, offset Offset) (Offset, error) { - brokersLen, offset, err := readArrayLength(pkt, header, offset) +func skipMetadataResponseBrokers(r byteReader, header *KafkaRequestHeader) error { + brokersLen, err := readArrayLength(r, header) if err != nil { - return 0, err + return err } for range brokersLen { - offset, err = skipBytes(pkt, offset, Int32Len) // node_id - if err != nil { - return 0, err + if err = r.Skip(Int32Len); err != nil { // node_id + return err } var hostLen int - hostLen, offset, err = readStringLength(pkt, header, offset, false) + hostLen, err = readStringLength(r, header, false) if err != nil { - return 0, err + return err } - offset, err = skipBytes(pkt, offset, hostLen+Int32Len) // host + port - if err != nil { - return 0, err + if err = r.Skip(hostLen + Int32Len); err != nil { // host + port + return err } var rackLen int - rackLen, offset, err = readStringLength(pkt, header, offset, true) + rackLen, err = readStringLength(r, header, true) if err != nil { - return 0, err + return err } - offset, err = skipBytes(pkt, offset, rackLen) // rack - if err != nil { - return 0, err + if err = r.Skip(rackLen); err != nil { // rack + return err } - offset, err = skipTaggedFields(pkt, header, offset) - if err != nil { - return 0, err + if err = skipTaggedFields(r, header); err != nil { + return err } } - return offset, nil + return nil } -func parseMetadataTopics(pkt []byte, header *KafkaRequestHeader, offset int) ([]*MetadataTopic, error) { - topicsLen, offset, err := readArrayLength(pkt, header, offset) +func parseMetadataTopics(r byteReader, header *KafkaRequestHeader) ([]*MetadataTopic, error) { + topicsLen, err := readArrayLength(r, header) if err != nil { return nil, err } var topics []*MetadataTopic - var topic *MetadataTopic for i := range topicsLen { - topic, offset, err = parseMetadataTopic(pkt, header, offset, i == topicsLen-1) + topic, err := parseMetadataTopic(r, header, i == topicsLen-1) if err != nil { // return the Topics parsed so far, even if one topic failed return topics, nil @@ -117,10 +104,10 @@ func parseMetadataTopics(pkt []byte, header *KafkaRequestHeader, offset int) ([] topics = append(topics, topic) } } - return topics, err + return topics, nil } -func parseMetadataTopic(pkt []byte, header *KafkaRequestHeader, offset int, isLast bool) (*MetadataTopic, int, error) { +func parseMetadataTopic(r byteReader, header *KafkaRequestHeader, isLast bool) (*MetadataTopic, error) { var topic MetadataTopic /* Metadata Response (Version: 10, 11, 12 and 13) @@ -139,39 +126,36 @@ func parseMetadataTopic(pkt []byte, header *KafkaRequestHeader, offset int, isLa offline_replicas => INT32 topic_authorized_operations => INT32 */ - offset, err := skipBytes(pkt, offset, Int16Len) // error_code - if err != nil { - return nil, offset, err + if err := r.Skip(Int16Len); err != nil { // error_code + return nil, err } isNullable := header.APIVersion >= 12 - topicName, offset, err := readString(pkt, header, offset, isNullable) + topicName, err := readString(r, header, isNullable) if err != nil { - return nil, offset, err + return nil, err } topic.Name = topicName - topicUUID, offset, err := readUUID(pkt, offset) + topicUUID, err := readUUID(r) if err != nil { - return nil, offset, err + return nil, err } topic.UUID = *topicUUID // optimization: no need to continue reading if this is the last topic if isLast { - return &topic, offset, nil + return &topic, nil } - partitionsCount, offset, err := readArrayLength(pkt, header, offset) + partitionsCount, err := readArrayLength(r, header) if err != nil { - return nil, offset, err + return nil, err } - offset, err = skipBytes(pkt, offset, (partitionsCount*partitionLen)+ // partitions - Int32Len, // topic_authorized_operations - ) - if err != nil { + skipBytes := partitionsCount*partitionLen + // partitions + Int32Len // topic_authorized_operations + if err = r.Skip(skipBytes); err != nil { // if we can't read partitions, we can still return the topic - return &topic, offset, nil + return &topic, nil } - offset, err = skipTaggedFields(pkt, header, offset) - if err != nil { - return &topic, offset, nil + if err = skipTaggedFields(r, header); err != nil { + return &topic, nil } - return &topic, offset, nil + return &topic, nil } diff --git a/pkg/internal/ebpf/kafkaparser/metadata_test.go b/pkg/internal/ebpf/kafkaparser/metadata_test.go index c18863e39..466661ed6 100644 --- a/pkg/internal/ebpf/kafkaparser/metadata_test.go +++ b/pkg/internal/ebpf/kafkaparser/metadata_test.go @@ -338,7 +338,7 @@ func TestParseMetadataResponse(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - resp, err := ParseMetadataResponse(tt.packet, tt.header, 0) + resp, err := ParseMetadataResponse(newBytesReader(tt.packet), tt.header) if tt.expectErr { assert.Error(t, err) @@ -380,7 +380,7 @@ func TestParseMetadataResponseTruncation(t *testing.T) { for i := 1; i < len(validPacket); i++ { t.Run(fmt.Sprintf("truncated_at_%d", i), func(t *testing.T) { truncated := validPacket[:i] - _, err := ParseMetadataResponse(truncated, header, 0) + _, err := ParseMetadataResponse(newBytesReader(truncated), header) assert.Error(t, err, "expected error for truncated packet at position %d for version %d", i, version) }) } @@ -402,7 +402,7 @@ func TestParseMetadataResponseAllVersions(t *testing.T) { // Create a valid packet for this version validPacket := createValidMetadataPacket(version) - resp, err := ParseMetadataResponse(validPacket, header, 0) + resp, err := ParseMetadataResponse(newBytesReader(validPacket), header) require.NoError(t, err, "unexpected error for version %d", version) require.NotNil(t, resp) diff --git a/pkg/internal/ebpf/kafkaparser/produce.go b/pkg/internal/ebpf/kafkaparser/produce.go index 0752f1c6c..dbc3dbb1a 100644 --- a/pkg/internal/ebpf/kafkaparser/produce.go +++ b/pkg/internal/ebpf/kafkaparser/produce.go @@ -14,12 +14,11 @@ type ProduceRequest struct { Topics []*ProduceTopic } -func ParseProduceRequest(pkt []byte, header *KafkaRequestHeader, offset Offset) (*ProduceRequest, error) { - offset, err := produceRequestSkipUntilTopics(pkt, header, offset) - if err != nil { +func ParseProduceRequest(r byteReader, header *KafkaRequestHeader) (*ProduceRequest, error) { + if err := produceRequestSkipUntilTopics(r, header); err != nil { return nil, err } - topics, err := parseProduceTopics(pkt, header, offset) + topics, err := parseProduceTopics(r, header) if err != nil { return nil, err } @@ -31,7 +30,7 @@ func ParseProduceRequest(pkt []byte, header *KafkaRequestHeader, offset Offset) }, nil } -func produceRequestSkipUntilTopics(pkt []byte, header *KafkaRequestHeader, offset Offset) (Offset, error) { +func produceRequestSkipUntilTopics(r byteReader, header *KafkaRequestHeader) error { /* Produce Request (Version: 3-12) => transactional_id acks timeout_ms [topic_data] _tagged_fields transactional_id => NULLABLE_STRING / COMPACT_NULLABLE_STRING @@ -39,34 +38,28 @@ func produceRequestSkipUntilTopics(pkt []byte, header *KafkaRequestHeader, offse timeout_ms => INT32 topic_data => Name [partition_data] _tagged_fields */ - transactionIDSize, offset, err := readStringLength(pkt, header, offset, true) + transactionIDSize, err := readStringLength(r, header, true) if err != nil { - return 0, err + return err } - offset, err = skipBytes(pkt, offset, - transactionIDSize+ // transactional_id - Int16Len+ // acks + return r.Skip( + transactionIDSize + // transactional_id + Int16Len + // acks Int32Len, // timeout_ms ) - if err != nil { - return 0, err - } - return offset, nil } -func parseProduceTopics(pkt []byte, header *KafkaRequestHeader, offset Offset) ([]*ProduceTopic, error) { - topicsLen, offset, err := readArrayLength(pkt, header, offset) +func parseProduceTopics(r byteReader, header *KafkaRequestHeader) ([]*ProduceTopic, error) { + topicsLen, err := readArrayLength(r, header) if err != nil { return nil, err } var topics []*ProduceTopic - var topic *ProduceTopic - // parse each topic if topicsLen <= 0 { return topics, nil } // read single topic for now, because skipping records is complicated - topic, _, err = parseProduceTopic(pkt, header, offset) + topic, err := parseProduceTopic(r, header) if err != nil { // return the Topics parsed so far, even if one topic failed return topics, nil @@ -74,34 +67,34 @@ func parseProduceTopics(pkt []byte, header *KafkaRequestHeader, offset Offset) ( if topic != nil { topics = append(topics, topic) } - return topics, err + return topics, nil } -func parseProduceTopic(pkt []byte, header *KafkaRequestHeader, offset Offset) (*ProduceTopic, Offset, error) { +func parseProduceTopic(r byteReader, header *KafkaRequestHeader) (*ProduceTopic, error) { var topic ProduceTopic /* Topics => topic [partitions] _tagged_fields topic => STRING / COMPACT_STRING */ - topicName, offset, err := readString(pkt, header, offset, false) + topicName, err := readString(r, header, false) if err != nil { - return nil, offset, err + return nil, err } topic.Name = topicName - partitionsLen, offset, err := readArrayLength(pkt, header, offset) + partitionsLen, err := readArrayLength(r, header) if err != nil { // return the topic even if partitions can't be read - return &topic, offset, nil + return &topic, nil } if partitionsLen != 1 { - // if more then 1 Partition, we just won't report Partition - return &topic, offset, nil + // if more than 1 Partition, we just won't report Partition + return &topic, nil } // read single Partition for now, because skipping records is complicated - firstPartition, offset, err := readInt32(pkt, offset) + firstPartition, err := readInt32(r) if err != nil { - return &topic, offset, nil + return &topic, nil } topic.Partition = &firstPartition - return &topic, offset, nil + return &topic, nil } diff --git a/pkg/internal/ebpf/kafkaparser/produce_test.go b/pkg/internal/ebpf/kafkaparser/produce_test.go index d68965a51..edcf98b0a 100644 --- a/pkg/internal/ebpf/kafkaparser/produce_test.go +++ b/pkg/internal/ebpf/kafkaparser/produce_test.go @@ -283,7 +283,7 @@ func TestParseProduceRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req, err := ParseProduceRequest(tt.packet, tt.header, 0) + req, err := ParseProduceRequest(newBytesReader(tt.packet), tt.header) if tt.expectErr { assert.Error(t, err) @@ -452,7 +452,8 @@ func TestProduceRequestSkipUntilTopics(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - offset, err := produceRequestSkipUntilTopics(tt.packet, tt.header, 0) + r := newBytesReader(tt.packet) + err := produceRequestSkipUntilTopics(r, tt.header) if tt.expectErr { assert.Error(t, err) @@ -460,7 +461,7 @@ func TestProduceRequestSkipUntilTopics(t *testing.T) { } require.NoError(t, err) - assert.Equal(t, tt.expectedOffset, offset) + assert.Equal(t, tt.expectedOffset, r.Pos()) }) } } @@ -484,7 +485,7 @@ func TestParseProduceRequestTruncation(t *testing.T) { for i := 1; i < len(validPacket); i++ { t.Run(fmt.Sprintf("truncated_at_%d", i), func(t *testing.T) { truncated := validPacket[:i] - _, err := ParseProduceRequest(truncated, header, 0) + _, err := ParseProduceRequest(newBytesReader(truncated), header) assert.Error(t, err, "expected error for truncated packet at position %d for version %d", i, version) }) } @@ -506,7 +507,7 @@ func TestParseProduceRequestAllVersions(t *testing.T) { // Create a valid packet for this version validPacket := createValidProducePacket(version) - req, err := ParseProduceRequest(validPacket, header, 0) + req, err := ParseProduceRequest(newBytesReader(validPacket), header) require.NoError(t, err, "unexpected error for version %d", version) require.NotNil(t, req) @@ -664,7 +665,7 @@ func TestParseProduceRequestEdgeCases(t *testing.T) { } packet := tt.packet() - _, err := ParseProduceRequest(packet, header, 0) + _, err := ParseProduceRequest(newBytesReader(packet), header) if tt.expectErr { assert.Error(t, err) diff --git a/pkg/internal/ebpf/kafkaparser/reader_test.go b/pkg/internal/ebpf/kafkaparser/reader_test.go new file mode 100644 index 000000000..85432dc27 --- /dev/null +++ b/pkg/internal/ebpf/kafkaparser/reader_test.go @@ -0,0 +1,50 @@ +// Copyright The OpenTelemetry Authors +// SPDX-License-Identifier: Apache-2.0 + +package kafkaparser + +import "fmt" + +// bytesReader is a test helper that wraps a []byte to implement the byteReader interface. +// It is used in tests to call parser functions without depending on LargeBuffer. +type bytesReader struct { + data []byte + pos int +} + +func newBytesReader(data []byte) *bytesReader { + return &bytesReader{data: data} +} + +func (r *bytesReader) ReadN(n int) ([]byte, error) { + if r.pos+n > len(r.data) { + return nil, fmt.Errorf("ReadN: requested %d bytes but only %d remaining", n, len(r.data)-r.pos) + } + s := r.data[r.pos : r.pos+n] + r.pos += n + return s, nil +} + +func (r *bytesReader) Peek(n int) ([]byte, error) { + if r.pos+n > len(r.data) { + return nil, fmt.Errorf("Peek: requested %d bytes but only %d remaining", n, len(r.data)-r.pos) + } + return r.data[r.pos : r.pos+n], nil +} + +func (r *bytesReader) Skip(n int) error { + if r.pos+n > len(r.data) { + return fmt.Errorf("Skip: requested %d bytes but only %d remaining", n, len(r.data)-r.pos) + } + r.pos += n + return nil +} + +func (r *bytesReader) Remaining() int { + return len(r.data) - r.pos +} + +// Pos returns the current read position (bytes consumed so far). +func (r *bytesReader) Pos() int { + return r.pos +}