Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/behaviors.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import { equals } from '@vitest/expect'

import type { AnyFunction, WithMatchers } from './types.ts'
import type {
AnyCallable,
AnyFunction,
ExtractParameters,
ExtractReturnType,
WithMatchers,
} from './types.ts'

export interface WhenOptions {
times?: number
}

export interface BehaviorStack<TFunc extends AnyFunction> {
use: (args: Parameters<TFunc>) => BehaviorEntry<Parameters<TFunc>> | undefined
export interface BehaviorStack<TFunc extends AnyCallable> {
use: (
args: ExtractParameters<TFunc>,
) => BehaviorEntry<ExtractParameters<TFunc>> | undefined

getAll: () => readonly BehaviorEntry<Parameters<TFunc>>[]
getAll: () => readonly BehaviorEntry<ExtractParameters<TFunc>>[]

getUnmatchedCalls: () => readonly Parameters<TFunc>[]
getUnmatchedCalls: () => readonly ExtractParameters<TFunc>[]

bindArgs: (
args: WithMatchers<Parameters<TFunc>>,
args: WithMatchers<ExtractParameters<TFunc>>,
options: WhenOptions,
) => BoundBehaviorStack<ReturnType<TFunc>>
) => BoundBehaviorStack<ExtractReturnType<TFunc>>
}

export interface BoundBehaviorStack<TReturn> {
Expand Down Expand Up @@ -55,10 +63,10 @@ export interface BehaviorOptions<TValue> {
}

export const createBehaviorStack = <
TFunc extends AnyFunction,
TFunc extends AnyCallable,
>(): BehaviorStack<TFunc> => {
const behaviors: BehaviorEntry<Parameters<TFunc>>[] = []
const unmatchedCalls: Parameters<TFunc>[] = []
const behaviors: BehaviorEntry<ExtractParameters<TFunc>>[] = []
const unmatchedCalls: ExtractParameters<TFunc>[] = []

return {
getAll: () => behaviors,
Expand Down
4 changes: 2 additions & 2 deletions src/debug.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {

import { type Behavior, BehaviorType } from './behaviors'
import { getBehaviorStack, validateSpy } from './stubs'
import type { AnyFunction, MockInstance } from './types'
import type { AnyCallable, MockInstance } from './types'

export interface DebugResult {
name: string
Expand All @@ -20,7 +20,7 @@ export interface Stubbing {
calls: readonly unknown[][]
}

export const getDebug = <TFunc extends AnyFunction>(
export const getDebug = <TFunc extends AnyCallable>(
spy: TFunc | MockInstance<TFunc>,
): DebugResult => {
const target = validateSpy<TFunc>(spy)
Expand Down
21 changes: 13 additions & 8 deletions src/stubs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@ import {
createBehaviorStack,
} from './behaviors.ts'
import { NotAMockFunctionError } from './errors.ts'
import type { AnyFunction, MockInstance } from './types.ts'
import type {
AnyCallable,
AnyFunction,
ExtractParameters,
MockInstance,
} from './types.ts'

const BEHAVIORS_KEY = Symbol('behaviors')

interface WhenStubImplementation<TFunc extends AnyFunction> {
(...args: Parameters<TFunc>): unknown
interface WhenStubImplementation<TFunc extends AnyCallable> {
(...args: ExtractParameters<TFunc>): unknown
[BEHAVIORS_KEY]: BehaviorStack<TFunc>
}

export const configureStub = <TFunc extends AnyFunction>(
export const configureStub = <TFunc extends AnyCallable>(
maybeSpy: unknown,
): BehaviorStack<TFunc> => {
const spy = validateSpy<TFunc>(maybeSpy)
Expand All @@ -26,10 +31,10 @@ export const configureStub = <TFunc extends AnyFunction>(
const behaviors = createBehaviorStack<TFunc>()
const fallbackImplementation = spy.getMockImplementation()

const implementation = (...args: Parameters<TFunc>) => {
const implementation = (...args: ExtractParameters<TFunc>) => {
const behavior = behaviors.use(args)?.behavior ?? {
type: BehaviorType.DO,
callback: fallbackImplementation,
callback: fallbackImplementation as AnyFunction | undefined,
}

switch (behavior.type) {
Expand Down Expand Up @@ -63,7 +68,7 @@ export const configureStub = <TFunc extends AnyFunction>(
return behaviors
}

export const validateSpy = <TFunc extends AnyFunction>(
export const validateSpy = <TFunc extends AnyCallable>(
maybeSpy: unknown,
): MockInstance<TFunc> => {
if (
Expand All @@ -81,7 +86,7 @@ export const validateSpy = <TFunc extends AnyFunction>(
throw new NotAMockFunctionError(maybeSpy)
}

export const getBehaviorStack = <TFunc extends AnyFunction>(
export const getBehaviorStack = <TFunc extends AnyCallable>(
spy: MockInstance<TFunc>,
): BehaviorStack<TFunc> | undefined => {
const existingImplementation = spy.getMockImplementation() as
Expand Down
30 changes: 25 additions & 5 deletions src/types.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
/** Common type definitions. */
import type { AsymmetricMatcher } from '@vitest/expect'

/** Any function, for use in `extends` */
/** Any function. */
export type AnyFunction = (...args: never[]) => unknown

/** Any constructor. */
export type AnyConstructor = new (...args: never[]) => unknown

/** Any callable, for use in `extends` */
export type AnyCallable = AnyFunction | AnyConstructor

/** Extract parameters from either a function or constructor. */
export type ExtractParameters<T> = T extends new (...args: infer P) => unknown
? P
: T extends (...args: infer P) => unknown
? P
: never

/** Extract return type from either a function or constructor */
export type ExtractReturnType<T> = T extends new (...args: never[]) => infer R
? R
: T extends (...args: never[]) => infer R
? R
: never

/** Accept a value or an AsymmetricMatcher in an arguments array */
export type WithMatchers<T extends unknown[]> = {
[K in keyof T]: T[K] | AsymmetricMatcher<unknown>
[K in keyof T]: AsymmetricMatcher<unknown> | T[K]
}

/**
Expand All @@ -15,13 +35,13 @@ export type WithMatchers<T extends unknown[]> = {
* Used to ensure backwards compatibility
* with older versions of Vitest.
*/
export interface MockInstance<TFunc extends AnyFunction = AnyFunction> {
export interface MockInstance<TFunc extends AnyCallable = AnyCallable> {
getMockName(): string
getMockImplementation(): TFunc | undefined
mockImplementation: (impl: TFunc) => this
mock: MockContext<TFunc>
}

export interface MockContext<TFunc extends AnyFunction> {
calls: Parameters<TFunc>[]
export interface MockContext<TFunc extends AnyCallable> {
calls: ExtractParameters<TFunc>[]
}
18 changes: 12 additions & 6 deletions src/vitest-when.ts
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import type { WhenOptions } from './behaviors.ts'
import { type DebugResult, getDebug } from './debug.ts'
import { configureStub } from './stubs.ts'
import type { AnyFunction, MockInstance, WithMatchers } from './types.ts'
import type {
AnyCallable,
ExtractParameters,
ExtractReturnType,
MockInstance,
WithMatchers,
} from './types.ts'

export { type Behavior, BehaviorType, type WhenOptions } from './behaviors.ts'
export type { DebugResult, Stubbing } from './debug.ts'
export * from './errors.ts'

export interface StubWrapper<TFunc extends AnyFunction> {
calledWith<TArgs extends Parameters<TFunc>>(
export interface StubWrapper<TFunc extends AnyCallable> {
calledWith<TArgs extends ExtractParameters<TFunc>>(
...args: WithMatchers<TArgs>
): Stub<TArgs, ReturnType<TFunc>>
): Stub<TArgs, ExtractReturnType<TFunc>>
}

export interface Stub<TArgs extends unknown[], TReturn> {
Expand All @@ -21,7 +27,7 @@ export interface Stub<TArgs extends unknown[], TReturn> {
thenDo: (...callbacks: ((...args: TArgs) => TReturn)[]) => void
}

export const when = <TFunc extends AnyFunction>(
export const when = <TFunc extends AnyCallable>(
spy: TFunc | MockInstance<TFunc>,
options: WhenOptions = {},
): StubWrapper<TFunc> => {
Expand All @@ -46,7 +52,7 @@ export interface DebugOptions {
log?: boolean
}

export const debug = <TFunc extends AnyFunction>(
export const debug = <TFunc extends AnyCallable>(
spy: TFunc | MockInstance<TFunc>,
options: DebugOptions = {},
): DebugResult => {
Expand Down
14 changes: 14 additions & 0 deletions test/typing.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,20 @@ describe('vitest-when type signatures', () => {
subject.when(simple).calledWith(expect.any(Number))
subject.when(complex).calledWith(expect.objectContaining({ a: 1 }))
})

it('should accept a class constructor', () => {
// eslint-disable-next-line @typescript-eslint/no-extraneous-class
class TestClass {
constructor(input: number) {
throw new Error(`TestClass(${input})`)
}
}

subject.when(TestClass).calledWith(42)

// @ts-expect-error: args wrong type
subject.when(TestClass).calledWith('42')
})
})

function untyped(...args: any[]): any {
Expand Down
Loading