51 lines
1.4 KiB
Python

from __future__ import annotations
from functools import partial, wraps
from typing import TYPE_CHECKING, TypeVar
from .. import _core
from ..abc import Clock, Instrument
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from typing_extensions import ParamSpec
ArgsT = ParamSpec("ArgsT")
RetT = TypeVar("RetT")
def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]:
"""Converts an async test function to be synchronous, running via Trio.
Usage::
@trio_test
async def test_whatever():
await ...
If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or
:class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`.
"""
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
__tracebackhide__ = True
clocks = [c for c in kwargs.values() if isinstance(c, Clock)]
if not clocks:
clock = None
elif len(clocks) == 1:
clock = clocks[0]
else:
raise ValueError("too many clocks spoil the broth!")
instruments = [i for i in kwargs.values() if isinstance(i, Instrument)]
return _core.run(
partial(fn, *args, **kwargs),
clock=clock,
instruments=instruments,
)
return wrapper