diff --git a/src/cli/actions/defaultAction.ts b/src/cli/actions/defaultAction.ts index 3cadb9b11..40958563f 100644 --- a/src/cli/actions/defaultAction.ts +++ b/src/cli/actions/defaultAction.ts @@ -13,6 +13,7 @@ import { generateDefaultSkillName } from '../../core/skill/skillUtils.js'; import { RepomixError, rethrowValidationErrorIfZodError } from '../../shared/errorHandle.js'; import { logger } from '../../shared/logger.js'; import { splitPatterns } from '../../shared/patternUtils.js'; +import type { RepomixProgressCallback } from '../../shared/types.js'; import { reportResults } from '../cliReport.js'; import { Spinner } from '../cliSpinner.js'; import { promptSkillLocation, resolveAndPrepareSkillDir } from '../prompts/skillPrompts.js'; @@ -28,6 +29,7 @@ export const runDefaultAction = async ( directories: string[], cwd: string, cliOptions: CliOptions, + progressCallback?: RepomixProgressCallback, ): Promise => { logger.trace('Loaded CLI options:', cliOptions); @@ -113,16 +115,20 @@ export const runDefaultAction = async ( const targetPaths = stdinFilePaths ? [cwd] : directories.map((directory) => path.resolve(cwd, directory)); - packResult = await pack( - targetPaths, - config, - (message) => { - spinner.update(message); - }, - {}, - stdinFilePaths, - packOptions, - ); + const handleProgress: RepomixProgressCallback = (message) => { + spinner.update(message); + if (progressCallback) { + try { + Promise.resolve(progressCallback(message)).catch((error) => { + logger.trace('progressCallback error:', error); + }); + } catch (error) { + logger.trace('progressCallback error:', error); + } + } + }; + + packResult = await pack(targetPaths, config, handleProgress, {}, stdinFilePaths, packOptions); spinner.succeed('Packing completed successfully!'); } catch (error) { diff --git a/src/shared/types.ts b/src/shared/types.ts index 6a30d6d9b..e0c49862b 100644 --- a/src/shared/types.ts +++ b/src/shared/types.ts @@ -1 +1 @@ -export type RepomixProgressCallback = (message: string) => void; +export type RepomixProgressCallback = (message: string) => void | Promise; diff --git a/tests/cli/actions/defaultAction.test.ts b/tests/cli/actions/defaultAction.test.ts index 920e94631..748993e39 100644 --- a/tests/cli/actions/defaultAction.test.ts +++ b/tests/cli/actions/defaultAction.test.ts @@ -6,6 +6,7 @@ import * as configLoader from '../../../src/config/configLoad.js'; import * as fileStdin from '../../../src/core/file/fileStdin.js'; import * as packageJsonParser from '../../../src/core/file/packageJsonParse.js'; import * as packager from '../../../src/core/packager.js'; +import * as loggerModule from '../../../src/shared/logger.js'; import { createMockConfig } from '../../testing/testUtils.js'; vi.mock('../../../src/core/packager'); @@ -156,6 +157,95 @@ describe('defaultAction', () => { expect(mockSpinner.fail).toHaveBeenCalledWith('Error during packing'); }); + describe('progressCallback', () => { + it('should forward progress messages to the provided callback', async () => { + // Configure pack mock to invoke its 3rd argument (progressCallback) + vi.mocked(packager.pack).mockImplementation(async (_paths, _config, progressCallback = () => {}) => { + progressCallback('Searching for files...'); + progressCallback('Processing files...'); + return { + totalFiles: 10, + totalCharacters: 1000, + totalTokens: 200, + fileCharCounts: {}, + fileTokenCounts: {}, + suspiciousFilesResults: [], + suspiciousGitDiffResults: [], + suspiciousGitLogResults: [], + processedFiles: [], + safeFilePaths: [], + gitDiffTokenCount: 0, + gitLogTokenCount: 0, + skippedFiles: [], + }; + }); + + const callback = vi.fn(); + await runDefaultAction(['.'], process.cwd(), {}, callback); + + expect(callback).toHaveBeenCalledWith('Searching for files...'); + expect(callback).toHaveBeenCalledWith('Processing files...'); + }); + + it('should isolate async callback errors without affecting pack flow', async () => { + vi.mocked(packager.pack).mockImplementation(async (_paths, _config, progressCallback = () => {}) => { + progressCallback('test message'); + // Allow microtask to process the rejected promise + await new Promise((resolve) => setTimeout(resolve, 10)); + return { + totalFiles: 10, + totalCharacters: 1000, + totalTokens: 200, + fileCharCounts: {}, + fileTokenCounts: {}, + suspiciousFilesResults: [], + suspiciousGitDiffResults: [], + suspiciousGitLogResults: [], + processedFiles: [], + safeFilePaths: [], + gitDiffTokenCount: 0, + gitLogTokenCount: 0, + skippedFiles: [], + }; + }); + + const rejectingCallback = vi.fn().mockRejectedValue(new Error('callback error')); + const result = await runDefaultAction(['.'], process.cwd(), {}, rejectingCallback); + + expect(result.packResult.totalFiles).toBe(10); + expect(loggerModule.logger.trace).toHaveBeenCalledWith('progressCallback error:', expect.any(Error)); + }); + + it('should still update spinner even when callback throws synchronously', async () => { + vi.mocked(packager.pack).mockImplementation(async (_paths, _config, progressCallback = () => {}) => { + progressCallback('test message'); + return { + totalFiles: 10, + totalCharacters: 1000, + totalTokens: 200, + fileCharCounts: {}, + fileTokenCounts: {}, + suspiciousFilesResults: [], + suspiciousGitDiffResults: [], + suspiciousGitLogResults: [], + processedFiles: [], + safeFilePaths: [], + gitDiffTokenCount: 0, + gitLogTokenCount: 0, + skippedFiles: [], + }; + }); + + const throwingCallback = vi.fn().mockImplementation(() => { + throw new Error('sync error'); + }); + await runDefaultAction(['.'], process.cwd(), {}, throwingCallback); + + // Spinner should still be updated despite callback failure + expect(mockSpinner.update).toHaveBeenCalledWith('test message'); + }); + }); + describe('buildCliConfig', () => { it('should handle custom include patterns', () => { const options = {