Skip to content

feat: provide ctx.signal #7878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions docs/guide/test-context.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ The first argument for each test callback is a test context.
```ts
import { it } from 'vitest'

it('should work', (ctx) => {
it('should work', ({ task }) => {
// prints name of the test
console.log(ctx.task.name)
console.log(task.name)
})
```

Expand Down Expand Up @@ -65,6 +65,21 @@ it('math is hard', ({ skip }) => {
})
```

#### `context.signal` <Version>3.2.0</Version> {#context-signal}

A signal object that can be aborted by Vitest. The signal is aborted in these situations:

- Test times out
- User manually cancelled the test run with Ctrl+C
- [`vitest.cancelCurrentRun`](/advanced/api/vitest#cancelcurrentrun) was called programmatically
- Another test failed in parallel and the [`bail`](/config/#bail) flag is set

```ts
it('stop request when test times out', async ({ signal }) => {
await fetch('/resource', { signal })
}, 2000)
```

## Extend Test Context

Vitest provides two different ways to help you extend the test context.
Expand Down
10 changes: 5 additions & 5 deletions packages/browser/src/client/tester/runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ export function createBrowserRunner(
const currentFailures = 1 + previousFailures

if (currentFailures >= this.config.bail) {
rpc().onCancel('test-failure')
this.onCancel('test-failure')
rpc().cancelCurrentRun('test-failure')
this.cancel('test-failure')
}
}
}
Expand All @@ -81,8 +81,8 @@ export function createBrowserRunner(
}
}

onCancel = (reason: CancelReason) => {
super.onCancel?.(reason)
cancel = (reason: CancelReason) => {
super.cancel?.(reason)
globalChannel.postMessage({ type: 'cancel', reason })
}

Expand Down Expand Up @@ -196,7 +196,7 @@ export async function initiateRunner(
cachedRunner = runner

onCancel.then((reason) => {
runner.onCancel?.(reason)
runner.cancel?.(reason)
})

const [diffOptions] = await Promise.all([
Expand Down
2 changes: 1 addition & 1 deletion packages/browser/src/node/rpc.ts
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ export function setupBrowserRpc(globalServer: ParentBrowserProject, defaultMocke
const mod = globalServer.vite.moduleGraph.getModuleById(id)
return mod?.transformResult?.map
},
onCancel(reason) {
cancelCurrentRun(reason) {
vitest.cancelCurrentRun(reason)
},
async resolveId(id, importer) {
Expand Down
2 changes: 1 addition & 1 deletion packages/browser/src/node/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export interface WebSocketBrowserHandlers {
onCollected: (method: TestExecutionMethod, files: RunnerTestFile[]) => Promise<void>
onTaskUpdate: (method: TestExecutionMethod, packs: TaskResultPack[], events: TaskEventPack[]) => void
onAfterSuiteRun: (meta: AfterSuiteRunMeta) => void
onCancel: (reason: CancelReason) => void
cancelCurrentRun: (reason: CancelReason) => void
getCountOfFailedTests: () => number
readSnapshotFile: (id: string) => Promise<string | null>
saveSnapshotFile: (id: string, content: string) => Promise<void>
Expand Down
44 changes: 40 additions & 4 deletions packages/runner/src/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type {
SuiteCollector,
Test,
TestContext,
WriteableTestContext,
} from './types/tasks'
import { getSafeTimers } from '@vitest/utils'
import { PendingError } from './errors'
Expand Down Expand Up @@ -36,6 +37,7 @@ export function withTimeout<T extends (...args: any[]) => any>(
timeout: number,
isHook = false,
stackTraceError?: Error,
onTimeout?: (args: T extends (...args: infer A) => any ? A : never, error: Error) => void,
): T {
if (timeout <= 0 || timeout === Number.POSITIVE_INFINITY) {
return fn
Expand All @@ -58,7 +60,9 @@ export function withTimeout<T extends (...args: any[]) => any>(
timer.unref?.()

function rejectTimeoutError() {
reject_(makeTimeoutError(isHook, timeout, stackTraceError))
const error = makeTimeoutError(isHook, timeout, stackTraceError)
onTimeout?.(args, error)
reject_(error)
}

function resolve(result: unknown) {
Expand Down Expand Up @@ -102,14 +106,34 @@ export function withTimeout<T extends (...args: any[]) => any>(
}) as T
}

const abortControllers = new WeakMap<TestContext, AbortController>()

export function abortIfTimeout([context]: [TestContext?], error: Error): void {
if (context) {
abortContextSignal(context, error)
}
}

export function abortContextSignal(context: TestContext, error: Error): void {
const abortController = abortControllers.get(context)
abortController?.abort(error)
}

export function createTestContext(
test: Test,
runner: VitestRunner,
): TestContext {
const context = function () {
throw new Error('done() callback is deprecated, use promise instead')
} as unknown as TestContext
} as unknown as WriteableTestContext

const abortController = abortControllers.get(context) || (() => {
const abortController = new AbortController()
abortControllers.set(context, abortController)
return abortController
})()

context.signal = abortController.signal
context.task = test

context.skip = (condition?: boolean | string, note?: string): never => {
Expand All @@ -129,14 +153,26 @@ export function createTestContext(
context.onTestFailed = (handler, timeout) => {
test.onFailed ||= []
test.onFailed.push(
withTimeout(handler, timeout ?? runner.config.hookTimeout, true, new Error('STACK_TRACE_ERROR')),
withTimeout(
handler,
timeout ?? runner.config.hookTimeout,
true,
new Error('STACK_TRACE_ERROR'),
(_, error) => abortController.abort(error),
),
)
}

context.onTestFinished = (handler, timeout) => {
test.onFinished ||= []
test.onFinished.push(
withTimeout(handler, timeout ?? runner.config.hookTimeout, true, new Error('STACK_TRACE_ERROR')),
withTimeout(
handler,
timeout ?? runner.config.hookTimeout,
true,
new Error('STACK_TRACE_ERROR'),
(_, error) => abortController.abort(error),
),
)
}

Expand Down
5 changes: 5 additions & 0 deletions packages/runner/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,8 @@ export class PendingError extends Error {
this.taskId = task.id
}
}

export class AbortError extends Error {
name = 'AbortError'
code = 20
}
21 changes: 18 additions & 3 deletions packages/runner/src/hooks.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ import type {
OnTestFinishedHandler,
TaskHook,
TaskPopulated,
TestContext,
} from './types/tasks'
import { assertTypes } from '@vitest/utils'
import { withTimeout } from './context'
import { abortContextSignal, abortIfTimeout, withTimeout } from './context'
import { withFixtures } from './fixture'
import { getCurrentSuite, getRunner } from './suite'
import { getCurrentTest } from './test-state'
Expand All @@ -21,7 +22,7 @@ function getDefaultHookTimeout() {
const CLEANUP_TIMEOUT_KEY = Symbol.for('VITEST_CLEANUP_TIMEOUT')
const CLEANUP_STACK_TRACE_KEY = Symbol.for('VITEST_CLEANUP_STACK_TRACE')

export function getBeforeHookCleanupCallback(hook: Function, result: any): Function | undefined {
export function getBeforeHookCleanupCallback(hook: Function, result: any, context?: TestContext): Function | undefined {
if (typeof result === 'function') {
const timeout
= CLEANUP_TIMEOUT_KEY in hook && typeof hook[CLEANUP_TIMEOUT_KEY] === 'number'
Expand All @@ -31,7 +32,17 @@ export function getBeforeHookCleanupCallback(hook: Function, result: any): Funct
= CLEANUP_STACK_TRACE_KEY in hook && hook[CLEANUP_STACK_TRACE_KEY] instanceof Error
? hook[CLEANUP_STACK_TRACE_KEY]
: undefined
return withTimeout(result, timeout, true, stackTraceError)
return withTimeout(
result,
timeout,
true,
stackTraceError,
(_, error) => {
if (context) {
abortContextSignal(context, error)
}
},
)
}
}

Expand Down Expand Up @@ -136,6 +147,7 @@ export function beforeEach<ExtraContext = object>(
timeout ?? getDefaultHookTimeout(),
true,
stackTraceError,
abortIfTimeout,
),
{
[CLEANUP_TIMEOUT_KEY]: timeout,
Expand Down Expand Up @@ -174,6 +186,7 @@ export function afterEach<ExtraContext = object>(
timeout ?? getDefaultHookTimeout(),
true,
new Error('STACK_TRACE_ERROR'),
abortIfTimeout,
),
)
}
Expand Down Expand Up @@ -206,6 +219,7 @@ export const onTestFailed: TaskHook<OnTestFailedHandler> = createTestHook(
timeout ?? getDefaultHookTimeout(),
true,
new Error('STACK_TRACE_ERROR'),
abortIfTimeout,
),
)
},
Expand Down Expand Up @@ -244,6 +258,7 @@ export const onTestFinished: TaskHook<OnTestFinishedHandler> = createTestHook(
timeout ?? getDefaultHookTimeout(),
true,
new Error('STACK_TRACE_ERROR'),
abortIfTimeout,
),
)
},
Expand Down
57 changes: 41 additions & 16 deletions packages/runner/src/run.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,17 @@ import type {
TaskUpdateEvent,
Test,
TestContext,
WriteableTestContext,
} from './types/tasks'
import { shuffle } from '@vitest/utils'
import { processError } from '@vitest/utils/error'
import { collectTests } from './collect'
import { PendingError } from './errors'
import { abortContextSignal } from './context'
import { AbortError, PendingError } from './errors'
import { callFixtureCleanup } from './fixture'
import { getBeforeHookCleanupCallback } from './hooks'
import { getFn, getHooks } from './map'
import { setCurrentTest } from './test-state'
import { addRunningTest, getRunningTests, setCurrentTest } from './test-state'
import { limitConcurrency } from './utils/limit-concurrency'
import { partitionSuiteChildren } from './utils/suite'
import { hasFailed, hasTests } from './utils/tasks'
Expand Down Expand Up @@ -87,12 +89,14 @@ async function callTestHooks(
return
}

const context = test.context as WriteableTestContext

const onTestFailed = test.context.onTestFailed
const onTestFinished = test.context.onTestFinished
test.context.onTestFailed = () => {
context.onTestFailed = () => {
throw new Error(`Cannot call "onTestFailed" inside a test hook.`)
}
test.context.onTestFinished = () => {
context.onTestFinished = () => {
throw new Error(`Cannot call "onTestFinished" inside a test hook.`)
}

Expand All @@ -115,8 +119,8 @@ async function callTestHooks(
}
}

test.context.onTestFailed = onTestFailed
test.context.onTestFinished = onTestFinished
context.onTestFailed = onTestFailed
context.onTestFinished = onTestFinished
}

export async function callSuiteHook<T extends keyof SuiteHooks>(
Expand Down Expand Up @@ -145,7 +149,11 @@ export async function callSuiteHook<T extends keyof SuiteHooks>(
}

async function runHook(hook: Function) {
return getBeforeHookCleanupCallback(hook, await hook(...args))
return getBeforeHookCleanupCallback(
hook,
await hook(...args),
name === 'beforeEach' ? args[0] : undefined,
)
}

if (sequence === 'parallel') {
Expand Down Expand Up @@ -274,6 +282,7 @@ export async function runTest(test: Test, runner: VitestRunner): Promise<void> {
}
updateTask('test-prepare', test, runner)

const cleanupRunningTest = addRunningTest(test)
setCurrentTest(test)

const suite = test.suite || test.file
Expand Down Expand Up @@ -374,6 +383,7 @@ export async function runTest(test: Test, runner: VitestRunner): Promise<void> {
}
updateTask('test-finished', test, runner)
setCurrentTest(undefined)
cleanupRunningTest()
return
}

Expand Down Expand Up @@ -405,6 +415,7 @@ export async function runTest(test: Test, runner: VitestRunner): Promise<void> {
}
}

cleanupRunningTest()
setCurrentTest(undefined)

test.result.duration = now() - start
Expand Down Expand Up @@ -588,21 +599,35 @@ export async function runFiles(files: File[], runner: VitestRunner): Promise<voi
}

export async function startTests(specs: string[] | FileSpecification[], runner: VitestRunner): Promise<File[]> {
const paths = specs.map(f => typeof f === 'string' ? f : f.filepath)
await runner.onBeforeCollect?.(paths)
const cancel = runner.cancel?.bind(runner)
runner.cancel = (reason) => {
const error = new AbortError('The test run was aborted by the user.')
getRunningTests().forEach(test =>
abortContextSignal(test.context, error),
)
return cancel?.(reason)
}

const files = await collectTests(specs, runner)
try {
const paths = specs.map(f => typeof f === 'string' ? f : f.filepath)
await runner.onBeforeCollect?.(paths)

await runner.onCollected?.(files)
await runner.onBeforeRunFiles?.(files)
const files = await collectTests(specs, runner)

await runFiles(files, runner)
await runner.onCollected?.(files)
await runner.onBeforeRunFiles?.(files)

await runner.onAfterRunFiles?.(files)
await runFiles(files, runner)

await finishSendTasksUpdate(runner)
await runner.onAfterRunFiles?.(files)

return files
await finishSendTasksUpdate(runner)

return files
}
finally {
runner.cancel = cancel
}
}

async function publicCollect(specs: string[] | FileSpecification[], runner: VitestRunner): Promise<File[]> {
Expand Down
Loading
Loading