Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
enable-cache: true
cache-dependency-glob: "**/pyproject.toml"
python-version: "3.10"

- name: Install the project
run: uv sync --extra dev --resolution lowest-direct

- name: Test package
run: uv run --frozen pytest

benchmark:
needs: [format]
runs-on: ubuntu-24.04
if:
github.event_name == 'workflow_dispatch' || (github.event_name ==
'pull_request' && contains(github.event.pull_request.labels.*.name,
'run-benchmarks')) || (github.event_name == 'push' && github.ref ==
'refs/heads/main')
steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
Expand All @@ -66,10 +90,13 @@ jobs:
- name: "Set up Python"
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.13"

- name: Install the project
run: uv sync --extra dev --resolution lowest-direct
run: uv sync --extra dev

- name: Test package
run: uv run --extra dev --resolution lowest-direct pytest
- name: Run benchmarks
uses: CodSpeedHQ/action@v3
with:
token: ${{ secrets.CODSPEED_TOKEN }}
run: uv run --frozen pytest tests/benchmark --codspeed
Empty file added tests/benchmark/__init__.py
Empty file.
139 changes: 139 additions & 0 deletions tests/benchmark/test_numpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
"""Benchmark tests for quaxed functions on quantities."""

from collections.abc import Callable
from typing import Any, TypeAlias, TypedDict
from typing_extensions import Unpack

import jax
import jax.numpy as jnp
import pytest
from jax._src.stages import Compiled

import quax

from ..myarray import MyArray


Args: TypeAlias = tuple[Any, ...]

x = jnp.linspace(0, 1, 1000)
xm = MyArray(x)


def process_func(func: Callable[..., Any], args: Args) -> tuple[Compiled, Args]:
"""JIT and compile the function."""
return jax.jit(quax.quaxify(func)), args


class ParameterizationKWArgs(TypedDict):
"""Keyword arguments for a pytest parameterization."""

argvalues: list[tuple[Callable[..., Any], Args]]
ids: list[str]


def process_pytest_argvalues(
process_fn: Callable[[Callable[..., Any], Args], tuple[Callable[..., Any], Args]],
argvalues: list[tuple[Callable[..., Any], Unpack[tuple[Args, ...]]]],
) -> ParameterizationKWArgs:
"""Process the argvalues."""
# Get the ID for each parameterization
get_types = lambda args: tuple(str(type(a)) for a in args)
ids: list[str] = []
processed_argvalues: list[tuple[Compiled, Args]] = []

for func, *many_args in argvalues:
for args in many_args:
ids.append(f"{func.__name__}-{get_types(args)}")
processed_argvalues.append(process_fn(func, args))

# Process the argvalues and return the parameterization, with IDs
return {"argvalues": processed_argvalues, "ids": ids}


funcs_and_args: list[tuple[Callable[..., Any], Unpack[tuple[Args, ...]]]] = [
(jnp.abs, (xm,)),
(jnp.acos, (xm,)),
(jnp.acosh, (xm,)),
(jnp.add, (xm, xm)),
(jnp.asin, (xm,)),
(jnp.asinh, (xm,)),
(jnp.atan, (xm,)),
(jnp.atan2, (xm, xm)),
(jnp.atanh, (xm,)),
# bitwise_and
# bitwise_left_shift
# bitwise_invert
# bitwise_or
# bitwise_right_shift
# bitwise_xor
(jnp.ceil, (xm,)),
(jnp.conj, (xm,)),
(jnp.cos, (xm,)),
(jnp.cosh, (xm,)),
(jnp.divide, (xm, xm)),
(jnp.equal, (xm, xm)),
(jnp.exp, (xm,)),
(jnp.expm1, (xm,)),
(jnp.floor, (xm,)),
(jnp.floor_divide, (xm, xm)),
(jnp.greater, (xm, xm)),
(jnp.greater_equal, (xm, xm)),
(jnp.imag, (xm,)),
(jnp.isfinite, (xm,)),
(jnp.isinf, (xm,)),
(jnp.isnan, (xm,)),
(jnp.less, (xm, xm)),
(jnp.less_equal, (xm, xm)),
(jnp.log, (xm,)),
(jnp.log1p, (xm,)),
(jnp.log2, (xm,)),
(jnp.log10, (xm,)),
(jnp.logaddexp, (xm, xm)),
(jnp.logical_and, (xm, xm)),
(jnp.logical_not, (xm,)),
(jnp.logical_or, (xm, xm)),
(jnp.logical_xor, (xm, xm)),
(jnp.multiply, (xm, xm)),
(jnp.negative, (xm,)),
(jnp.not_equal, (xm, xm)),
(jnp.positive, (xm,)),
(jnp.power, (xm, 2.0)),
(jnp.real, (xm,)),
(jnp.remainder, (xm, xm)),
(jnp.round, (xm,)),
(jnp.sign, (xm,)),
(jnp.sin, (xm,)),
(jnp.sinh, (xm,)),
(jnp.square, (xm,)),
(jnp.sqrt, (xm,)),
(jnp.subtract, (xm, xm)),
(jnp.tan, (xm,)),
(jnp.tanh, (xm,)),
(jnp.trunc, (xm,)),
]


# =============================================================================


@pytest.mark.parametrize(
("func", "args"), **process_pytest_argvalues(process_func, funcs_and_args)
)
@pytest.mark.benchmark(group="quaxed", max_time=1.0, warmup=False)
def test_jit_compile(func, args):
"""Test the speed of jitting a function."""
_ = func.lower(*args).compile()


@pytest.mark.parametrize(
("func", "args"), **process_pytest_argvalues(process_func, funcs_and_args)
)
@pytest.mark.benchmark(
group="quaxed",
max_time=1.0, # NOTE: max_time is ignored
warmup=True,
)
def test_execute(func, args):
"""Test the speed of calling the function."""
_ = jax.block_until_ready(func(*args))