Updated script that can be controled by Nodejs web app

This commit is contained in:
mac OS
2024-11-25 12:24:18 +07:00
parent c440eda1f4
commit 8b0ab2bd3a
8662 changed files with 1803808 additions and 34 deletions

View File

@ -0,0 +1,39 @@
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
from .._core import (
MockClock as MockClock,
wait_all_tasks_blocked as wait_all_tasks_blocked,
)
from .._threads import (
active_thread_count as active_thread_count,
wait_all_threads_completed as wait_all_threads_completed,
)
from .._util import fixup_module_metadata
from ._check_streams import (
check_half_closeable_stream as check_half_closeable_stream,
check_one_way_stream as check_one_way_stream,
check_two_way_stream as check_two_way_stream,
)
from ._checkpoints import (
assert_checkpoints as assert_checkpoints,
assert_no_checkpoints as assert_no_checkpoints,
)
from ._memory_streams import (
MemoryReceiveStream as MemoryReceiveStream,
MemorySendStream as MemorySendStream,
lockstep_stream_one_way_pair as lockstep_stream_one_way_pair,
lockstep_stream_pair as lockstep_stream_pair,
memory_stream_one_way_pair as memory_stream_one_way_pair,
memory_stream_pair as memory_stream_pair,
memory_stream_pump as memory_stream_pump,
)
from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener
from ._raises_group import Matcher as Matcher, RaisesGroup as RaisesGroup
from ._sequencer import Sequencer as Sequencer
from ._trio_test import trio_test as trio_test
################################################################
fixup_module_metadata(__name__, globals())
del fixup_module_metadata

View File

@ -0,0 +1,572 @@
# Generic stream tests
from __future__ import annotations
import random
import sys
from contextlib import contextmanager, suppress
from typing import (
TYPE_CHECKING,
Awaitable,
Callable,
Generator,
Generic,
Tuple,
TypeVar,
)
from .. import CancelScope, _core
from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream
from .._highlevel_generic import aclose_forcefully
from ._checkpoints import assert_checkpoints
if TYPE_CHECKING:
from types import TracebackType
from typing_extensions import ParamSpec, TypeAlias
ArgsT = ParamSpec("ArgsT")
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
Res1 = TypeVar("Res1", bound=AsyncResource)
Res2 = TypeVar("Res2", bound=AsyncResource)
StreamMaker: TypeAlias = Callable[[], Awaitable[Tuple[Res1, Res2]]]
class _ForceCloseBoth(Generic[Res1, Res2]):
def __init__(self, both: tuple[Res1, Res2]) -> None:
self._first, self._second = both
async def __aenter__(self) -> tuple[Res1, Res2]:
return self._first, self._second
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
try:
await aclose_forcefully(self._first)
finally:
await aclose_forcefully(self._second)
# This is used in this file instead of pytest.raises in order to avoid a dependency
# on pytest, as the check_* functions are publicly exported.
@contextmanager
def _assert_raises(
expected_exc: type[BaseException],
wrapped: bool = False,
) -> Generator[None, None, None]:
__tracebackhide__ = True
try:
yield
except BaseExceptionGroup as exc:
assert wrapped, "caught exceptiongroup, but expected an unwrapped exception"
# assert in except block ignored below
assert len(exc.exceptions) == 1 # noqa: PT017
assert isinstance(exc.exceptions[0], expected_exc) # noqa: PT017
except expected_exc:
assert not wrapped, "caught exception, but expected an exceptiongroup"
else:
raise AssertionError(f"expected exception: {expected_exc}")
async def check_one_way_stream(
stream_maker: StreamMaker[SendStream, ReceiveStream],
clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None,
) -> None:
"""Perform a number of generic tests on a custom one-way stream
implementation.
Args:
stream_maker: An async (!) function which returns a connected
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`)
pair.
clogged_stream_maker: Either None, or an async function similar to
stream_maker, but with the extra property that the returned stream
is in a state where ``send_all`` and
``wait_send_all_might_not_block`` will block until ``receive_some``
has been called. This allows for more thorough testing of some edge
cases, especially around ``wait_send_all_might_not_block``.
Raises:
AssertionError: if a test fails.
"""
async with _ForceCloseBoth(await stream_maker()) as (s, r):
assert isinstance(s, SendStream)
assert isinstance(r, ReceiveStream)
async def do_send_all(data: bytes | bytearray | memoryview) -> None:
with assert_checkpoints(): # We're testing that it doesn't return anything.
assert await s.send_all(data) is None # type: ignore[func-returns-value]
async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray:
with assert_checkpoints():
return await r.receive_some(max_bytes)
async def checked_receive_1(expected: bytes) -> None:
assert await do_receive_some(1) == expected
async def do_aclose(resource: AsyncResource) -> None:
with assert_checkpoints():
await resource.aclose()
# Simple sending/receiving
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
nursery.start_soon(checked_receive_1, b"x")
async def send_empty_then_y() -> None:
# Streams should tolerate sending b"" without giving it any
# special meaning.
await do_send_all(b"")
await do_send_all(b"y")
async with _core.open_nursery() as nursery:
nursery.start_soon(send_empty_then_y)
nursery.start_soon(checked_receive_1, b"y")
# ---- Checking various argument types ----
# send_all accepts bytearray and memoryview
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, bytearray(b"1"))
nursery.start_soon(checked_receive_1, b"1")
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, memoryview(b"2"))
nursery.start_soon(checked_receive_1, b"2")
# max_bytes must be a positive integer
with _assert_raises(ValueError):
await r.receive_some(-1)
with _assert_raises(ValueError):
await r.receive_some(0)
with _assert_raises(TypeError):
await r.receive_some(1.5) # type: ignore[arg-type]
# it can also be missing or None
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
assert await do_receive_some() == b"x"
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all, b"x")
assert await do_receive_some(None) == b"x"
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(do_receive_some, 1)
nursery.start_soon(do_receive_some, 1)
# Method always has to exist, and an empty stream with a blocked
# receive_some should *always* allow send_all. (Technically it's legal
# for send_all to wait until receive_some is called to run, though; a
# stream doesn't *have* to have any internal buffering. That's why we
# start a concurrent receive_some call, then cancel it.)
async def simple_check_wait_send_all_might_not_block(
scope: CancelScope,
) -> None:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(
simple_check_wait_send_all_might_not_block,
nursery.cancel_scope,
)
nursery.start_soon(do_receive_some, 1)
# closing the r side leads to BrokenResourceError on the s side
# (eventually)
async def expect_broken_stream_on_send() -> None:
with _assert_raises(_core.BrokenResourceError):
while True:
await do_send_all(b"x" * 100)
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_broken_stream_on_send)
nursery.start_soon(do_aclose, r)
# once detected, the stream stays broken
with _assert_raises(_core.BrokenResourceError):
await do_send_all(b"x" * 100)
# r closed -> ClosedResourceError on the receive side
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
# we can close the same stream repeatedly, it's fine
await do_aclose(r)
await do_aclose(r)
# closing the sender side
await do_aclose(s)
# now trying to send raises ClosedResourceError
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"x" * 100)
# even if it's an empty send
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"")
# ditto for wait_send_all_might_not_block
with _assert_raises(_core.ClosedResourceError):
with assert_checkpoints():
await s.wait_send_all_might_not_block()
# and again, repeated closing is fine
await do_aclose(s)
await do_aclose(s)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
# if send-then-graceful-close, receiver gets data then b""
async def send_then_close() -> None:
await do_send_all(b"y")
await do_aclose(s)
async def receive_send_then_close() -> None:
# We want to make sure that if the sender closes the stream before
# we read anything, then we still get all the data. But some
# streams might block on the do_send_all call. So we let the
# sender get as far as it can, then we receive.
await _core.wait_all_tasks_blocked()
await checked_receive_1(b"y")
await checked_receive_1(b"")
await do_aclose(r)
async with _core.open_nursery() as nursery:
nursery.start_soon(send_then_close)
nursery.start_soon(receive_send_then_close)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
await aclose_forcefully(r)
with _assert_raises(_core.BrokenResourceError):
while True:
await do_send_all(b"x" * 100)
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
async with _ForceCloseBoth(await stream_maker()) as (s, r):
await aclose_forcefully(s)
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"123")
# after the sender does a forceful close, the receiver might either
# get BrokenResourceError or a clean b""; either is OK. Not OK would be
# if it freezes, or returns data.
with suppress(_core.BrokenResourceError):
await checked_receive_1(b"")
# cancelled aclose still closes
async with _ForceCloseBoth(await stream_maker()) as (s, r):
with _core.CancelScope() as scope:
scope.cancel()
await r.aclose()
with _core.CancelScope() as scope:
scope.cancel()
await s.aclose()
with _assert_raises(_core.ClosedResourceError):
await do_send_all(b"123")
with _assert_raises(_core.ClosedResourceError):
await do_receive_some(4096)
# Check that we can still gracefully close a stream after an operation has
# been cancelled. This can be challenging if cancellation can leave the
# stream internals in an inconsistent state, e.g. for
# SSLStream. Unfortunately this test isn't very thorough; the really
# challenging case for something like SSLStream is it gets cancelled
# *while* it's sending data on the underlying, not before. But testing
# that requires some special-case handling of the particular stream setup;
# we can't do it here. Maybe we could do a bit better with
# https://github.com/python-trio/trio/issues/77
async with _ForceCloseBoth(await stream_maker()) as (s, r):
async def expect_cancelled(
afn: Callable[ArgsT, Awaitable[object]],
*args: ArgsT.args,
**kwargs: ArgsT.kwargs,
) -> None:
with _assert_raises(_core.Cancelled):
await afn(*args, **kwargs)
with _core.CancelScope() as scope:
scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_cancelled, do_send_all, b"x")
nursery.start_soon(expect_cancelled, do_receive_some, 1)
async with _core.open_nursery() as nursery:
nursery.start_soon(do_aclose, s)
nursery.start_soon(do_aclose, r)
# Check that if a task is blocked in receive_some, then closing the
# receive stream causes it to wake up.
async with _ForceCloseBoth(await stream_maker()) as (s, r):
async def receive_expecting_closed():
with _assert_raises(_core.ClosedResourceError):
await r.receive_some(10)
async with _core.open_nursery() as nursery:
nursery.start_soon(receive_expecting_closed)
await _core.wait_all_tasks_blocked()
await aclose_forcefully(r)
# check wait_send_all_might_not_block, if we can
if clogged_stream_maker is not None:
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
record: list[str] = []
async def waiter(cancel_scope: CancelScope) -> None:
record.append("waiter sleeping")
with assert_checkpoints():
await s.wait_send_all_might_not_block()
record.append("waiter wokeup")
cancel_scope.cancel()
async def receiver() -> None:
# give wait_send_all_might_not_block a chance to block
await _core.wait_all_tasks_blocked()
record.append("receiver starting")
while True:
await r.receive_some(16834)
async with _core.open_nursery() as nursery:
nursery.start_soon(waiter, nursery.cancel_scope)
await _core.wait_all_tasks_blocked()
nursery.start_soon(receiver)
assert record == [
"waiter sleeping",
"receiver starting",
"waiter wokeup",
]
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
# simultaneous wait_send_all_might_not_block fails
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.wait_send_all_might_not_block)
nursery.start_soon(s.wait_send_all_might_not_block)
# and simultaneous send_all and wait_send_all_might_not_block (NB
# this test might destroy the stream b/c we end up cancelling
# send_all and e.g. SSLStream can't handle that, so we have to
# recreate afterwards)
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.wait_send_all_might_not_block)
nursery.start_soon(s.send_all, b"123")
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
# send_all and send_all blocked simultaneously should also raise
# (but again this might destroy the stream)
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s.send_all, b"123")
nursery.start_soon(s.send_all, b"123")
# closing the receiver causes wait_send_all_might_not_block to return,
# with or without an exception
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async def sender() -> None:
try:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
except _core.BrokenResourceError: # pragma: no cover
pass
async def receiver() -> None:
await _core.wait_all_tasks_blocked()
await aclose_forcefully(r)
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
# and again with the call starting after the close
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
await aclose_forcefully(r)
try:
with assert_checkpoints():
await s.wait_send_all_might_not_block()
except _core.BrokenResourceError: # pragma: no cover
pass
# Check that if a task is blocked in a send-side method, then closing
# the send stream causes it to wake up.
async def close_soon(s: SendStream) -> None:
await _core.wait_all_tasks_blocked()
await aclose_forcefully(s)
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async with _core.open_nursery() as nursery:
nursery.start_soon(close_soon, s)
with _assert_raises(_core.ClosedResourceError):
await s.send_all(b"xyzzy")
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
async with _core.open_nursery() as nursery:
nursery.start_soon(close_soon, s)
with _assert_raises(_core.ClosedResourceError):
await s.wait_send_all_might_not_block()
async def check_two_way_stream(
stream_maker: StreamMaker[Stream, Stream],
clogged_stream_maker: StreamMaker[Stream, Stream] | None,
) -> None:
"""Perform a number of generic tests on a custom two-way stream
implementation.
This is similar to :func:`check_one_way_stream`, except that the maker
functions are expected to return objects implementing the
:class:`~trio.abc.Stream` interface.
This function tests a *superset* of what :func:`check_one_way_stream`
checks if you call this, then you don't need to also call
:func:`check_one_way_stream`.
"""
await check_one_way_stream(stream_maker, clogged_stream_maker)
async def flipped_stream_maker() -> tuple[Stream, Stream]:
return (await stream_maker())[::-1]
flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None
if clogged_stream_maker is not None:
async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]:
return (await clogged_stream_maker())[::-1]
else:
flipped_clogged_stream_maker = None
await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker)
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
assert isinstance(s1, Stream)
assert isinstance(s2, Stream)
# Duplex can be a bit tricky, might as well check it as well
DUPLEX_TEST_SIZE = 2**20
CHUNK_SIZE_MAX = 2**14
r = random.Random(0)
i = r.getrandbits(8 * DUPLEX_TEST_SIZE)
test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little")
async def sender(
s: Stream,
data: bytes | bytearray | memoryview,
seed: int,
) -> None:
r = random.Random(seed)
m = memoryview(data)
while m:
chunk_size = r.randint(1, CHUNK_SIZE_MAX)
await s.send_all(m[:chunk_size])
m = m[chunk_size:]
async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None:
r = random.Random(seed)
got = bytearray()
while len(got) < len(data):
chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX))
assert chunk
got += chunk
assert got == data
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, s1, test_data, 0)
nursery.start_soon(sender, s2, test_data[::-1], 1)
nursery.start_soon(receiver, s1, test_data[::-1], 2)
nursery.start_soon(receiver, s2, test_data, 3)
async def expect_receive_some_empty() -> None:
assert await s2.receive_some(10) == b""
await s2.aclose()
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_receive_some_empty)
nursery.start_soon(s1.aclose)
async def check_half_closeable_stream(
stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream],
clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None,
) -> None:
"""Perform a number of generic tests on a custom half-closeable stream
implementation.
This is similar to :func:`check_two_way_stream`, except that the maker
functions are expected to return objects that implement the
:class:`~trio.abc.HalfCloseableStream` interface.
This function tests a *superset* of what :func:`check_two_way_stream`
checks if you call this, then you don't need to also call
:func:`check_two_way_stream`.
"""
await check_two_way_stream(stream_maker, clogged_stream_maker)
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
assert isinstance(s1, HalfCloseableStream)
assert isinstance(s2, HalfCloseableStream)
async def send_x_then_eof(s: HalfCloseableStream) -> None:
await s.send_all(b"x")
with assert_checkpoints():
await s.send_eof()
async def expect_x_then_eof(r: HalfCloseableStream) -> None:
await _core.wait_all_tasks_blocked()
assert await r.receive_some(10) == b"x"
assert await r.receive_some(10) == b""
async with _core.open_nursery() as nursery:
nursery.start_soon(send_x_then_eof, s1)
nursery.start_soon(expect_x_then_eof, s2)
# now sending is disallowed
with _assert_raises(_core.ClosedResourceError):
await s1.send_all(b"y")
# but we can do send_eof again
with assert_checkpoints():
await s1.send_eof()
# and we can still send stuff back the other way
async with _core.open_nursery() as nursery:
nursery.start_soon(send_x_then_eof, s2)
nursery.start_soon(expect_x_then_eof, s1)
if clogged_stream_maker is not None:
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
# send_all and send_eof simultaneously is not ok
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s1.send_all, b"x")
await _core.wait_all_tasks_blocked()
nursery.start_soon(s1.send_eof)
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
# wait_send_all_might_not_block and send_eof simultaneously is not
# ok either
with _assert_raises(_core.BusyResourceError, wrapped=True):
async with _core.open_nursery() as nursery:
nursery.start_soon(s1.wait_send_all_might_not_block)
await _core.wait_all_tasks_blocked()
nursery.start_soon(s1.send_eof)

View File

@ -0,0 +1,69 @@
from __future__ import annotations
from contextlib import AbstractContextManager, contextmanager
from typing import TYPE_CHECKING
from .. import _core
if TYPE_CHECKING:
from collections.abc import Generator
@contextmanager
def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]:
"""Check if checkpoints are executed in a block of code."""
__tracebackhide__ = True
task = _core.current_task()
orig_cancel = task._cancel_points
orig_schedule = task._schedule_points
try:
yield
if expected and (
task._cancel_points == orig_cancel or task._schedule_points == orig_schedule
):
raise AssertionError("assert_checkpoints block did not yield!")
finally:
if not expected and (
task._cancel_points != orig_cancel or task._schedule_points != orig_schedule
):
raise AssertionError("assert_no_checkpoints block yielded!")
def assert_checkpoints() -> AbstractContextManager[None]:
"""Use as a context manager to check that the code inside the ``with``
block either exits with an exception or executes at least one
:ref:`checkpoint <checkpoints>`.
Raises:
AssertionError: if no checkpoint was executed.
Example:
Check that :func:`trio.sleep` is a checkpoint, even if it doesn't
block::
with trio.testing.assert_checkpoints():
await trio.sleep(0)
"""
__tracebackhide__ = True
return _assert_yields_or_not(True)
def assert_no_checkpoints() -> AbstractContextManager[None]:
"""Use as a context manager to check that the code inside the ``with``
block does not execute any :ref:`checkpoints <checkpoints>`.
Raises:
AssertionError: if a checkpoint was executed.
Example:
Synchronous code never contains any checkpoints, but we can double-check
that::
send_channel, receive_channel = trio.open_memory_channel(10)
with trio.testing.assert_no_checkpoints():
send_channel.send_nowait(None)
"""
__tracebackhide__ = True
return _assert_yields_or_not(False)

View File

@ -0,0 +1,578 @@
# This should eventually be cleaned up and become public, but for right now I'm just
# implementing enough to test DTLS.
# TODO:
# - user-defined routers
# - TCP
# - UDP broadcast
from __future__ import annotations
import contextlib
import errno
import ipaddress
import os
import socket
import sys
from typing import (
TYPE_CHECKING,
Any,
Iterable,
NoReturn,
TypeVar,
Union,
overload,
)
import attrs
import trio
from trio._util import NoPublicConstructor, final
if TYPE_CHECKING:
import builtins
from socket import AddressFamily, SocketKind
from types import TracebackType
from typing_extensions import Buffer, Self, TypeAlias
IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
def _family_for(ip: IPAddress) -> int:
if isinstance(ip, ipaddress.IPv4Address):
return trio.socket.AF_INET
elif isinstance(ip, ipaddress.IPv6Address):
return trio.socket.AF_INET6
raise NotImplementedError("Unhandled IPAddress instance type") # pragma: no cover
def _wildcard_ip_for(family: int) -> IPAddress:
if family == trio.socket.AF_INET:
return ipaddress.ip_address("0.0.0.0")
elif family == trio.socket.AF_INET6:
return ipaddress.ip_address("::")
raise NotImplementedError("Unhandled ip address family") # pragma: no cover
# not used anywhere
def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover
if family == trio.socket.AF_INET:
return ipaddress.ip_address("127.0.0.1")
elif family == trio.socket.AF_INET6:
return ipaddress.ip_address("::1")
raise NotImplementedError("Unhandled ip address family")
def _fake_err(code: int) -> NoReturn:
raise OSError(code, os.strerror(code))
def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int:
written = 0
for buf in buffers: # pragma: no branch
next_piece = data[written : written + memoryview(buf).nbytes]
with memoryview(buf) as mbuf:
mbuf[: len(next_piece)] = next_piece
written += len(next_piece)
if written == len(data): # pragma: no branch
break
return written
T_UDPEndpoint = TypeVar("T_UDPEndpoint", bound="UDPEndpoint")
@attrs.frozen
class UDPEndpoint:
ip: IPAddress
port: int
def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]:
sockaddr: tuple[str, int] | tuple[str, int, int, int] = (
self.ip.compressed,
self.port,
)
if isinstance(self.ip, ipaddress.IPv6Address):
sockaddr += (0, 0) # type: ignore[assignment]
return sockaddr
@classmethod
def from_python_sockaddr(
cls: type[T_UDPEndpoint],
sockaddr: tuple[str, int] | tuple[str, int, int, int],
) -> T_UDPEndpoint:
ip, port = sockaddr[:2]
return cls(ip=ipaddress.ip_address(ip), port=port)
@attrs.frozen
class UDPBinding:
local: UDPEndpoint
# remote: UDPEndpoint # ??
@attrs.frozen
class UDPPacket:
source: UDPEndpoint
destination: UDPEndpoint
payload: bytes = attrs.field(repr=lambda p: p.hex())
# not used/tested anywhere
def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover
return UDPPacket(
source=self.destination,
destination=self.source,
payload=payload,
)
@attrs.frozen
class FakeSocketFactory(trio.abc.SocketFactory):
fake_net: FakeNet
def socket(self, family: int, type_: int, proto: int) -> FakeSocket: # type: ignore[override]
return FakeSocket._create(self.fake_net, family, type_, proto)
@attrs.frozen
class FakeHostnameResolver(trio.abc.HostnameResolver):
fake_net: FakeNet
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = 0,
type: int = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> tuple[str, str]:
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
@final
class FakeNet:
def __init__(self) -> None:
# When we need to pick an arbitrary unique ip address/port, use these:
self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() # untested
self._auto_ipv6_iter = ipaddress.IPv6Network("1::/16").hosts() # untested
self._auto_port_iter = iter(range(50000, 65535))
self._bound: dict[UDPBinding, FakeSocket] = {}
self.route_packet = None
def _bind(self, binding: UDPBinding, socket: FakeSocket) -> None:
if binding in self._bound:
_fake_err(errno.EADDRINUSE)
self._bound[binding] = socket
def enable(self) -> None:
trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))
def send_packet(self, packet: UDPPacket) -> None:
if self.route_packet is None:
self.deliver_packet(packet)
else:
self.route_packet(packet)
def deliver_packet(self, packet: UDPPacket) -> None:
binding = UDPBinding(local=packet.destination)
if binding in self._bound:
self._bound[binding]._deliver_packet(packet)
else:
# No valid destination, so drop it
pass
@final
class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
def __init__(
self,
fake_net: FakeNet,
family: AddressFamily,
type: SocketKind,
proto: int,
):
self._fake_net = fake_net
if not family: # pragma: no cover
family = trio.socket.AF_INET
if not type: # pragma: no cover
type = trio.socket.SOCK_STREAM # noqa: A001 # name shadowing builtin
if family not in (trio.socket.AF_INET, trio.socket.AF_INET6):
raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}")
if type != trio.socket.SOCK_DGRAM:
raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")
self._family = family
self._type = type
self._proto = proto
self._closed = False
self._packet_sender, self._packet_receiver = trio.open_memory_channel[
UDPPacket
](float("inf"))
# This is the source-of-truth for what port etc. this socket is bound to
self._binding: UDPBinding | None = None
@property
def type(self) -> SocketKind:
return self._type
@property
def family(self) -> AddressFamily:
return self._family
@property
def proto(self) -> int:
return self._proto
def _check_closed(self) -> None:
if self._closed:
_fake_err(errno.EBADF)
def close(self) -> None:
if self._closed:
return
self._closed = True
if self._binding is not None:
del self._fake_net._bound[self._binding]
self._packet_receiver.close()
async def _resolve_address_nocp(
self,
address: object,
*,
local: bool,
) -> tuple[str, int]:
return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return]
self.type,
self.family,
self.proto,
address=address,
ipv6_v6only=False,
local=local,
)
def _deliver_packet(self, packet: UDPPacket) -> None:
# sending to a closed socket -- UDP packets get dropped
with contextlib.suppress(trio.BrokenResourceError):
self._packet_sender.send_nowait(packet)
################################################################
# Actual IO operation implementations
################################################################
async def bind(self, addr: object) -> None:
self._check_closed()
if self._binding is not None:
_fake_err(errno.EINVAL)
await trio.lowlevel.checkpoint()
ip_str, port, *_ = await self._resolve_address_nocp(addr, local=True)
assert _ == [], "TODO: handle other values?"
ip = ipaddress.ip_address(ip_str)
assert _family_for(ip) == self.family
# We convert binds to INET_ANY into binds to localhost
if ip == ipaddress.ip_address("0.0.0.0"):
ip = ipaddress.ip_address("127.0.0.1")
elif ip == ipaddress.ip_address("::"):
ip = ipaddress.ip_address("::1")
if port == 0:
port = next(self._fake_net._auto_port_iter)
binding = UDPBinding(local=UDPEndpoint(ip, port))
self._fake_net._bind(binding, self)
self._binding = binding
async def connect(self, peer: object) -> NoReturn:
raise NotImplementedError("FakeNet does not (yet) support connected sockets")
async def _sendmsg(
self,
buffers: Iterable[Buffer],
ancdata: Iterable[tuple[int, int, Buffer]] = (),
flags: int = 0,
address: Any | None = None,
) -> int:
self._check_closed()
await trio.lowlevel.checkpoint()
if address is not None:
address = await self._resolve_address_nocp(address, local=False)
if ancdata:
raise NotImplementedError("FakeNet doesn't support ancillary data")
if flags:
raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}")
if address is None:
_fake_err(errno.ENOTCONN)
destination = UDPEndpoint.from_python_sockaddr(address)
if self._binding is None:
await self.bind((_wildcard_ip_for(self.family).compressed, 0))
payload = b"".join(buffers)
assert self._binding is not None
packet = UDPPacket(
source=self._binding.local,
destination=destination,
payload=payload,
)
self._fake_net.send_packet(packet)
return len(payload)
if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
):
sendmsg = _sendmsg
async def _recvmsg_into(
self,
buffers: Iterable[Buffer],
ancbufsize: int = 0,
flags: int = 0,
) -> tuple[int, list[tuple[int, int, bytes]], int, Any]:
if ancbufsize != 0:
raise NotImplementedError("FakeNet doesn't support ancillary data")
if flags != 0:
raise NotImplementedError("FakeNet doesn't support any recv flags")
if self._binding is None:
# I messed this up a few times when writing tests ... but it also never happens
# in any of the existing tests, so maybe it could be intentional...
raise NotImplementedError(
"The code will most likely hang if you try to receive on a fakesocket "
"without a binding. If that is not the case, or you explicitly want to "
"test that, remove this warning.",
)
self._check_closed()
ancdata: list[tuple[int, int, bytes]] = []
msg_flags = 0
packet = await self._packet_receiver.receive()
address = packet.source.as_python_sockaddr()
written = _scatter(packet.payload, buffers)
if written < len(packet.payload):
msg_flags |= trio.socket.MSG_TRUNC
return written, ancdata, msg_flags, address
if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
):
recvmsg_into = _recvmsg_into
################################################################
# Simple state query stuff
################################################################
def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]:
self._check_closed()
if self._binding is not None:
return self._binding.local.as_python_sockaddr()
elif self.family == trio.socket.AF_INET:
return ("0.0.0.0", 0)
else:
assert self.family == trio.socket.AF_INET6
return ("::", 0)
# TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError.
def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]:
self._check_closed()
if self._binding is not None:
assert hasattr(
self._binding,
"remote",
), "This method seems to assume that self._binding has a remote UDPEndpoint"
if self._binding.remote is not None: # pragma: no cover
assert isinstance(
self._binding.remote,
UDPEndpoint,
), "Self._binding.remote should be a UDPEndpoint"
return self._binding.remote.as_python_sockaddr()
_fake_err(errno.ENOTCONN)
@overload
def getsockopt(self, /, level: int, optname: int) -> int: ...
@overload
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
def getsockopt(
self,
/,
level: int,
optname: int,
buflen: int | None = None,
) -> int | bytes:
self._check_closed()
raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})")
@overload
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
@overload
def setsockopt(
self,
/,
level: int,
optname: int,
value: None,
optlen: int,
) -> None: ...
def setsockopt(
self,
/,
level: int,
optname: int,
value: int | Buffer | None,
optlen: int | None = None,
) -> None:
self._check_closed()
if (level, optname) == (
trio.socket.IPPROTO_IPV6,
trio.socket.IPV6_V6ONLY,
) and not value:
raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)")
################################################################
# Various boilerplate and trivial stubs
################################################################
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: builtins.type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.close()
async def send(self, data: Buffer, flags: int = 0) -> int:
return await self.sendto(data, flags, None)
@overload
async def sendto(
self,
__data: Buffer,
__address: tuple[object, ...] | str | Buffer,
) -> int: ...
@overload
async def sendto(
self,
__data: Buffer,
__flags: int,
__address: tuple[object, ...] | str | None | Buffer,
) -> int: ...
async def sendto(self, *args: Any) -> int:
data: Buffer
flags: int
address: tuple[object, ...] | str | Buffer
if len(args) == 2:
data, address = args
flags = 0
elif len(args) == 3:
data, flags, address = args
else:
raise TypeError("wrong number of arguments")
return await self._sendmsg([data], [], flags, address)
async def recv(self, bufsize: int, flags: int = 0) -> bytes:
data, address = await self.recvfrom(bufsize, flags)
return data
async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int:
got_bytes, address = await self.recvfrom_into(buf, nbytes, flags)
return got_bytes
async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]:
data, ancdata, msg_flags, address = await self._recvmsg(bufsize, flags)
return data, address
async def recvfrom_into(
self,
buf: Buffer,
nbytes: int = 0,
flags: int = 0,
) -> tuple[int, Any]:
if nbytes != 0 and nbytes != memoryview(buf).nbytes:
raise NotImplementedError("partial recvfrom_into")
got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into(
[buf],
0,
flags,
)
return got_nbytes, address
async def _recvmsg(
self,
bufsize: int,
ancbufsize: int = 0,
flags: int = 0,
) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]:
buf = bytearray(bufsize)
got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into(
[buf],
ancbufsize,
flags,
)
return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)
if sys.platform != "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
):
recvmsg = _recvmsg
def fileno(self) -> int:
raise NotImplementedError("can't get fileno() for FakeNet sockets")
def detach(self) -> int:
raise NotImplementedError("can't detach() a FakeNet socket")
def get_inheritable(self) -> bool:
return False
def set_inheritable(self, inheritable: bool) -> None:
if inheritable:
raise NotImplementedError("FakeNet can't make inheritable sockets")
if sys.platform == "win32" or (
not TYPE_CHECKING and hasattr(socket.socket, "share")
):
def share(self, process_id: int) -> bytes:
raise NotImplementedError("FakeNet can't share sockets")

View File

@ -0,0 +1,626 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar
from .. import _core, _util
from .._highlevel_generic import StapledStream
from ..abc import ReceiveStream, SendStream
if TYPE_CHECKING:
from typing_extensions import TypeAlias
AsyncHook: TypeAlias = Callable[[], Awaitable[object]]
# Would be nice to exclude awaitable here, but currently not possible.
SyncHook: TypeAlias = Callable[[], object]
SendStreamT = TypeVar("SendStreamT", bound=SendStream)
ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream)
################################################################
# In-memory streams - Unbounded buffer version
################################################################
class _UnboundedByteQueue:
def __init__(self) -> None:
self._data = bytearray()
self._closed = False
self._lot = _core.ParkingLot()
self._fetch_lock = _util.ConflictDetector(
"another task is already fetching data",
)
# This object treats "close" as being like closing the send side of a
# channel: so after close(), calling put() raises ClosedResourceError, and
# calling the get() variants drains the buffer and then returns an empty
# bytearray.
def close(self) -> None:
self._closed = True
self._lot.unpark_all()
def close_and_wipe(self) -> None:
self._data = bytearray()
self.close()
def put(self, data: bytes | bytearray | memoryview) -> None:
if self._closed:
raise _core.ClosedResourceError("virtual connection closed")
self._data += data
self._lot.unpark_all()
def _check_max_bytes(self, max_bytes: int | None) -> None:
if max_bytes is None:
return
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
def _get_impl(self, max_bytes: int | None) -> bytearray:
assert self._closed or self._data
if max_bytes is None:
max_bytes = len(self._data)
if self._data:
chunk = self._data[:max_bytes]
del self._data[:max_bytes]
assert chunk
return chunk
else:
return bytearray()
def get_nowait(self, max_bytes: int | None = None) -> bytearray:
with self._fetch_lock:
self._check_max_bytes(max_bytes)
if not self._closed and not self._data:
raise _core.WouldBlock
return self._get_impl(max_bytes)
async def get(self, max_bytes: int | None = None) -> bytearray:
with self._fetch_lock:
self._check_max_bytes(max_bytes)
if not self._closed and not self._data:
await self._lot.park()
else:
await _core.checkpoint()
return self._get_impl(max_bytes)
@_util.final
class MemorySendStream(SendStream):
"""An in-memory :class:`~trio.abc.SendStream`.
Args:
send_all_hook: An async function, or None. Called from
:meth:`send_all`. Can do whatever you like.
wait_send_all_might_not_block_hook: An async function, or None. Called
from :meth:`wait_send_all_might_not_block`. Can do whatever you
like.
close_hook: A synchronous function, or None. Called from :meth:`close`
and :meth:`aclose`. Can do whatever you like.
.. attribute:: send_all_hook
wait_send_all_might_not_block_hook
close_hook
All of these hooks are also exposed as attributes on the object, and
you can change them at any time.
"""
def __init__(
self,
send_all_hook: AsyncHook | None = None,
wait_send_all_might_not_block_hook: AsyncHook | None = None,
close_hook: SyncHook | None = None,
):
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream",
)
self._outgoing = _UnboundedByteQueue()
self.send_all_hook = send_all_hook
self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook
self.close_hook = close_hook
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
"""Places the given data into the object's internal buffer, and then
calls the :attr:`send_all_hook` (if any).
"""
# Execute two checkpoints so we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
self._outgoing.put(data)
if self.send_all_hook is not None:
await self.send_all_hook()
async def wait_send_all_might_not_block(self) -> None:
"""Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and
then returns immediately.
"""
# Execute two checkpoints so that we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
# check for being closed:
self._outgoing.put(b"")
if self.wait_send_all_might_not_block_hook is not None:
await self.wait_send_all_might_not_block_hook()
def close(self) -> None:
"""Marks this stream as closed, and then calls the :attr:`close_hook`
(if any).
"""
# XXX should this cancel any pending calls to the send_all_hook and
# wait_send_all_might_not_block_hook? Those are the only places where
# send_all and wait_send_all_might_not_block can be blocked.
#
# The way we set things up, send_all_hook is memory_stream_pump, and
# wait_send_all_might_not_block_hook is unset. memory_stream_pump is
# synchronous. So normally, send_all and wait_send_all_might_not_block
# cannot block at all.
self._outgoing.close()
if self.close_hook is not None:
self.close_hook()
async def aclose(self) -> None:
"""Same as :meth:`close`, but async."""
self.close()
await _core.checkpoint()
async def get_data(self, max_bytes: int | None = None) -> bytearray:
"""Retrieves data from the internal buffer, blocking if necessary.
Args:
max_bytes (int or None): The maximum amount of data to
retrieve. None (the default) means to retrieve all the data
that's present (but still blocks until at least one byte is
available).
Returns:
If this stream has been closed, an empty bytearray. Otherwise, the
requested data.
"""
return await self._outgoing.get(max_bytes)
def get_data_nowait(self, max_bytes: int | None = None) -> bytearray:
"""Retrieves data from the internal buffer, but doesn't block.
See :meth:`get_data` for details.
Raises:
trio.WouldBlock: if no data is available to retrieve.
"""
return self._outgoing.get_nowait(max_bytes)
@_util.final
class MemoryReceiveStream(ReceiveStream):
"""An in-memory :class:`~trio.abc.ReceiveStream`.
Args:
receive_some_hook: An async function, or None. Called from
:meth:`receive_some`. Can do whatever you like.
close_hook: A synchronous function, or None. Called from :meth:`close`
and :meth:`aclose`. Can do whatever you like.
.. attribute:: receive_some_hook
close_hook
Both hooks are also exposed as attributes on the object, and you can
change them at any time.
"""
def __init__(
self,
receive_some_hook: AsyncHook | None = None,
close_hook: SyncHook | None = None,
):
self._conflict_detector = _util.ConflictDetector(
"another task is using this stream",
)
self._incoming = _UnboundedByteQueue()
self._closed = False
self.receive_some_hook = receive_some_hook
self.close_hook = close_hook
async def receive_some(self, max_bytes: int | None = None) -> bytearray:
"""Calls the :attr:`receive_some_hook` (if any), and then retrieves
data from the internal buffer, blocking if necessary.
"""
# Execute two checkpoints so we have more of a chance to detect
# buggy user code that calls this twice at the same time.
with self._conflict_detector:
await _core.checkpoint()
await _core.checkpoint()
if self._closed:
raise _core.ClosedResourceError
if self.receive_some_hook is not None:
await self.receive_some_hook()
# self._incoming's closure state tracks whether we got an EOF.
# self._closed tracks whether we, ourselves, are closed.
# self.close() sends an EOF to wake us up and sets self._closed,
# so after we wake up we have to check self._closed again.
data = await self._incoming.get(max_bytes)
if self._closed:
raise _core.ClosedResourceError
return data
def close(self) -> None:
"""Discards any pending data from the internal buffer, and marks this
stream as closed.
"""
self._closed = True
self._incoming.close_and_wipe()
if self.close_hook is not None:
self.close_hook()
async def aclose(self) -> None:
"""Same as :meth:`close`, but async."""
self.close()
await _core.checkpoint()
def put_data(self, data: bytes | bytearray | memoryview) -> None:
"""Appends the given data to the internal buffer."""
self._incoming.put(data)
def put_eof(self) -> None:
"""Adds an end-of-file marker to the internal buffer."""
self._incoming.close()
def memory_stream_pump(
memory_send_stream: MemorySendStream,
memory_receive_stream: MemoryReceiveStream,
*,
max_bytes: int | None = None,
) -> bool:
"""Take data out of the given :class:`MemorySendStream`'s internal buffer,
and put it into the given :class:`MemoryReceiveStream`'s internal buffer.
Args:
memory_send_stream (MemorySendStream): The stream to get data from.
memory_receive_stream (MemoryReceiveStream): The stream to put data into.
max_bytes (int or None): The maximum amount of data to transfer in this
call, or None to transfer all available data.
Returns:
True if it successfully transferred some data, or False if there was no
data to transfer.
This is used to implement :func:`memory_stream_one_way_pair` and
:func:`memory_stream_pair`; see the latter's docstring for an example
of how you might use it yourself.
"""
try:
data = memory_send_stream.get_data_nowait(max_bytes)
except _core.WouldBlock:
return False
try:
if not data:
memory_receive_stream.put_eof()
else:
memory_receive_stream.put_data(data)
except _core.ClosedResourceError:
raise _core.BrokenResourceError("MemoryReceiveStream was closed") from None
return True
def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]:
"""Create a connected, pure-Python, unidirectional stream with infinite
buffering and flexible configuration options.
You can think of this as being a no-operating-system-involved
Trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe`
returns the streams in the wrong order we follow the superior convention
that data flows from left to right).
Returns:
A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where
the :class:`MemorySendStream` has its hooks set up so that it calls
:func:`memory_stream_pump` from its
:attr:`~MemorySendStream.send_all_hook` and
:attr:`~MemorySendStream.close_hook`.
The end result is that data automatically flows from the
:class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're
also free to rearrange things however you like. For example, you can
temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you
want to simulate a stall in data transmission. Or see
:func:`memory_stream_pair` for a more elaborate example.
"""
send_stream = MemorySendStream()
recv_stream = MemoryReceiveStream()
def pump_from_send_stream_to_recv_stream() -> None:
memory_stream_pump(send_stream, recv_stream)
async def async_pump_from_send_stream_to_recv_stream() -> None:
pump_from_send_stream_to_recv_stream()
send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream
send_stream.close_hook = pump_from_send_stream_to_recv_stream
return send_stream, recv_stream
def _make_stapled_pair(
one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]],
) -> tuple[
StapledStream[SendStreamT, ReceiveStreamT],
StapledStream[SendStreamT, ReceiveStreamT],
]:
pipe1_send, pipe1_recv = one_way_pair()
pipe2_send, pipe2_recv = one_way_pair()
stream1 = StapledStream(pipe1_send, pipe2_recv)
stream2 = StapledStream(pipe2_send, pipe1_recv)
return stream1, stream2
def memory_stream_pair() -> tuple[
StapledStream[MemorySendStream, MemoryReceiveStream],
StapledStream[MemorySendStream, MemoryReceiveStream],
]:
"""Create a connected, pure-Python, bidirectional stream with infinite
buffering and flexible configuration options.
This is a convenience function that creates two one-way streams using
:func:`memory_stream_one_way_pair`, and then uses
:class:`~trio.StapledStream` to combine them into a single bidirectional
stream.
This is like a no-operating-system-involved, Trio-streamsified version of
:func:`socket.socketpair`.
Returns:
A pair of :class:`~trio.StapledStream` objects that are connected so
that data automatically flows from one to the other in both directions.
After creating a stream pair, you can send data back and forth, which is
enough for simple tests::
left, right = memory_stream_pair()
await left.send_all(b"123")
assert await right.receive_some() == b"123"
await right.send_all(b"456")
assert await left.receive_some() == b"456"
But if you read the docs for :class:`~trio.StapledStream` and
:func:`memory_stream_one_way_pair`, you'll see that all the pieces
involved in wiring this up are public APIs, so you can adjust to suit the
requirements of your tests. For example, here's how to tweak a stream so
that data flowing from left to right trickles in one byte at a time (but
data flowing from right to left proceeds at full speed)::
left, right = memory_stream_pair()
async def trickle():
# left is a StapledStream, and left.send_stream is a MemorySendStream
# right is a StapledStream, and right.recv_stream is a MemoryReceiveStream
while memory_stream_pump(left.send_stream, right.recv_stream, max_bytes=1):
# Pause between each byte
await trio.sleep(1)
# Normally this send_all_hook calls memory_stream_pump directly without
# passing in a max_bytes. We replace it with our custom version:
left.send_stream.send_all_hook = trickle
And here's a simple test using our modified stream objects::
async def sender():
await left.send_all(b"12345")
await left.send_eof()
async def receiver():
async for data in right:
print(data)
async with trio.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
By default, this will print ``b"12345"`` and then immediately exit; with
our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then
sleeps 1 second, then prints ``b"2"``, etc.
Pro-tip: you can insert sleep calls (like in our example above) to
manipulate the flow of data across tasks... and then use
:class:`MockClock` and its :attr:`~MockClock.autojump_threshold`
functionality to keep your test suite running quickly.
If you want to stress test a protocol implementation, one nice trick is to
use the :mod:`random` module (preferably with a fixed seed) to move random
numbers of bytes at a time, and insert random sleeps in between them. You
can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if
you want to manipulate things on the receiving side, and not just the
sending side.
"""
return _make_stapled_pair(memory_stream_one_way_pair)
################################################################
# In-memory streams - Lockstep version
################################################################
class _LockstepByteQueue:
def __init__(self) -> None:
self._data = bytearray()
self._sender_closed = False
self._receiver_closed = False
self._receiver_waiting = False
self._waiters = _core.ParkingLot()
self._send_conflict_detector = _util.ConflictDetector(
"another task is already sending",
)
self._receive_conflict_detector = _util.ConflictDetector(
"another task is already receiving",
)
def _something_happened(self) -> None:
self._waiters.unpark_all()
# Always wakes up when one side is closed, because everyone always reacts
# to that.
async def _wait_for(self, fn: Callable[[], bool]) -> None:
while True:
if fn():
break
if self._sender_closed or self._receiver_closed:
break
await self._waiters.park()
await _core.checkpoint()
def close_sender(self) -> None:
self._sender_closed = True
self._something_happened()
def close_receiver(self) -> None:
self._receiver_closed = True
self._something_happened()
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
with self._send_conflict_detector:
if self._sender_closed:
raise _core.ClosedResourceError
if self._receiver_closed:
raise _core.BrokenResourceError
assert not self._data
self._data += data
self._something_happened()
await self._wait_for(lambda: self._data == b"")
if self._sender_closed:
raise _core.ClosedResourceError
if self._data and self._receiver_closed:
raise _core.BrokenResourceError
async def wait_send_all_might_not_block(self) -> None:
with self._send_conflict_detector:
if self._sender_closed:
raise _core.ClosedResourceError
if self._receiver_closed:
await _core.checkpoint()
return
await self._wait_for(lambda: self._receiver_waiting)
if self._sender_closed:
raise _core.ClosedResourceError
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
with self._receive_conflict_detector:
# Argument validation
if max_bytes is not None:
max_bytes = operator.index(max_bytes)
if max_bytes < 1:
raise ValueError("max_bytes must be >= 1")
# State validation
if self._receiver_closed:
raise _core.ClosedResourceError
# Wake wait_send_all_might_not_block and wait for data
self._receiver_waiting = True
self._something_happened()
try:
await self._wait_for(lambda: self._data != b"")
finally:
self._receiver_waiting = False
if self._receiver_closed:
raise _core.ClosedResourceError
# Get data, possibly waking send_all
if self._data:
# Neat trick: if max_bytes is None, then obj[:max_bytes] is
# the same as obj[:].
got = self._data[:max_bytes]
del self._data[:max_bytes]
self._something_happened()
return got
else:
assert self._sender_closed
return b""
class _LockstepSendStream(SendStream):
def __init__(self, lbq: _LockstepByteQueue):
self._lbq = lbq
def close(self) -> None:
self._lbq.close_sender()
async def aclose(self) -> None:
self.close()
await _core.checkpoint()
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
await self._lbq.send_all(data)
async def wait_send_all_might_not_block(self) -> None:
await self._lbq.wait_send_all_might_not_block()
class _LockstepReceiveStream(ReceiveStream):
def __init__(self, lbq: _LockstepByteQueue):
self._lbq = lbq
def close(self) -> None:
self._lbq.close_receiver()
async def aclose(self) -> None:
self.close()
await _core.checkpoint()
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
return await self._lbq.receive_some(max_bytes)
def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]:
"""Create a connected, pure Python, unidirectional stream where data flows
in lockstep.
Returns:
A tuple
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`).
This stream has *absolutely no* buffering. Each call to
:meth:`~trio.abc.SendStream.send_all` will block until all the given data
has been returned by a call to
:meth:`~trio.abc.ReceiveStream.receive_some`.
This can be useful for testing flow control mechanisms in an extreme case,
or for setting up "clogged" streams to use with
:func:`check_one_way_stream` and friends.
In addition to fulfilling the :class:`~trio.abc.SendStream` and
:class:`~trio.abc.ReceiveStream` interfaces, the return objects
also have a synchronous ``close`` method.
"""
lbq = _LockstepByteQueue()
return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq)
def lockstep_stream_pair() -> tuple[
StapledStream[SendStream, ReceiveStream],
StapledStream[SendStream, ReceiveStream],
]:
"""Create a connected, pure-Python, bidirectional stream where data flows
in lockstep.
Returns:
A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`).
This is a convenience function that creates two one-way streams using
:func:`lockstep_stream_one_way_pair`, and then uses
:class:`~trio.StapledStream` to combine them into a single bidirectional
stream.
"""
return _make_stapled_pair(lockstep_stream_one_way_pair)

View File

@ -0,0 +1,36 @@
from .. import socket as tsocket
from .._highlevel_socket import SocketListener, SocketStream
async def open_stream_to_socket_listener(
socket_listener: SocketListener,
) -> SocketStream:
"""Connect to the given :class:`~trio.SocketListener`.
This is particularly useful in tests when you want to let a server pick
its own port, and then connect to it::
listeners = await trio.open_tcp_listeners(0)
client = await trio.testing.open_stream_to_socket_listener(listeners[0])
Args:
socket_listener (~trio.SocketListener): The
:class:`~trio.SocketListener` to connect to.
Returns:
SocketStream: a stream connected to the given listener.
"""
family = socket_listener.socket.family
sockaddr = socket_listener.socket.getsockname()
if family in (tsocket.AF_INET, tsocket.AF_INET6):
sockaddr = list(sockaddr)
if sockaddr[0] == "0.0.0.0":
sockaddr[0] = "127.0.0.1"
if sockaddr[0] == "::":
sockaddr[0] = "::1"
sockaddr = tuple(sockaddr)
sock = tsocket.socket(family=family)
await sock.connect(sockaddr)
return SocketStream(sock)

View File

@ -0,0 +1,568 @@
from __future__ import annotations
import re
import sys
from typing import (
TYPE_CHECKING,
Callable,
ContextManager,
Generic,
Literal,
Pattern,
Sequence,
cast,
overload,
)
from trio._util import final
if TYPE_CHECKING:
import builtins
# sphinx will *only* work if we use types.TracebackType, and import
# *inside* TYPE_CHECKING. No other combination works.....
import types
from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback
from typing_extensions import TypeGuard, TypeVar
MatchE = TypeVar(
"MatchE",
bound=BaseException,
default=BaseException,
covariant=True,
)
else:
from typing import TypeVar
MatchE = TypeVar("MatchE", bound=BaseException, covariant=True)
# RaisesGroup doesn't work with a default.
E = TypeVar("E", bound=BaseException, covariant=True)
# These two typevars are special cased in sphinx config to workaround lookup bugs.
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
@final
class _ExceptionInfo(Generic[MatchE]):
"""Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`."""
_excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None
def __init__(
self,
excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None,
):
self._excinfo = excinfo
def fill_unfilled(
self,
exc_info: tuple[type[MatchE], MatchE, types.TracebackType],
) -> None:
"""Fill an unfilled ExceptionInfo created with ``for_later()``."""
assert self._excinfo is None, "ExceptionInfo was already filled"
self._excinfo = exc_info
@classmethod
def for_later(cls) -> _ExceptionInfo[MatchE]:
"""Return an unfilled ExceptionInfo."""
return cls(None)
# Note, special cased in sphinx config, since "type" conflicts.
@property
def type(self) -> type[MatchE]:
"""The exception class."""
assert (
self._excinfo is not None
), ".type can only be used after the context manager exits"
return self._excinfo[0]
@property
def value(self) -> MatchE:
"""The exception value."""
assert (
self._excinfo is not None
), ".value can only be used after the context manager exits"
return self._excinfo[1]
@property
def tb(self) -> types.TracebackType:
"""The exception raw traceback."""
assert (
self._excinfo is not None
), ".tb can only be used after the context manager exits"
return self._excinfo[2]
def exconly(self, tryshort: bool = False) -> str:
raise NotImplementedError(
"This is a helper method only available if you use RaisesGroup with the pytest package installed",
)
def errisinstance(
self,
exc: builtins.type[BaseException] | tuple[builtins.type[BaseException], ...],
) -> bool:
raise NotImplementedError(
"This is a helper method only available if you use RaisesGroup with the pytest package installed",
)
def getrepr(
self,
showlocals: bool = False,
style: str = "long",
abspath: bool = False,
tbfilter: bool | Callable[[_ExceptionInfo], Traceback] = True,
funcargs: bool = False,
truncate_locals: bool = True,
chain: bool = True,
) -> ReprExceptionInfo | ExceptionChainRepr:
raise NotImplementedError(
"This is a helper method only available if you use RaisesGroup with the pytest package installed",
)
# Type checkers are not able to do conditional types depending on installed packages, so
# we've added signatures for all helpers to _ExceptionInfo, and then always use that.
# If this ends up leading to problems, we can resort to always using _ExceptionInfo and
# users that want to use getrepr/errisinstance/exconly can write helpers on their own, or
# we reimplement them ourselves...or get this merged in upstream pytest.
if TYPE_CHECKING:
ExceptionInfo = _ExceptionInfo
else:
try:
from pytest import ExceptionInfo # noqa: PT013
except ImportError: # pragma: no cover
ExceptionInfo = _ExceptionInfo
# copied from pytest.ExceptionInfo
def _stringify_exception(exc: BaseException) -> str:
return "\n".join(
[
getattr(exc, "message", str(exc)),
*getattr(exc, "__notes__", []),
],
)
# String patterns default to including the unicode flag.
_regex_no_flags = re.compile("").flags
@final
class Matcher(Generic[MatchE]):
"""Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments.
The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter.
:meth:`trio.testing.Matcher.matches` can also be used standalone to check individual exceptions.
Examples::
with RaisesGroups(Matcher(ValueError, match="string"))
...
with RaisesGroups(Matcher(check=lambda x: x.args == (3, "hello"))):
...
with RaisesGroups(Matcher(check=lambda x: type(x) is ValueError)):
...
"""
# At least one of the three parameters must be passed.
@overload
def __init__(
self: Matcher[MatchE],
exception_type: type[MatchE],
match: str | Pattern[str] = ...,
check: Callable[[MatchE], bool] = ...,
): ...
@overload
def __init__(
self: Matcher[BaseException], # Give E a value.
*,
match: str | Pattern[str],
# If exception_type is not provided, check() must do any typechecks itself.
check: Callable[[BaseException], bool] = ...,
): ...
@overload
def __init__(self, *, check: Callable[[BaseException], bool]): ...
def __init__(
self,
exception_type: type[MatchE] | None = None,
match: str | Pattern[str] | None = None,
check: Callable[[MatchE], bool] | None = None,
):
if exception_type is None and match is None and check is None:
raise ValueError("You must specify at least one parameter to match on.")
if exception_type is not None and not issubclass(exception_type, BaseException):
raise ValueError(
f"exception_type {exception_type} must be a subclass of BaseException",
)
self.exception_type = exception_type
self.match: Pattern[str] | None
if isinstance(match, str):
self.match = re.compile(match)
else:
self.match = match
self.check = check
def matches(self, exception: BaseException) -> TypeGuard[MatchE]:
"""Check if an exception matches the requirements of this Matcher.
Examples::
assert Matcher(ValueError).matches(my_exception):
# is equivalent to
assert isinstance(my_exception, ValueError)
# this can be useful when checking e.g. the ``__cause__`` of an exception.
with pytest.raises(ValueError) as excinfo:
...
assert Matcher(SyntaxError, match="foo").matches(excinfo.value.__cause__)
# above line is equivalent to
assert isinstance(excinfo.value.__cause__, SyntaxError)
assert re.search("foo", str(excinfo.value.__cause__)
"""
if self.exception_type is not None and not isinstance(
exception,
self.exception_type,
):
return False
if self.match is not None and not re.search(
self.match,
_stringify_exception(exception),
):
return False
# If exception_type is None check() accepts BaseException.
# If non-none, we have done an isinstance check above.
return self.check is None or self.check(cast(MatchE, exception))
def __str__(self) -> str:
reqs = []
if self.exception_type is not None:
reqs.append(self.exception_type.__name__)
if (match := self.match) is not None:
# If no flags were specified, discard the redundant re.compile() here.
reqs.append(
f"match={match.pattern if match.flags == _regex_no_flags else match!r}",
)
if self.check is not None:
reqs.append(f"check={self.check!r}")
return f'Matcher({", ".join(reqs)})'
# typing this has been somewhat of a nightmare, with the primary difficulty making
# the return type of __enter__ correct. Ideally it would function like this
# with RaisesGroup(RaisesGroup(ValueError)) as excinfo:
# ...
# assert_type(excinfo.value, ExceptionGroup[ExceptionGroup[ValueError]])
# in addition to all the simple cases, but getting all the way to the above seems maybe
# impossible. The type being RaisesGroup[RaisesGroup[ValueError]] is probably also fine,
# as long as I add fake properties corresponding to the properties of exceptiongroup. But
# I had trouble with it handling recursive cases properly.
# Current solution settles on the above giving BaseExceptionGroup[RaisesGroup[ValueError]], and it not
# being a type error to do `with RaisesGroup(ValueError()): ...` - but that will error on runtime.
# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups
if TYPE_CHECKING:
SuperClass = BaseExceptionGroup
else:
# At runtime, use a redundant Generic base class which effectively gets ignored.
SuperClass = Generic
@final
class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E]):
"""Contextmanager for checking for an expected `ExceptionGroup`.
This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538
The catching behaviour differs from :ref:`except* <except_star>` in multiple different ways, being much stricter by default. By using ``allow_unwrapped=True`` and ``flatten_subgroups=True`` you can match ``except*`` fully when expecting a single exception.
#. All specified exceptions must be present, *and no others*.
* If you expect a variable number of exceptions you need to use ``pytest.raises(ExceptionGroup)`` and manually check the contained exceptions. Consider making use of :func:`Matcher.matches`.
#. It will only catch exceptions wrapped in an exceptiongroup by default.
* With ``allow_unwrapped=True`` you can specify a single expected exception or `Matcher` and it will match the exception even if it is not inside an `ExceptionGroup`. If you expect one of several different exception types you need to use a `Matcher` object.
#. By default it cares about the full structure with nested `ExceptionGroup`'s. You can specify nested `ExceptionGroup`'s by passing `RaisesGroup` objects as expected exceptions.
* With ``flatten_subgroups=True`` it will "flatten" the raised `ExceptionGroup`, extracting all exceptions inside any nested :class:`ExceptionGroup`, before matching.
It currently does not care about the order of the exceptions, so ``RaisesGroups(ValueError, TypeError)`` is equivalent to ``RaisesGroups(TypeError, ValueError)``.
This class is not as polished as ``pytest.raises``, and is currently not as helpful in e.g. printing diffs when strings don't match, suggesting you use ``re.escape``, etc.
Examples::
with RaisesGroups(ValueError):
raise ExceptionGroup("", (ValueError(),))
with RaisesGroups(ValueError, ValueError, Matcher(TypeError, match="expected int")):
...
with RaisesGroups(KeyboardInterrupt, match="hello", check=lambda x: type(x) is BaseExceptionGroup):
...
with RaisesGroups(RaisesGroups(ValueError)):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
# flatten_subgroups
with RaisesGroups(ValueError, flatten_subgroups=True):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
# allow_unwrapped
with RaisesGroups(ValueError, allow_unwrapped=True):
raise ValueError
`RaisesGroup.matches` can also be used directly to check a standalone exception group.
The matching algorithm is greedy, which means cases such as this may fail::
with RaisesGroups(ValueError, Matcher(ValueError, match="hello")):
raise ExceptionGroup("", (ValueError("hello"), ValueError("goodbye")))
even though it generally does not care about the order of the exceptions in the group.
To avoid the above you should specify the first ValueError with a Matcher as well.
It is also not typechecked perfectly, and that's likely not possible with the current approach. Most common usage should work without issue though.
"""
# needed for pyright, since BaseExceptionGroup.__new__ takes two arguments
if TYPE_CHECKING:
def __new__(cls, *args: object, **kwargs: object) -> RaisesGroup[E]: ...
# allow_unwrapped=True requires: singular exception, exception not being
# RaisesGroup instance, match is None, check is None
@overload
def __init__(
self,
exception: type[E] | Matcher[E],
*,
allow_unwrapped: Literal[True],
flatten_subgroups: bool = False,
match: None = None,
check: None = None,
): ...
# flatten_subgroups = True also requires no nested RaisesGroup
@overload
def __init__(
self,
exception: type[E] | Matcher[E],
*other_exceptions: type[E] | Matcher[E],
allow_unwrapped: Literal[False] = False,
flatten_subgroups: Literal[True],
match: str | Pattern[str] | None = None,
check: Callable[[BaseExceptionGroup[E]], bool] | None = None,
): ...
@overload
def __init__(
self,
exception: type[E] | Matcher[E] | E,
*other_exceptions: type[E] | Matcher[E] | E,
allow_unwrapped: Literal[False] = False,
flatten_subgroups: Literal[False] = False,
match: str | Pattern[str] | None = None,
check: Callable[[BaseExceptionGroup[E]], bool] | None = None,
): ...
def __init__(
self,
exception: type[E] | Matcher[E] | E,
*other_exceptions: type[E] | Matcher[E] | E,
allow_unwrapped: bool = False,
flatten_subgroups: bool = False,
match: str | Pattern[str] | None = None,
check: Callable[[BaseExceptionGroup[E]], bool] | None = None,
):
self.expected_exceptions: tuple[type[E] | Matcher[E] | E, ...] = (
exception,
*other_exceptions,
)
self.flatten_subgroups: bool = flatten_subgroups
self.allow_unwrapped = allow_unwrapped
self.match_expr = match
self.check = check
self.is_baseexceptiongroup = False
if allow_unwrapped and other_exceptions:
raise ValueError(
"You cannot specify multiple exceptions with `allow_unwrapped=True.`"
" If you want to match one of multiple possible exceptions you should"
" use a `Matcher`."
" E.g. `Matcher(check=lambda e: isinstance(e, (...)))`",
)
if allow_unwrapped and isinstance(exception, RaisesGroup):
raise ValueError(
"`allow_unwrapped=True` has no effect when expecting a `RaisesGroup`."
" You might want it in the expected `RaisesGroup`, or"
" `flatten_subgroups=True` if you don't care about the structure.",
)
if allow_unwrapped and (match is not None or check is not None):
raise ValueError(
"`allow_unwrapped=True` bypasses the `match` and `check` parameters"
" if the exception is unwrapped. If you intended to match/check the"
" exception you should use a `Matcher` object. If you want to match/check"
" the exceptiongroup when the exception *is* wrapped you need to"
" do e.g. `if isinstance(exc.value, ExceptionGroup):"
" assert RaisesGroup(...).matches(exc.value)` afterwards.",
)
# verify `expected_exceptions` and set `self.is_baseexceptiongroup`
for exc in self.expected_exceptions:
if isinstance(exc, RaisesGroup):
if self.flatten_subgroups:
raise ValueError(
"You cannot specify a nested structure inside a RaisesGroup with"
" `flatten_subgroups=True`. The parameter will flatten subgroups"
" in the raised exceptiongroup before matching, which would never"
" match a nested structure.",
)
self.is_baseexceptiongroup |= exc.is_baseexceptiongroup
elif isinstance(exc, Matcher):
# The Matcher could match BaseExceptions through the other arguments
# but `self.is_baseexceptiongroup` is only used for printing.
if exc.exception_type is None:
continue
# Matcher __init__ assures it's a subclass of BaseException
self.is_baseexceptiongroup |= not issubclass(
exc.exception_type,
Exception,
)
elif isinstance(exc, type) and issubclass(exc, BaseException):
self.is_baseexceptiongroup |= not issubclass(exc, Exception)
else:
raise ValueError(
f'Invalid argument "{exc!r}" must be exception type, Matcher, or'
" RaisesGroup.",
)
def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[E]]:
self.excinfo: ExceptionInfo[BaseExceptionGroup[E]] = ExceptionInfo.for_later()
return self.excinfo
def _unroll_exceptions(
self,
exceptions: Sequence[BaseException],
) -> Sequence[BaseException]:
"""Used if `flatten_subgroups=True`."""
res: list[BaseException] = []
for exc in exceptions:
if isinstance(exc, BaseExceptionGroup):
res.extend(self._unroll_exceptions(exc.exceptions))
else:
res.append(exc)
return res
def matches(
self,
exc_val: BaseException | None,
) -> TypeGuard[BaseExceptionGroup[E]]:
"""Check if an exception matches the requirements of this RaisesGroup.
Example::
with pytest.raises(TypeError) as excinfo:
...
assert RaisesGroups(ValueError).matches(excinfo.value.__cause__)
# the above line is equivalent to
myexc = excinfo.value.__cause
assert isinstance(myexc, BaseExceptionGroup)
assert len(myexc.exceptions) == 1
assert isinstance(myexc.exceptions[0], ValueError)
"""
if exc_val is None:
return False
# TODO: print/raise why a match fails, in a way that works properly in nested cases
# maybe have a list of strings logging failed matches, that __exit__ can
# recursively step through and print on a failing match.
if not isinstance(exc_val, BaseExceptionGroup):
if self.allow_unwrapped:
exp_exc = self.expected_exceptions[0]
if isinstance(exp_exc, Matcher) and exp_exc.matches(exc_val):
return True
if isinstance(exp_exc, type) and isinstance(exc_val, exp_exc):
return True
return False
if self.match_expr is not None and not re.search(
self.match_expr,
_stringify_exception(exc_val),
):
return False
if self.check is not None and not self.check(exc_val):
return False
remaining_exceptions = list(self.expected_exceptions)
actual_exceptions: Sequence[BaseException] = exc_val.exceptions
if self.flatten_subgroups:
actual_exceptions = self._unroll_exceptions(actual_exceptions)
# important to check the length *after* flattening subgroups
if len(actual_exceptions) != len(self.expected_exceptions):
return False
# it should be possible to get RaisesGroup.matches typed so as not to
# need type: ignore, but I'm not sure that's possible while also having it
# transparent for the end user.
for e in actual_exceptions:
for rem_e in remaining_exceptions:
if (
(isinstance(rem_e, type) and isinstance(e, rem_e))
or (isinstance(rem_e, RaisesGroup) and rem_e.matches(e))
or (isinstance(rem_e, Matcher) and rem_e.matches(e))
):
remaining_exceptions.remove(rem_e) # type: ignore[arg-type]
break
else:
return False
return True
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: types.TracebackType | None,
) -> bool:
__tracebackhide__ = True
assert (
exc_type is not None
), f"DID NOT RAISE any exception, expected {self.expected_type()}"
assert (
self.excinfo is not None
), "Internal error - should have been constructed in __enter__"
if not self.matches(exc_val):
return False
# Cast to narrow the exception type now that it's verified.
exc_info = cast(
"tuple[type[BaseExceptionGroup[E]], BaseExceptionGroup[E], types.TracebackType]",
(exc_type, exc_val, exc_tb),
)
self.excinfo.fill_unfilled(exc_info)
return True
def expected_type(self) -> str:
subexcs = []
for e in self.expected_exceptions:
if isinstance(e, Matcher):
subexcs.append(str(e))
elif isinstance(e, RaisesGroup):
subexcs.append(e.expected_type())
elif isinstance(e, type):
subexcs.append(e.__name__)
else: # pragma: no cover
raise AssertionError("unknown type")
group_type = "Base" if self.is_baseexceptiongroup else ""
return f"{group_type}ExceptionGroup({', '.join(subexcs)})"

View File

@ -0,0 +1,87 @@
from __future__ import annotations
from collections import defaultdict
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING
import attrs
from .. import Event, _core, _util
if TYPE_CHECKING:
from collections.abc import AsyncIterator
@_util.final
@attrs.define(eq=False, slots=False)
class Sequencer:
"""A convenience class for forcing code in different tasks to run in an
explicit linear order.
Instances of this class implement a ``__call__`` method which returns an
async context manager. The idea is that you pass a sequence number to
``__call__`` to say where this block of code should go in the linear
sequence. Block 0 starts immediately, and then block N doesn't start until
block N-1 has finished.
Example:
An extremely elaborate way to print the numbers 0-5, in order::
async def worker1(seq):
async with seq(0):
print(0)
async with seq(4):
print(4)
async def worker2(seq):
async with seq(2):
print(2)
async with seq(5):
print(5)
async def worker3(seq):
async with seq(1):
print(1)
async with seq(3):
print(3)
async def main():
seq = trio.testing.Sequencer()
async with trio.open_nursery() as nursery:
nursery.start_soon(worker1, seq)
nursery.start_soon(worker2, seq)
nursery.start_soon(worker3, seq)
"""
_sequence_points: defaultdict[int, Event] = attrs.field(
factory=lambda: defaultdict(Event),
init=False,
)
_claimed: set[int] = attrs.field(factory=set, init=False)
_broken: bool = attrs.field(default=False, init=False)
@asynccontextmanager
async def __call__(self, position: int) -> AsyncIterator[None]:
if position in self._claimed:
raise RuntimeError(f"Attempted to reuse sequence point {position}")
if self._broken:
raise RuntimeError("sequence broken!")
self._claimed.add(position)
if position != 0:
try:
await self._sequence_points[position].wait()
except _core.Cancelled:
self._broken = True
for event in self._sequence_points.values():
event.set()
raise RuntimeError(
"Sequencer wait cancelled -- sequence broken",
) from None
else:
if self._broken:
raise RuntimeError("sequence broken!")
try:
yield
finally:
self._sequence_points[position + 1].set()

View File

@ -0,0 +1,50 @@
from __future__ import annotations
from functools import partial, wraps
from typing import TYPE_CHECKING, TypeVar
from .. import _core
from ..abc import Clock, Instrument
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from typing_extensions import ParamSpec
ArgsT = ParamSpec("ArgsT")
RetT = TypeVar("RetT")
def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]:
"""Converts an async test function to be synchronous, running via Trio.
Usage::
@trio_test
async def test_whatever():
await ...
If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or
:class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`.
"""
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
__tracebackhide__ = True
clocks = [c for c in kwargs.values() if isinstance(c, Clock)]
if not clocks:
clock = None
elif len(clocks) == 1:
clock = clocks[0]
else:
raise ValueError("too many clocks spoil the broth!")
instruments = [i for i in kwargs.values() if isinstance(i, Instrument)]
return _core.run(
partial(fn, *args, **kwargs),
clock=clock,
instruments=instruments,
)
return wrapper