diff --git a/pkg/engine/engine_test.go b/pkg/engine/engine_test.go index df12b2109409..2d735c906930 100644 --- a/pkg/engine/engine_test.go +++ b/pkg/engine/engine_test.go @@ -13,17 +13,16 @@ import ( "github.com/stretchr/testify/assert" - "github.com/trufflesecurity/trufflehog/v3/pkg/detectors/gitlab/v2" - "github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" - "github.com/trufflesecurity/trufflehog/v3/pkg/config" "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/custom_detectors" "github.com/trufflesecurity/trufflehog/v3/pkg/decoders" "github.com/trufflesecurity/trufflehog/v3/pkg/detectors" + "github.com/trufflesecurity/trufflehog/v3/pkg/detectors/gitlab/v2" "github.com/trufflesecurity/trufflehog/v3/pkg/engine/ahocorasick" "github.com/trufflesecurity/trufflehog/v3/pkg/engine/defaults" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/custom_detectorspb" + "github.com/trufflesecurity/trufflehog/v3/pkg/pb/detectorspb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/source_metadatapb" "github.com/trufflesecurity/trufflehog/v3/pkg/pb/sourcespb" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" @@ -317,8 +316,8 @@ aws_secret_access_key = 5dkLVuqpZhD6V3Zym1hivdSHOzh6FGPjwplXD+5f`, { name: "secret with mixed whitespace before", content: `first line - - + + AKIA2OGYBAH6STMMNXNN aws_secret_access_key = 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`, expectedLine: 4, @@ -1248,7 +1247,7 @@ def test_something(): conf := Config{ Concurrency: 1, Decoders: decoders.DefaultDecoders(), - Detectors: DefaultDetectors(), + Detectors: defaults.DefaultDetectors(), Verify: false, SourceManager: sourceManager, Dispatcher: NewPrinterDispatcher(new(discardPrinter)), diff --git a/pkg/handlers/apk.go b/pkg/handlers/apk.go index c83665522189..402f7cc39f70 100644 --- a/pkg/handlers/apk.go +++ b/pkg/handlers/apk.go @@ -12,9 +12,9 @@ import ( "strings" "time" + "github.com/avast/apkparser" dextk "github.com/csnewman/dextk" - "github.com/avast/apkparser" logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/engine/defaults" "github.com/trufflesecurity/trufflehog/v3/pkg/iobuf" @@ -66,45 +66,40 @@ func newAPKHandler() *apkHandler { } // HandleFile processes apk formatted files. -func (h *apkHandler) HandleFile(ctx logContext.Context, input fileReader) (chan []byte, error) { - apkChan := make(chan []byte, defaultBufferSize) +func (h *apkHandler) HandleFile(ctx logContext.Context, input fileReader) chan DataOrErr { + apkChan := make(chan DataOrErr, defaultBufferSize) go func() { - ctx, cancel := logContext.WithTimeout(ctx, maxTimeout) - defer cancel() defer close(apkChan) - // Update the metrics for the file processing. - start := time.Now() - var err error - defer func() { - h.measureLatencyAndHandleErrors(start, err) - h.metrics.incFilesProcessed() - }() - // Defer a panic recovery to handle any panics that occur during the APK processing. defer func() { if r := recover(); r != nil { // Return the panic as an error. + var panicErr error if e, ok := r.(error); ok { - err = e + panicErr = e } else { - err = fmt.Errorf("panic occurred: %v", r) + panicErr = fmt.Errorf("panic occurred: %v", r) } - ctx.Logger().Error(err, "Panic occurred when reading apk archive") + ctx.Logger().Error(panicErr, "Panic occurred when reading apk archive") } }() - if err = h.processAPK(ctx, input, apkChan); err != nil { - ctx.Logger().Error(err, "error processing apk content") + start := time.Now() + err := h.processAPK(ctx, input, apkChan) + if err == nil { + h.metrics.incFilesProcessed() } + + h.measureLatencyAndHandleErrors(ctx, start, err, apkChan) }() - return apkChan, nil + + return apkChan } // processAPK processes the apk file and sends the extracted data to the provided channel. -func (h *apkHandler) processAPK(ctx logContext.Context, input fileReader, apkChan chan []byte) error { - +func (h *apkHandler) processAPK(ctx logContext.Context, input fileReader, apkChan chan DataOrErr) error { // Create a ZIP reader from the input fileReader zipReader, err := createZipReader(input) if err != nil { @@ -132,7 +127,7 @@ func (h *apkHandler) processAPK(ctx logContext.Context, input fileReader, apkCha } // processResources processes the resources.arsc file and sends the extracted data to the provided channel. -func (h *apkHandler) processResources(ctx logContext.Context, resTable *apkparser.ResourceTable, apkChan chan []byte) error { +func (h *apkHandler) processResources(ctx logContext.Context, resTable *apkparser.ResourceTable, apkChan chan DataOrErr) error { if resTable == nil { return errors.New("ResourceTable is nil") } @@ -144,7 +139,7 @@ func (h *apkHandler) processResources(ctx logContext.Context, resTable *apkparse } // processFile processes the file and sends the extracted data to the provided channel. -func (h *apkHandler) processFile(ctx logContext.Context, file *zip.File, resTable *apkparser.ResourceTable, apkChan chan []byte) error { +func (h *apkHandler) processFile(ctx logContext.Context, file *zip.File, resTable *apkparser.ResourceTable, apkChan chan DataOrErr) error { // check if the file is empty if file.UncompressedSize64 == 0 { return nil @@ -177,7 +172,7 @@ func (h *apkHandler) processFile(ctx logContext.Context, file *zip.File, resTabl } // handleAPKFileContent sends the extracted data to the provided channel via the handleNonArchiveContent function. -func (h *apkHandler) handleAPKFileContent(ctx logContext.Context, rdr io.Reader, fileName string, apkChan chan []byte) error { +func (h *apkHandler) handleAPKFileContent(ctx logContext.Context, rdr io.Reader, fileName string, apkChan chan DataOrErr) error { mimeReader, err := newMimeTypeReader(rdr) if err != nil { return fmt.Errorf("failed to create mimeTypeReader for file %s: %w", fileName, err) diff --git a/pkg/handlers/apk_test.go b/pkg/handlers/apk_test.go index 0790d451118c..e7d12de7693f 100644 --- a/pkg/handlers/apk_test.go +++ b/pkg/handlers/apk_test.go @@ -19,7 +19,6 @@ func TestAPKHandler(t *testing.T) { expectedChunks int expectedSecrets int matchString string - expectErr bool }{ "apk_with_3_leaked_keys": { "https://github.com/joeleonjr/leakyAPK/raw/refs/heads/main/aws_leak.apk", @@ -28,7 +27,6 @@ func TestAPKHandler(t *testing.T) { // we're just looking for a string match. There is one extra string match in the APK (but only 3 detected secrets). 4, "AKIA2UC3BSXMLSCLTUUS", - false, }, } @@ -47,11 +45,7 @@ func TestAPKHandler(t *testing.T) { } defer newReader.Close() - archiveChan, err := handler.HandleFile(logContext.Background(), newReader) - if testCase.expectErr { - assert.NoError(t, err) - return - } + archiveChan := handler.HandleFile(logContext.Background(), newReader) chunkCount := 0 secretCount := 0 @@ -59,7 +53,7 @@ func TestAPKHandler(t *testing.T) { matched := false for chunk := range archiveChan { chunkCount++ - if re.Match(chunk) { + if re.Match(chunk.Data) { secretCount++ matched = true } @@ -82,7 +76,7 @@ func TestOpenInvalidAPK(t *testing.T) { assert.NoError(t, err) defer rdr.Close() - archiveChan := make(chan []byte) + archiveChan := make(chan DataOrErr) err = handler.processAPK(ctx, rdr, archiveChan) assert.Contains(t, err.Error(), "zip: not a valid zip file") @@ -106,7 +100,7 @@ func TestOpenValidZipInvalidAPK(t *testing.T) { assert.NoError(t, err) defer newReader.Close() - archiveChan := make(chan []byte) + archiveChan := make(chan DataOrErr) ctx := logContext.AddLogger(context.Background()) err = handler.processAPK(ctx, newReader, archiveChan) diff --git a/pkg/handlers/ar.go b/pkg/handlers/ar.go index c7d6dd5aac86..6720035fcdaf 100644 --- a/pkg/handlers/ar.go +++ b/pkg/handlers/ar.go @@ -22,16 +22,16 @@ func newARHandler() *arHandler { // HandleFile processes AR formatted files. This function needs to be implemented to extract or // manage data from AR files according to specific requirements. -func (h *arHandler) HandleFile(ctx logContext.Context, input fileReader) (chan []byte, error) { - archiveChan := make(chan []byte, defaultBufferSize) +func (h *arHandler) HandleFile(ctx logContext.Context, input fileReader) chan DataOrErr { + dataOrErrChan := make(chan DataOrErr, defaultBufferSize) if feature.ForceSkipArchives.Load() { - close(archiveChan) - return archiveChan, nil + close(dataOrErrChan) + return dataOrErrChan } go func() { - defer close(archiveChan) + defer close(dataOrErrChan) // Defer a panic recovery to handle any panics that occur during the AR processing. defer func() { @@ -53,19 +53,19 @@ func (h *arHandler) HandleFile(ctx logContext.Context, input fileReader) (chan [ return } - err = h.processARFiles(ctx, arReader, archiveChan) + err = h.processARFiles(ctx, arReader, dataOrErrChan) if err == nil { h.metrics.incFilesProcessed() } // Update the metrics for the file processing and handle any errors. - h.measureLatencyAndHandleErrors(start, err) + h.measureLatencyAndHandleErrors(ctx, start, err, dataOrErrChan) }() - return archiveChan, nil + return dataOrErrChan } -func (h *arHandler) processARFiles(ctx logContext.Context, reader *deb.Ar, archiveChan chan []byte) error { +func (h *arHandler) processARFiles(ctx logContext.Context, reader *deb.Ar, dataOrErrChan chan DataOrErr) error { for { select { case <-ctx.Done(): @@ -88,7 +88,7 @@ func (h *arHandler) processARFiles(ctx logContext.Context, reader *deb.Ar, archi return fmt.Errorf("error creating mime-type reader: %w", err) } - if err := h.handleNonArchiveContent(fileCtx, rdr, archiveChan); err != nil { + if err := h.handleNonArchiveContent(fileCtx, rdr, dataOrErrChan); err != nil { fileCtx.Logger().Error(err, "error handling archive content in AR") h.metrics.incErrors() } diff --git a/pkg/handlers/ar_test.go b/pkg/handlers/ar_test.go index 59285ca47efd..8658fdafcdff 100644 --- a/pkg/handlers/ar_test.go +++ b/pkg/handlers/ar_test.go @@ -23,12 +23,12 @@ func TestHandleARFile(t *testing.T) { defer rdr.Close() handler := newARHandler() - archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr) + dataOrErrChan := handler.HandleFile(context.AddLogger(ctx), rdr) assert.NoError(t, err) wantChunkCount := 102 count := 0 - for range archiveChan { + for range dataOrErrChan { count++ } diff --git a/pkg/handlers/archive.go b/pkg/handlers/archive.go index 3e940b232733..204fd080833a 100644 --- a/pkg/handlers/archive.go +++ b/pkg/handlers/archive.go @@ -44,16 +44,16 @@ func newArchiveHandler() *archiveHandler { // utilizing a single output channel. It first tries to identify the input as an archive. If it is an archive, // it processes it accordingly; otherwise, it handles the input as non-archive content. // The function returns a channel that will receive the extracted data bytes and an error if the initial setup fails. -func (h *archiveHandler) HandleFile(ctx logContext.Context, input fileReader) (chan []byte, error) { - dataChan := make(chan []byte, defaultBufferSize) +func (h *archiveHandler) HandleFile(ctx logContext.Context, input fileReader) chan DataOrErr { + dataOrErrChan := make(chan DataOrErr, defaultBufferSize) if feature.ForceSkipArchives.Load() { - close(dataChan) - return dataChan, nil + close(dataOrErrChan) + return dataOrErrChan } go func() { - defer close(dataChan) + defer close(dataOrErrChan) // The underlying 7zip library may panic when attempting to open an archive. // This is due to an Index Out Of Range (IOOR) error when reading the archive header. @@ -71,16 +71,16 @@ func (h *archiveHandler) HandleFile(ctx logContext.Context, input fileReader) (c }() start := time.Now() - err := h.openArchive(ctx, 0, input, dataChan) + err := h.openArchive(ctx, 0, input, dataOrErrChan) if err == nil { h.metrics.incFilesProcessed() } // Update the metrics for the file processing and handle any errors. - h.measureLatencyAndHandleErrors(start, err) + h.measureLatencyAndHandleErrors(ctx, start, err, dataOrErrChan) }() - return dataChan, nil + return dataOrErrChan } var ErrMaxDepthReached = errors.New("max archive depth reached") @@ -89,7 +89,12 @@ var ErrMaxDepthReached = errors.New("max archive depth reached") // It takes a reader from which it attempts to identify and process the archive format. Depending on the archive type, // it either decompresses or extracts the contents directly, sending data to the provided channel. // Returns an error if the archive cannot be processed due to issues like exceeding maximum depth or unsupported formats. -func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader fileReader, archiveChan chan []byte) error { +func (h *archiveHandler) openArchive( + ctx logContext.Context, + depth int, + reader fileReader, + dataOrErrChan chan DataOrErr, +) error { ctx.Logger().V(4).Info("Starting archive processing", "depth", depth) defer ctx.Logger().V(4).Info("Finished archive processing", "depth", depth) @@ -104,7 +109,7 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f if reader.format == nil { if depth > 0 { - return h.handleNonArchiveContent(ctx, newMimeTypeReaderFromFileReader(reader), archiveChan) + return h.handleNonArchiveContent(ctx, newMimeTypeReaderFromFileReader(reader), dataOrErrChan) } return fmt.Errorf("unknown archive format") } @@ -132,9 +137,9 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f } defer rdr.Close() - return h.openArchive(ctx, depth+1, rdr, archiveChan) + return h.openArchive(ctx, depth+1, rdr, dataOrErrChan) case archiver.Extractor: - err := archive.Extract(logContext.WithValue(ctx, depthKey, depth+1), reader, nil, h.extractorHandler(archiveChan)) + err := archive.Extract(logContext.WithValue(ctx, depthKey, depth+1), reader, nil, h.extractorHandler(dataOrErrChan)) if err != nil { return fmt.Errorf("error extracting archive with format: %s: %w", reader.format.Name(), err) } @@ -148,7 +153,7 @@ func (h *archiveHandler) openArchive(ctx logContext.Context, depth int, reader f // It logs the extraction, checks for cancellation, and decides whether to skip the file based on its name or type, // particularly for binary files if configured to skip. If the file is not skipped, it recursively calls openArchive // to handle nested archives or to continue processing based on the file's content and depth in the archive structure. -func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context.Context, archiver.File) error { +func (h *archiveHandler) extractorHandler(dataOrErrChan chan DataOrErr) func(context.Context, archiver.File) error { return func(ctx context.Context, file archiver.File) error { lCtx := logContext.WithValues( logContext.AddLogger(ctx), @@ -220,6 +225,6 @@ func (h *archiveHandler) extractorHandler(archiveChan chan []byte) func(context. h.metrics.observeFileSize(fileSize) lCtx.Logger().V(4).Info("Processed file successfully", "filename", file.Name(), "size", file.Size()) - return h.openArchive(lCtx, depth, rdr, archiveChan) + return h.openArchive(lCtx, depth, rdr, dataOrErrChan) } } diff --git a/pkg/handlers/archive_test.go b/pkg/handlers/archive_test.go index a24ab98a1775..3463b6220c90 100644 --- a/pkg/handlers/archive_test.go +++ b/pkg/handlers/archive_test.go @@ -91,7 +91,7 @@ func TestArchiveHandler(t *testing.T) { } defer newReader.Close() - archiveChan, err := handler.HandleFile(logContext.Background(), newReader) + dataOrErrChan := handler.HandleFile(logContext.Background(), newReader) if testCase.expectErr { assert.NoError(t, err) return @@ -100,9 +100,9 @@ func TestArchiveHandler(t *testing.T) { count := 0 re := regexp.MustCompile(testCase.matchString) matched := false - for chunk := range archiveChan { + for chunk := range dataOrErrChan { count++ - if re.Match(chunk) { + if re.Match(chunk.Data) { matched = true } } @@ -123,8 +123,8 @@ func TestOpenInvalidArchive(t *testing.T) { assert.NoError(t, err) defer rdr.Close() - archiveChan := make(chan []byte) + dataOrErrChan := make(chan DataOrErr) - err = handler.openArchive(ctx, 0, rdr, archiveChan) + err = handler.openArchive(ctx, 0, rdr, dataOrErrChan) assert.Error(t, err) } diff --git a/pkg/handlers/default.go b/pkg/handlers/default.go index c39a0c8e878e..5bcfb2d63ea8 100644 --- a/pkg/handlers/default.go +++ b/pkg/handlers/default.go @@ -3,6 +3,7 @@ package handlers import ( "context" "errors" + "fmt" "io" "time" @@ -30,37 +31,53 @@ func newDefaultHandler(handlerType handlerType) *defaultHandler { // utilizing a single output channel. It first tries to identify the input as an archive. If it is an archive, // it processes it accordingly; otherwise, it handles the input as non-archive content. // The function returns a channel that will receive the extracted data bytes and an error if the initial setup fails. -func (h *defaultHandler) HandleFile(ctx logContext.Context, input fileReader) (chan []byte, error) { +func (h *defaultHandler) HandleFile(ctx logContext.Context, input fileReader) chan DataOrErr { // Shared channel for both archive and non-archive content. - dataChan := make(chan []byte, defaultBufferSize) + dataOrErrChan := make(chan DataOrErr, defaultBufferSize) go func() { - defer close(dataChan) + defer close(dataOrErrChan) start := time.Now() - err := h.handleNonArchiveContent(ctx, newMimeTypeReaderFromFileReader(input), dataChan) + err := h.handleNonArchiveContent(ctx, newMimeTypeReaderFromFileReader(input), dataOrErrChan) if err == nil { h.metrics.incFilesProcessed() } // Update the metrics for the file processing and handle errors. - h.measureLatencyAndHandleErrors(start, err) + h.measureLatencyAndHandleErrors(ctx, start, err, dataOrErrChan) }() - return dataChan, nil + return dataOrErrChan } // measureLatencyAndHandleErrors measures the latency of the file processing and updates the metrics accordingly. // It also records errors and timeouts in the metrics. -func (h *defaultHandler) measureLatencyAndHandleErrors(start time.Time, err error) { +func (h *defaultHandler) measureLatencyAndHandleErrors( + ctx logContext.Context, + start time.Time, + err error, + dataErrChan chan<- DataOrErr, +) { if err == nil { h.metrics.observeHandleFileLatency(time.Since(start).Milliseconds()) return } + dataOrErr := DataOrErr{} h.metrics.incErrors() if errors.Is(err, context.DeadlineExceeded) { h.metrics.incFileProcessingTimeouts() + dataOrErr.Err = fmt.Errorf("%w: error processing chunk", err) + if err := common.CancellableWrite(ctx, dataErrChan, dataOrErr); err != nil { + ctx.Logger().Error(err, "error writing to data channel") + } + return + } + + dataOrErr.Err = err + if err := common.CancellableWrite(ctx, dataErrChan, dataOrErr); err != nil { + ctx.Logger().Error(err, "error writing to data channel") } } @@ -69,7 +86,11 @@ func (h *defaultHandler) measureLatencyAndHandleErrors(start time.Time, err erro // on the type, particularly for binary files. It manages reading file chunks and writing them to the archive channel, // effectively collecting the final bytes for further processing. This function is a key component in ensuring that all // file content, regardless of being an archive or not, is handled appropriately. -func (h *defaultHandler) handleNonArchiveContent(ctx logContext.Context, reader mimeTypeReader, archiveChan chan []byte) error { +func (h *defaultHandler) handleNonArchiveContent( + ctx logContext.Context, + reader mimeTypeReader, + dataOrErrChan chan DataOrErr, +) error { mimeExt := reader.mimeExt if common.SkipFile(mimeExt) || common.IsBinary(mimeExt) { @@ -82,13 +103,18 @@ func (h *defaultHandler) handleNonArchiveContent(ctx logContext.Context, reader chunkReader := sources.NewChunkReader() for data := range chunkReader(ctx, reader) { + dataOrErr := DataOrErr{} if err := data.Error(); err != nil { - ctx.Logger().Error(err, "error reading chunk") h.metrics.incErrors() + dataOrErr.Err = fmt.Errorf("%w: error reading chunk", err) + if writeErr := common.CancellableWrite(ctx, dataOrErrChan, dataOrErr); writeErr != nil { + return fmt.Errorf("%w: error writing to data channel", writeErr) + } continue } - if err := common.CancellableWrite(ctx, archiveChan, data.Bytes()); err != nil { + dataOrErr.Data = data.Bytes() + if err := common.CancellableWrite(ctx, dataOrErrChan, dataOrErr); err != nil { return err } h.metrics.incBytesProcessed(len(data.Bytes())) diff --git a/pkg/handlers/default_test.go b/pkg/handlers/default_test.go index 3d071ad6f382..613ce9dffdea 100644 --- a/pkg/handlers/default_test.go +++ b/pkg/handlers/default_test.go @@ -23,12 +23,12 @@ func TestHandleNonArchiveFile(t *testing.T) { defer rdr.Close() handler := newDefaultHandler(defaultHandlerType) - archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr) + dataOrErrChan := handler.HandleFile(context.AddLogger(ctx), rdr) assert.NoError(t, err) wantChunkCount := 6 count := 0 - for range archiveChan { + for range dataOrErrChan { count++ } diff --git a/pkg/handlers/handlers.go b/pkg/handlers/handlers.go index 3f49cf62538a..2b72f2041a3f 100644 --- a/pkg/handlers/handlers.go +++ b/pkg/handlers/handlers.go @@ -92,26 +92,37 @@ func newMimeTypeReader(r io.Reader) (mimeTypeReader, error) { } // newFileReader creates a fileReader from an io.Reader, optionally using BufferedFileWriter for certain formats. -func newFileReader(r io.Reader, options ...readerOption) (fileReader, error) { - var ( - fReader fileReader - cfg readerConfig - ) +// The caller is responsible for closing the reader when it is no longer needed. +func newFileReader(r io.Reader, options ...readerOption) (fReader fileReader, err error) { + var cfg readerConfig for _, opt := range options { opt(&cfg) } - + // To detect the MIME type of the input data, we need a reader that supports seeking. + // This allows us to read the data multiple times if necessary without losing the original position. + // We use a BufferedReaderSeeker to wrap the original reader, enabling this functionality. fReader.BufferedReadSeeker = iobuf.NewBufferedReaderSeeker(r) - mime, err := mimetype.DetectReader(fReader) + // If an error occurs during MIME type detection, it is important we close the BufferedReaderSeeker + // to release any resources it holds (checked out buffers or temp file). + defer func() { + if err != nil { + if closeErr := fReader.Close(); closeErr != nil { + err = fmt.Errorf("%w; error closing reader: %w", err, closeErr) + } + } + }() + + var mime *mimetype.MIME + mime, err = mimetype.DetectReader(fReader) if err != nil { return fReader, fmt.Errorf("unable to detect MIME type: %w", err) } fReader.mime = mime // Reset the reader to the beginning because DetectReader consumes the reader. - if _, err := fReader.Seek(0, io.SeekStart); err != nil { + if _, err = fReader.Seek(0, io.SeekStart); err != nil { return fReader, fmt.Errorf("error resetting reader after MIME detection: %w", err) } @@ -132,7 +143,8 @@ func newFileReader(r io.Reader, options ...readerOption) (fileReader, error) { return fReader, nil } - format, _, err := archiver.Identify("", fReader) + var format archiver.Format + format, _, err = archiver.Identify("", fReader) switch { case err == nil: fReader.isGenericArchive = true @@ -147,18 +159,29 @@ func newFileReader(r io.Reader, options ...readerOption) (fileReader, error) { // Reset the reader to the beginning again to allow the handler to read from the start. // This is necessary because Identify consumes the reader. - if _, err := fReader.Seek(0, io.SeekStart); err != nil { + if _, err = fReader.Seek(0, io.SeekStart); err != nil { return fReader, fmt.Errorf("error resetting reader after archive identification: %w", err) } return fReader, nil } +// DataOrErr represents a result that can either contain data or an error. +// The Data field holds the byte slice of data, and the Err field holds any error that occurred. +// This structure is used to handle asynchronous file processing where each chunk of data +// or potential error needs to be communicated back to the caller. It allows for +// efficient streaming of file contents while also providing a way to propagate errors +// that may occur during the file handling process. +type DataOrErr struct { + Data []byte + Err error +} + // FileHandler represents a handler for files. // It has a single method, HandleFile, which takes a context and a fileReader as input, // and returns a channel of byte slices and an error. type FileHandler interface { - HandleFile(ctx logContext.Context, reader fileReader) (chan []byte, error) + HandleFile(ctx logContext.Context, reader fileReader) chan DataOrErr } // fileHandlingConfig encapsulates configuration settings that control the behavior of file processing. @@ -295,14 +318,25 @@ var maxTimeout = time.Duration(60) * time.Second func SetArchiveMaxTimeout(timeout time.Duration) { maxTimeout = timeout } // HandleFile orchestrates the complete file handling process for a given file. -// It determines the MIME type of the file, selects the appropriate handler based on this type, and processes the file. -// This function initializes the handling process and delegates to the specific handler to manage file -// extraction or processing. Errors at any stage result in an error return value. -// Successful handling passes the file content through a channel to be chunked and reported. -// The function will close the reader when it has consumed all the data. +// It determines the MIME type of the file, +// selects the appropriate handler based on this type, and processes the file. +// This function initializes the handling process and delegates to the specific +// handler to manage file extraction or processing. +// +// The function will return nil (success) in the following cases: +// - If the reader is empty (ErrEmptyReader) +// - If skipArchives option is true and the file is detected as an archive +// - If all chunks are processed successfully without critical errors // -// If the skipArchives option is set to true and the detected MIME type is a known archive type, -// the function will skip processing the file and return nil. +// The function will return an error in the following cases: +// - If the reader is nil +// - If there's an error creating the file reader +// - If there's an error closing the reader +// - If a critical error occurs during chunk processing (context cancellation, deadline exceeded, or ErrProcessingFatal) +// - If there's an error reporting a chunk +// +// Non-critical errors during chunk processing are logged +// but do not cause the function to return an error. func HandleFile( ctx logContext.Context, reader io.Reader, @@ -311,7 +345,7 @@ func HandleFile( options ...func(*fileHandlingConfig), ) error { if reader == nil { - return fmt.Errorf("reader is nil") + return errors.New("reader is nil") } readerOption := withFileExtension(getFileExtension(chunkSkel)) @@ -321,7 +355,7 @@ func HandleFile( ctx.Logger().V(5).Info("empty reader, skipping file") return nil } - return fmt.Errorf("failed to create file reader to handle file: %w", err) + return fmt.Errorf("unable to HandleFile, error creating file reader: %w", err) } defer func() { // Ensure all data is read to prevent broken pipe. @@ -347,39 +381,41 @@ func HandleFile( defer cancel() handler := selectHandler(mimeT, rdr.isGenericArchive) - archiveChan, err := handler.HandleFile(processingCtx, rdr) // Delegate to the specific handler to process the file. - if err != nil { - return fmt.Errorf("error handling file: %w", err) - } + dataOrErrChan := handler.HandleFile(processingCtx, rdr) // Delegate to the specific handler to process the file. - return handleChunks(processingCtx, archiveChan, chunkSkel, reporter) + return handleChunksWithError(processingCtx, dataOrErrChan, chunkSkel, reporter) } -// handleChunks reads data from the handlerChan and uses it to fill chunks according to a predefined skeleton (chunkSkel). -// Each filled chunk is reported using the provided reporter. This function manages the lifecycle of the channel, -// handling the termination condition when the channel closes and ensuring the cancellation of the operation if the context -// is done. It returns true if all chunks are processed successfully, otherwise returns false on errors or cancellation. -func handleChunks( +// handleChunksWithError processes data and errors received from the dataErrChan channel. +// For each DataOrErr received: +// - If it contains data, the function creates a chunk based on chunkSkel and reports it through the reporter. +// - If it contains an error, the function logs the error. +// The function also listens for context cancellation to gracefully terminate processing if the context is done. +// It returns nil upon successful processing of all data, or the first encountered error. +func handleChunksWithError( ctx logContext.Context, - handlerChan chan []byte, + dataErrChan chan DataOrErr, chunkSkel *sources.Chunk, reporter sources.ChunkReporter, ) error { - if handlerChan == nil { - return fmt.Errorf("handler channel is nil") - } - for { select { - case data, open := <-handlerChan: - if !open { - ctx.Logger().V(5).Info("handler channel closed, all chunks processed") + case dataOrErr, ok := <-dataErrChan: + if !ok { + // Channel closed, processing complete. + ctx.Logger().V(5).Info("dataErrChan closed, all chunks processed") return nil } - chunk := *chunkSkel - chunk.Data = data - if err := reporter.ChunkOk(ctx, chunk); err != nil { - return fmt.Errorf("error reporting chunk: %w", err) + if dataOrErr.Err != nil { + ctx.Logger().Error(dataOrErr.Err, "error processing chunk") + continue + } + if len(dataOrErr.Data) > 0 { + chunk := *chunkSkel + chunk.Data = dataOrErr.Data + if err := reporter.ChunkOk(ctx, chunk); err != nil { + return fmt.Errorf("error reporting chunk: %w", err) + } } case <-ctx.Done(): return ctx.Err() diff --git a/pkg/handlers/handlers_test.go b/pkg/handlers/handlers_test.go index f4c3c54f26c8..64460a0a2da0 100644 --- a/pkg/handlers/handlers_test.go +++ b/pkg/handlers/handlers_test.go @@ -3,6 +3,7 @@ package handlers import ( "archive/zip" "bytes" + "errors" "fmt" "io" "net/http" @@ -11,13 +12,13 @@ import ( "path/filepath" "strings" "testing" + "testing/iotest" "time" "github.com/stretchr/testify/assert" diskbufferreader "github.com/trufflesecurity/disk-buffer-reader" "github.com/trufflesecurity/trufflehog/v3/pkg/context" - logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context" "github.com/trufflesecurity/trufflehog/v3/pkg/sources" ) @@ -61,7 +62,7 @@ func TestHandleHTTPJson(t *testing.T) { chunkCh := make(chan *sources.Chunk, 1) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -85,7 +86,7 @@ func TestHandleHTTPJsonZip(t *testing.T) { chunkCh := make(chan *sources.Chunk, 1) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -115,7 +116,7 @@ func BenchmarkHandleHTTPJsonZip(b *testing.B) { b.StartTimer() go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(b, err) }() @@ -158,7 +159,7 @@ func TestSkipArchive(t *testing.T) { chunkCh := make(chan *sources.Chunk) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}, WithSkipArchives(true)) + err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}, WithSkipArchives(true)) assert.NoError(t, err) }() @@ -177,7 +178,7 @@ func TestHandleNestedArchives(t *testing.T) { chunkCh := make(chan *sources.Chunk) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -196,7 +197,7 @@ func TestHandleCompressedZip(t *testing.T) { chunkCh := make(chan *sources.Chunk) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -215,7 +216,7 @@ func TestHandleNestedCompressedArchive(t *testing.T) { chunkCh := make(chan *sources.Chunk) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -234,7 +235,7 @@ func TestExtractTarContent(t *testing.T) { chunkCh := make(chan *sources.Chunk) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -339,7 +340,7 @@ func TestExtractTarContentWithEmptyFile(t *testing.T) { chunkCh := make(chan *sources.Chunk, 1) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -359,7 +360,7 @@ func TestHandleTar(t *testing.T) { chunkCh := make(chan *sources.Chunk, 1) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), file, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -411,7 +412,7 @@ func TestHandleLargeHTTPJson(t *testing.T) { chunkCh := make(chan *sources.Chunk, 1) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), resp.Body, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -438,7 +439,7 @@ func TestHandlePipe(t *testing.T) { chunkCh := make(chan *sources.Chunk, 1) go func() { defer close(chunkCh) - err := HandleFile(logContext.Background(), r, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) + err := HandleFile(context.Background(), r, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}) assert.NoError(t, err) }() @@ -684,3 +685,99 @@ func setupTempGitRepoCommon(t *testing.T, fileName string, fileSize int, isUnsup return tempDir } + +func TestHandleFileNewFileReaderFailure(t *testing.T) { + customReader := iotest.ErrReader(errors.New("simulated newFileReader error")) + + chunkSkel := &sources.Chunk{} + chunkCh := make(chan *sources.Chunk) + reporter := sources.ChanReporter{Ch: chunkCh} + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + err := HandleFile(ctx, customReader, chunkSkel, reporter) + + assert.Error(t, err, "HandleFile should return an error when newFileReader fails") +} + +// errorInjectingReader is a custom io.Reader that injects an error after reading a certain number of bytes. +type errorInjectingReader struct { + reader io.Reader + injectAfter int64 // Number of bytes after which to inject the error + injected bool + bytesRead int64 + errorToInject error +} + +func (eir *errorInjectingReader) Read(p []byte) (int, error) { + if eir.injectAfter > 0 && eir.bytesRead >= eir.injectAfter && !eir.injected { + eir.injected = true + return 0, eir.errorToInject + } + + n, err := eir.reader.Read(p) + eir.bytesRead += int64(n) + return n, err +} + +// TestHandleGitCatFileWithPipeError tests that when an error is injected during the HandleFile processing, +// the error is reported and the git cat-file command completes successfully. +func TestHandleGitCatFileWithPipeError(t *testing.T) { + fileName := "largefile_with_error.bin" + fileSize := 100 * 1024 // 100 KB + injectErrorAfter := int64(50 * 1024) // Inject error after 50 KB + simulatedError := errors.New("simulated error during newFileReader") + + gitDir := setupTempGitRepo(t, fileName, fileSize) + defer os.RemoveAll(gitDir) + + commitHash := getGitCommitHash(t, gitDir) + + cmd := exec.Command("git", "-C", gitDir, "cat-file", "blob", fmt.Sprintf("%s:%s", commitHash, fileName)) + + var stderr bytes.Buffer + cmd.Stderr = &stderr + + stdout, err := cmd.StdoutPipe() + assert.NoError(t, err, "Failed to create stdout pipe") + + err = cmd.Start() + assert.NoError(t, err, "Failed to start git cat-file command") + + // Wrap the stdout with errorInjectingReader to simulate an error after reading injectErrorAfter bytes. + wrappedReader := &errorInjectingReader{ + reader: stdout, + injectAfter: injectErrorAfter, + injected: false, + bytesRead: 0, + errorToInject: simulatedError, + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + chunkCh := make(chan *sources.Chunk, 1000) + + go func() { + defer close(chunkCh) + err = HandleFile(ctx, wrappedReader, &sources.Chunk{}, sources.ChanReporter{Ch: chunkCh}, WithSkipArchives(false)) + assert.NoError(t, err, "HandleFile should not return an error") + }() + + for range chunkCh { + } + + err = cmd.Wait() + assert.NoError(t, err, "git cat-file command should complete without error") +} + +// getGitCommitHash retrieves the current commit hash of the Git repository. +func getGitCommitHash(t *testing.T, gitDir string) string { + t.Helper() + cmd := exec.Command("git", "-C", gitDir, "rev-parse", "HEAD") + hashBytes, err := cmd.Output() + assert.NoError(t, err, "Failed to get commit hash") + commitHash := strings.TrimSpace(string(hashBytes)) + return commitHash +} diff --git a/pkg/handlers/rpm.go b/pkg/handlers/rpm.go index b9b397f6d493..edd138c48296 100644 --- a/pkg/handlers/rpm.go +++ b/pkg/handlers/rpm.go @@ -22,16 +22,16 @@ func newRPMHandler() *rpmHandler { // HandleFile processes RPM formatted files. Further implementation is required to appropriately // handle RPM specific archive operations. -func (h *rpmHandler) HandleFile(ctx logContext.Context, input fileReader) (chan []byte, error) { - archiveChan := make(chan []byte, defaultBufferSize) +func (h *rpmHandler) HandleFile(ctx logContext.Context, input fileReader) chan DataOrErr { + dataOrErrChan := make(chan DataOrErr, defaultBufferSize) if feature.ForceSkipArchives.Load() { - close(archiveChan) - return archiveChan, nil + close(dataOrErrChan) + return dataOrErrChan } go func() { - defer close(archiveChan) + defer close(dataOrErrChan) // Defer a panic recovery to handle any panics that occur during the RPM processing. defer func() { @@ -59,19 +59,23 @@ func (h *rpmHandler) HandleFile(ctx logContext.Context, input fileReader) (chan return } - err = h.processRPMFiles(ctx, reader, archiveChan) + err = h.processRPMFiles(ctx, reader, dataOrErrChan) if err == nil { h.metrics.incFilesProcessed() } // Update the metrics for the file processing and handle any errors. - h.measureLatencyAndHandleErrors(start, err) + h.measureLatencyAndHandleErrors(ctx, start, err, dataOrErrChan) }() - return archiveChan, nil + return dataOrErrChan } -func (h *rpmHandler) processRPMFiles(ctx logContext.Context, reader rpmutils.PayloadReader, archiveChan chan []byte) error { +func (h *rpmHandler) processRPMFiles( + ctx logContext.Context, + reader rpmutils.PayloadReader, + dataOrErrChan chan DataOrErr, +) error { for { select { case <-ctx.Done(): @@ -94,7 +98,7 @@ func (h *rpmHandler) processRPMFiles(ctx logContext.Context, reader rpmutils.Pay return fmt.Errorf("error creating mime-type reader: %w", err) } - if err := h.handleNonArchiveContent(fileCtx, rdr, archiveChan); err != nil { + if err := h.handleNonArchiveContent(fileCtx, rdr, dataOrErrChan); err != nil { fileCtx.Logger().Error(err, "error handling archive content in RPM") h.metrics.incErrors() } diff --git a/pkg/handlers/rpm_test.go b/pkg/handlers/rpm_test.go index f90d7b672fa0..7ed0e7ad7f6e 100644 --- a/pkg/handlers/rpm_test.go +++ b/pkg/handlers/rpm_test.go @@ -23,12 +23,12 @@ func TestHandleRPMFile(t *testing.T) { defer rdr.Close() handler := newRPMHandler() - archiveChan, err := handler.HandleFile(context.AddLogger(ctx), rdr) + dataOrErrChan := handler.HandleFile(context.AddLogger(ctx), rdr) assert.NoError(t, err) wantChunkCount := 179 count := 0 - for range archiveChan { + for range dataOrErrChan { count++ } diff --git a/pkg/iobuf/bufferedreaderseeker.go b/pkg/iobuf/bufferedreaderseeker.go index 47ba5119e5ba..1c3a623cbc8c 100644 --- a/pkg/iobuf/bufferedreaderseeker.go +++ b/pkg/iobuf/bufferedreaderseeker.go @@ -266,6 +266,10 @@ func (br *BufferedReadSeeker) readToEnd() error { } func (br *BufferedReadSeeker) writeData(data []byte) error { + if br.buf == nil { + br.buf = br.bufPool.Get() + } + _, err := br.buf.Write(data) if err != nil { return err