diff --git a/packages/vitest/src/node/pools/pool.ts b/packages/vitest/src/node/pools/pool.ts index a1f8c6ede1cb..b7a34b6e7c1f 100644 --- a/packages/vitest/src/node/pools/pool.ts +++ b/packages/vitest/src/node/pools/pool.ts @@ -34,7 +34,7 @@ export class Pool { private activeTasks: ActiveTask[] = [] private sharedRunners: PoolRunner[] = [] private exitPromises: Promise[] = [] - private _isCancelling: boolean = false + private cancellingPromise: Promise | null = null constructor(private options: Options, private logger: Logger) {} @@ -47,9 +47,9 @@ export class Pool { } async run(task: PoolTask, method: 'run' | 'collect'): Promise { - // Prevent new tasks from being queued during cancellation - if (this._isCancelling) { - throw new Error('[vitest-pool]: Cannot run tasks while pool is cancelling') + // Wait for any ongoing cancellation to complete before accepting new tasks + if (this.cancellingPromise) { + await this.cancellingPromise } // Every runner related failure should make this promise reject so that it's picked by pool. @@ -167,28 +167,32 @@ export class Pool { } async cancel(): Promise { - // Set flag to prevent new tasks from being queued - this._isCancelling = true + // Create a promise to track cancellation completion + this.cancellingPromise = (async () => { + const pendingTasks = this.queue.splice(0) - const pendingTasks = this.queue.splice(0) - - if (pendingTasks.length) { - const error = new Error('Cancelled') - pendingTasks.forEach(task => task.resolver.reject(error)) - } + if (pendingTasks.length) { + const error = new Error('Cancelled') + pendingTasks.forEach(task => task.resolver.reject(error)) + } - const activeTasks = this.activeTasks.splice(0) - await Promise.all(activeTasks.map(task => task.cancelTask())) + const activeTasks = this.activeTasks.splice(0) + await Promise.all(activeTasks.map(task => task.cancelTask())) - const sharedRunners = this.sharedRunners.splice(0) - await Promise.all(sharedRunners.map(runner => runner.stop())) + const sharedRunners = this.sharedRunners.splice(0) + await Promise.all(sharedRunners.map(runner => runner.stop())) - await Promise.all(this.exitPromises.splice(0)) + await Promise.all(this.exitPromises.splice(0)) - this.workerIds.forEach((_, id) => this.freeWorkerId(id)) + this.workerIds.forEach((_, id) => this.freeWorkerId(id)) + })() - // Reset flag after cancellation completes - this._isCancelling = false + try { + await this.cancellingPromise + } + finally { + this.cancellingPromise = null + } } async close(): Promise { diff --git a/test/cli/fixtures/bail-race/add.spec.js b/test/cli/fixtures/bail-race/add.spec.js new file mode 100644 index 000000000000..991ec40b26cd --- /dev/null +++ b/test/cli/fixtures/bail-race/add.spec.js @@ -0,0 +1,9 @@ +import { expect, test } from 'vitest' + +test('adds two numbers', () => { + expect(2 + 3).toBe(5) +}) + +test('fails adding two numbers', () => { + expect(2 + 3).toBe(6) +}) diff --git a/test/cli/test/bail-race.test.ts b/test/cli/test/bail-race.test.ts new file mode 100644 index 000000000000..fd8e133236e1 --- /dev/null +++ b/test/cli/test/bail-race.test.ts @@ -0,0 +1,61 @@ +import { resolve } from 'pathe' +import { expect, test } from 'vitest' +import { runVitest } from '../../test-utils' + +test('cancels previous run before starting new one', async () => { + const results: Record[] = [] + + const { ctx: vitest, buildTestTree } = await runVitest({ + root: resolve(import.meta.dirname, '../fixtures/bail-race'), + bail: 1, + pool: 'threads', + reporters: [{ + onTestRunEnd() { + results.push(buildTestTree()) + }, + }], + }) + + if (!vitest) { + throw new Error('Vitest context is not available') + } + + let rounds = 0 + + while (vitest.state.errorsSet.size === 0) { + await vitest.start() + + if (rounds >= 2) { + break + } + + rounds++ + } + + if (vitest.state.errorsSet.size > 0) { + throw vitest.state.errorsSet.values().next().value + } + + expect(results).toMatchInlineSnapshot(` + [ + { + "add.spec.js": { + "adds two numbers": "passed", + "fails adding two numbers": "failed", + }, + }, + { + "add.spec.js": { + "adds two numbers": "passed", + "fails adding two numbers": "failed", + }, + }, + { + "add.spec.js": { + "adds two numbers": "passed", + "fails adding two numbers": "failed", + }, + }, + ] + `) +}) diff --git a/test/test-utils/index.ts b/test/test-utils/index.ts index 2eca071ef630..ff6744064fa4 100644 --- a/test/test-utils/index.ts +++ b/test/test-utils/index.ts @@ -141,6 +141,9 @@ export async function runVitest( vitest: cli, stdout: cli.stdout, stderr: cli.stderr, + buildTestTree: () => { + return buildTestTree(ctx?.state.getTestModules() || []) + }, waitForClose: async () => { await new Promise(resolve => ctx!.onClose(resolve)) return ctx?.closingPromise