Updated script that can be controled by Nodejs web app
This commit is contained in:
39
lib/python3.13/site-packages/trio/testing/__init__.py
Normal file
39
lib/python3.13/site-packages/trio/testing/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Uses `from x import y as y` for compatibility with `pyright --verifytypes` (#2625)
|
||||
|
||||
from .._core import (
|
||||
MockClock as MockClock,
|
||||
wait_all_tasks_blocked as wait_all_tasks_blocked,
|
||||
)
|
||||
from .._threads import (
|
||||
active_thread_count as active_thread_count,
|
||||
wait_all_threads_completed as wait_all_threads_completed,
|
||||
)
|
||||
from .._util import fixup_module_metadata
|
||||
from ._check_streams import (
|
||||
check_half_closeable_stream as check_half_closeable_stream,
|
||||
check_one_way_stream as check_one_way_stream,
|
||||
check_two_way_stream as check_two_way_stream,
|
||||
)
|
||||
from ._checkpoints import (
|
||||
assert_checkpoints as assert_checkpoints,
|
||||
assert_no_checkpoints as assert_no_checkpoints,
|
||||
)
|
||||
from ._memory_streams import (
|
||||
MemoryReceiveStream as MemoryReceiveStream,
|
||||
MemorySendStream as MemorySendStream,
|
||||
lockstep_stream_one_way_pair as lockstep_stream_one_way_pair,
|
||||
lockstep_stream_pair as lockstep_stream_pair,
|
||||
memory_stream_one_way_pair as memory_stream_one_way_pair,
|
||||
memory_stream_pair as memory_stream_pair,
|
||||
memory_stream_pump as memory_stream_pump,
|
||||
)
|
||||
from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener
|
||||
from ._raises_group import Matcher as Matcher, RaisesGroup as RaisesGroup
|
||||
from ._sequencer import Sequencer as Sequencer
|
||||
from ._trio_test import trio_test as trio_test
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
fixup_module_metadata(__name__, globals())
|
||||
del fixup_module_metadata
|
572
lib/python3.13/site-packages/trio/testing/_check_streams.py
Normal file
572
lib/python3.13/site-packages/trio/testing/_check_streams.py
Normal file
@ -0,0 +1,572 @@
|
||||
# Generic stream tests
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import sys
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generator,
|
||||
Generic,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from .. import CancelScope, _core
|
||||
from .._abc import AsyncResource, HalfCloseableStream, ReceiveStream, SendStream, Stream
|
||||
from .._highlevel_generic import aclose_forcefully
|
||||
from ._checkpoints import assert_checkpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
from typing_extensions import ParamSpec, TypeAlias
|
||||
|
||||
ArgsT = ParamSpec("ArgsT")
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
Res1 = TypeVar("Res1", bound=AsyncResource)
|
||||
Res2 = TypeVar("Res2", bound=AsyncResource)
|
||||
StreamMaker: TypeAlias = Callable[[], Awaitable[Tuple[Res1, Res2]]]
|
||||
|
||||
|
||||
class _ForceCloseBoth(Generic[Res1, Res2]):
|
||||
def __init__(self, both: tuple[Res1, Res2]) -> None:
|
||||
self._first, self._second = both
|
||||
|
||||
async def __aenter__(self) -> tuple[Res1, Res2]:
|
||||
return self._first, self._second
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
try:
|
||||
await aclose_forcefully(self._first)
|
||||
finally:
|
||||
await aclose_forcefully(self._second)
|
||||
|
||||
|
||||
# This is used in this file instead of pytest.raises in order to avoid a dependency
|
||||
# on pytest, as the check_* functions are publicly exported.
|
||||
@contextmanager
|
||||
def _assert_raises(
|
||||
expected_exc: type[BaseException],
|
||||
wrapped: bool = False,
|
||||
) -> Generator[None, None, None]:
|
||||
__tracebackhide__ = True
|
||||
try:
|
||||
yield
|
||||
except BaseExceptionGroup as exc:
|
||||
assert wrapped, "caught exceptiongroup, but expected an unwrapped exception"
|
||||
# assert in except block ignored below
|
||||
assert len(exc.exceptions) == 1 # noqa: PT017
|
||||
assert isinstance(exc.exceptions[0], expected_exc) # noqa: PT017
|
||||
except expected_exc:
|
||||
assert not wrapped, "caught exception, but expected an exceptiongroup"
|
||||
else:
|
||||
raise AssertionError(f"expected exception: {expected_exc}")
|
||||
|
||||
|
||||
async def check_one_way_stream(
|
||||
stream_maker: StreamMaker[SendStream, ReceiveStream],
|
||||
clogged_stream_maker: StreamMaker[SendStream, ReceiveStream] | None,
|
||||
) -> None:
|
||||
"""Perform a number of generic tests on a custom one-way stream
|
||||
implementation.
|
||||
|
||||
Args:
|
||||
stream_maker: An async (!) function which returns a connected
|
||||
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`)
|
||||
pair.
|
||||
clogged_stream_maker: Either None, or an async function similar to
|
||||
stream_maker, but with the extra property that the returned stream
|
||||
is in a state where ``send_all`` and
|
||||
``wait_send_all_might_not_block`` will block until ``receive_some``
|
||||
has been called. This allows for more thorough testing of some edge
|
||||
cases, especially around ``wait_send_all_might_not_block``.
|
||||
|
||||
Raises:
|
||||
AssertionError: if a test fails.
|
||||
|
||||
"""
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
assert isinstance(s, SendStream)
|
||||
assert isinstance(r, ReceiveStream)
|
||||
|
||||
async def do_send_all(data: bytes | bytearray | memoryview) -> None:
|
||||
with assert_checkpoints(): # We're testing that it doesn't return anything.
|
||||
assert await s.send_all(data) is None # type: ignore[func-returns-value]
|
||||
|
||||
async def do_receive_some(max_bytes: int | None = None) -> bytes | bytearray:
|
||||
with assert_checkpoints():
|
||||
return await r.receive_some(max_bytes)
|
||||
|
||||
async def checked_receive_1(expected: bytes) -> None:
|
||||
assert await do_receive_some(1) == expected
|
||||
|
||||
async def do_aclose(resource: AsyncResource) -> None:
|
||||
with assert_checkpoints():
|
||||
await resource.aclose()
|
||||
|
||||
# Simple sending/receiving
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, b"x")
|
||||
nursery.start_soon(checked_receive_1, b"x")
|
||||
|
||||
async def send_empty_then_y() -> None:
|
||||
# Streams should tolerate sending b"" without giving it any
|
||||
# special meaning.
|
||||
await do_send_all(b"")
|
||||
await do_send_all(b"y")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_empty_then_y)
|
||||
nursery.start_soon(checked_receive_1, b"y")
|
||||
|
||||
# ---- Checking various argument types ----
|
||||
|
||||
# send_all accepts bytearray and memoryview
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, bytearray(b"1"))
|
||||
nursery.start_soon(checked_receive_1, b"1")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, memoryview(b"2"))
|
||||
nursery.start_soon(checked_receive_1, b"2")
|
||||
|
||||
# max_bytes must be a positive integer
|
||||
with _assert_raises(ValueError):
|
||||
await r.receive_some(-1)
|
||||
with _assert_raises(ValueError):
|
||||
await r.receive_some(0)
|
||||
with _assert_raises(TypeError):
|
||||
await r.receive_some(1.5) # type: ignore[arg-type]
|
||||
# it can also be missing or None
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, b"x")
|
||||
assert await do_receive_some() == b"x"
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all, b"x")
|
||||
assert await do_receive_some(None) == b"x"
|
||||
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive_some, 1)
|
||||
nursery.start_soon(do_receive_some, 1)
|
||||
|
||||
# Method always has to exist, and an empty stream with a blocked
|
||||
# receive_some should *always* allow send_all. (Technically it's legal
|
||||
# for send_all to wait until receive_some is called to run, though; a
|
||||
# stream doesn't *have* to have any internal buffering. That's why we
|
||||
# start a concurrent receive_some call, then cancel it.)
|
||||
async def simple_check_wait_send_all_might_not_block(
|
||||
scope: CancelScope,
|
||||
) -> None:
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
scope.cancel()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
simple_check_wait_send_all_might_not_block,
|
||||
nursery.cancel_scope,
|
||||
)
|
||||
nursery.start_soon(do_receive_some, 1)
|
||||
|
||||
# closing the r side leads to BrokenResourceError on the s side
|
||||
# (eventually)
|
||||
async def expect_broken_stream_on_send() -> None:
|
||||
with _assert_raises(_core.BrokenResourceError):
|
||||
while True:
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_broken_stream_on_send)
|
||||
nursery.start_soon(do_aclose, r)
|
||||
|
||||
# once detected, the stream stays broken
|
||||
with _assert_raises(_core.BrokenResourceError):
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
# r closed -> ClosedResourceError on the receive side
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_receive_some(4096)
|
||||
|
||||
# we can close the same stream repeatedly, it's fine
|
||||
await do_aclose(r)
|
||||
await do_aclose(r)
|
||||
|
||||
# closing the sender side
|
||||
await do_aclose(s)
|
||||
|
||||
# now trying to send raises ClosedResourceError
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
# even if it's an empty send
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"")
|
||||
|
||||
# ditto for wait_send_all_might_not_block
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
|
||||
# and again, repeated closing is fine
|
||||
await do_aclose(s)
|
||||
await do_aclose(s)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
# if send-then-graceful-close, receiver gets data then b""
|
||||
async def send_then_close() -> None:
|
||||
await do_send_all(b"y")
|
||||
await do_aclose(s)
|
||||
|
||||
async def receive_send_then_close() -> None:
|
||||
# We want to make sure that if the sender closes the stream before
|
||||
# we read anything, then we still get all the data. But some
|
||||
# streams might block on the do_send_all call. So we let the
|
||||
# sender get as far as it can, then we receive.
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await checked_receive_1(b"y")
|
||||
await checked_receive_1(b"")
|
||||
await do_aclose(r)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_then_close)
|
||||
nursery.start_soon(receive_send_then_close)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
await aclose_forcefully(r)
|
||||
|
||||
with _assert_raises(_core.BrokenResourceError):
|
||||
while True:
|
||||
await do_send_all(b"x" * 100)
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_receive_some(4096)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
await aclose_forcefully(s)
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"123")
|
||||
|
||||
# after the sender does a forceful close, the receiver might either
|
||||
# get BrokenResourceError or a clean b""; either is OK. Not OK would be
|
||||
# if it freezes, or returns data.
|
||||
with suppress(_core.BrokenResourceError):
|
||||
await checked_receive_1(b"")
|
||||
|
||||
# cancelled aclose still closes
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
with _core.CancelScope() as scope:
|
||||
scope.cancel()
|
||||
await r.aclose()
|
||||
|
||||
with _core.CancelScope() as scope:
|
||||
scope.cancel()
|
||||
await s.aclose()
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"123")
|
||||
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await do_receive_some(4096)
|
||||
|
||||
# Check that we can still gracefully close a stream after an operation has
|
||||
# been cancelled. This can be challenging if cancellation can leave the
|
||||
# stream internals in an inconsistent state, e.g. for
|
||||
# SSLStream. Unfortunately this test isn't very thorough; the really
|
||||
# challenging case for something like SSLStream is it gets cancelled
|
||||
# *while* it's sending data on the underlying, not before. But testing
|
||||
# that requires some special-case handling of the particular stream setup;
|
||||
# we can't do it here. Maybe we could do a bit better with
|
||||
# https://github.com/python-trio/trio/issues/77
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
|
||||
async def expect_cancelled(
|
||||
afn: Callable[ArgsT, Awaitable[object]],
|
||||
*args: ArgsT.args,
|
||||
**kwargs: ArgsT.kwargs,
|
||||
) -> None:
|
||||
with _assert_raises(_core.Cancelled):
|
||||
await afn(*args, **kwargs)
|
||||
|
||||
with _core.CancelScope() as scope:
|
||||
scope.cancel()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_cancelled, do_send_all, b"x")
|
||||
nursery.start_soon(expect_cancelled, do_receive_some, 1)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_aclose, s)
|
||||
nursery.start_soon(do_aclose, r)
|
||||
|
||||
# Check that if a task is blocked in receive_some, then closing the
|
||||
# receive stream causes it to wake up.
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s, r):
|
||||
|
||||
async def receive_expecting_closed():
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_expecting_closed)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await aclose_forcefully(r)
|
||||
|
||||
# check wait_send_all_might_not_block, if we can
|
||||
if clogged_stream_maker is not None:
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
record: list[str] = []
|
||||
|
||||
async def waiter(cancel_scope: CancelScope) -> None:
|
||||
record.append("waiter sleeping")
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
record.append("waiter wokeup")
|
||||
cancel_scope.cancel()
|
||||
|
||||
async def receiver() -> None:
|
||||
# give wait_send_all_might_not_block a chance to block
|
||||
await _core.wait_all_tasks_blocked()
|
||||
record.append("receiver starting")
|
||||
while True:
|
||||
await r.receive_some(16834)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(waiter, nursery.cancel_scope)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
assert record == [
|
||||
"waiter sleeping",
|
||||
"receiver starting",
|
||||
"waiter wokeup",
|
||||
]
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
# simultaneous wait_send_all_might_not_block fails
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s.wait_send_all_might_not_block)
|
||||
nursery.start_soon(s.wait_send_all_might_not_block)
|
||||
|
||||
# and simultaneous send_all and wait_send_all_might_not_block (NB
|
||||
# this test might destroy the stream b/c we end up cancelling
|
||||
# send_all and e.g. SSLStream can't handle that, so we have to
|
||||
# recreate afterwards)
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s.wait_send_all_might_not_block)
|
||||
nursery.start_soon(s.send_all, b"123")
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
# send_all and send_all blocked simultaneously should also raise
|
||||
# (but again this might destroy the stream)
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s.send_all, b"123")
|
||||
nursery.start_soon(s.send_all, b"123")
|
||||
|
||||
# closing the receiver causes wait_send_all_might_not_block to return,
|
||||
# with or without an exception
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
|
||||
async def sender() -> None:
|
||||
try:
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
except _core.BrokenResourceError: # pragma: no cover
|
||||
pass
|
||||
|
||||
async def receiver() -> None:
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await aclose_forcefully(r)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
# and again with the call starting after the close
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
await aclose_forcefully(r)
|
||||
try:
|
||||
with assert_checkpoints():
|
||||
await s.wait_send_all_might_not_block()
|
||||
except _core.BrokenResourceError: # pragma: no cover
|
||||
pass
|
||||
|
||||
# Check that if a task is blocked in a send-side method, then closing
|
||||
# the send stream causes it to wake up.
|
||||
async def close_soon(s: SendStream) -> None:
|
||||
await _core.wait_all_tasks_blocked()
|
||||
await aclose_forcefully(s)
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(close_soon, s)
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await s.send_all(b"xyzzy")
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s, r):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(close_soon, s)
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await s.wait_send_all_might_not_block()
|
||||
|
||||
|
||||
async def check_two_way_stream(
|
||||
stream_maker: StreamMaker[Stream, Stream],
|
||||
clogged_stream_maker: StreamMaker[Stream, Stream] | None,
|
||||
) -> None:
|
||||
"""Perform a number of generic tests on a custom two-way stream
|
||||
implementation.
|
||||
|
||||
This is similar to :func:`check_one_way_stream`, except that the maker
|
||||
functions are expected to return objects implementing the
|
||||
:class:`~trio.abc.Stream` interface.
|
||||
|
||||
This function tests a *superset* of what :func:`check_one_way_stream`
|
||||
checks – if you call this, then you don't need to also call
|
||||
:func:`check_one_way_stream`.
|
||||
|
||||
"""
|
||||
await check_one_way_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
async def flipped_stream_maker() -> tuple[Stream, Stream]:
|
||||
return (await stream_maker())[::-1]
|
||||
|
||||
flipped_clogged_stream_maker: Callable[[], Awaitable[tuple[Stream, Stream]]] | None
|
||||
|
||||
if clogged_stream_maker is not None:
|
||||
|
||||
async def flipped_clogged_stream_maker() -> tuple[Stream, Stream]:
|
||||
return (await clogged_stream_maker())[::-1]
|
||||
|
||||
else:
|
||||
flipped_clogged_stream_maker = None
|
||||
await check_one_way_stream(flipped_stream_maker, flipped_clogged_stream_maker)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
|
||||
assert isinstance(s1, Stream)
|
||||
assert isinstance(s2, Stream)
|
||||
|
||||
# Duplex can be a bit tricky, might as well check it as well
|
||||
DUPLEX_TEST_SIZE = 2**20
|
||||
CHUNK_SIZE_MAX = 2**14
|
||||
|
||||
r = random.Random(0)
|
||||
i = r.getrandbits(8 * DUPLEX_TEST_SIZE)
|
||||
test_data = i.to_bytes(DUPLEX_TEST_SIZE, "little")
|
||||
|
||||
async def sender(
|
||||
s: Stream,
|
||||
data: bytes | bytearray | memoryview,
|
||||
seed: int,
|
||||
) -> None:
|
||||
r = random.Random(seed)
|
||||
m = memoryview(data)
|
||||
while m:
|
||||
chunk_size = r.randint(1, CHUNK_SIZE_MAX)
|
||||
await s.send_all(m[:chunk_size])
|
||||
m = m[chunk_size:]
|
||||
|
||||
async def receiver(s: Stream, data: bytes | bytearray, seed: int) -> None:
|
||||
r = random.Random(seed)
|
||||
got = bytearray()
|
||||
while len(got) < len(data):
|
||||
chunk = await s.receive_some(r.randint(1, CHUNK_SIZE_MAX))
|
||||
assert chunk
|
||||
got += chunk
|
||||
assert got == data
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender, s1, test_data, 0)
|
||||
nursery.start_soon(sender, s2, test_data[::-1], 1)
|
||||
nursery.start_soon(receiver, s1, test_data[::-1], 2)
|
||||
nursery.start_soon(receiver, s2, test_data, 3)
|
||||
|
||||
async def expect_receive_some_empty() -> None:
|
||||
assert await s2.receive_some(10) == b""
|
||||
await s2.aclose()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_receive_some_empty)
|
||||
nursery.start_soon(s1.aclose)
|
||||
|
||||
|
||||
async def check_half_closeable_stream(
|
||||
stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream],
|
||||
clogged_stream_maker: StreamMaker[HalfCloseableStream, HalfCloseableStream] | None,
|
||||
) -> None:
|
||||
"""Perform a number of generic tests on a custom half-closeable stream
|
||||
implementation.
|
||||
|
||||
This is similar to :func:`check_two_way_stream`, except that the maker
|
||||
functions are expected to return objects that implement the
|
||||
:class:`~trio.abc.HalfCloseableStream` interface.
|
||||
|
||||
This function tests a *superset* of what :func:`check_two_way_stream`
|
||||
checks – if you call this, then you don't need to also call
|
||||
:func:`check_two_way_stream`.
|
||||
|
||||
"""
|
||||
await check_two_way_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
async with _ForceCloseBoth(await stream_maker()) as (s1, s2):
|
||||
assert isinstance(s1, HalfCloseableStream)
|
||||
assert isinstance(s2, HalfCloseableStream)
|
||||
|
||||
async def send_x_then_eof(s: HalfCloseableStream) -> None:
|
||||
await s.send_all(b"x")
|
||||
with assert_checkpoints():
|
||||
await s.send_eof()
|
||||
|
||||
async def expect_x_then_eof(r: HalfCloseableStream) -> None:
|
||||
await _core.wait_all_tasks_blocked()
|
||||
assert await r.receive_some(10) == b"x"
|
||||
assert await r.receive_some(10) == b""
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_x_then_eof, s1)
|
||||
nursery.start_soon(expect_x_then_eof, s2)
|
||||
|
||||
# now sending is disallowed
|
||||
with _assert_raises(_core.ClosedResourceError):
|
||||
await s1.send_all(b"y")
|
||||
|
||||
# but we can do send_eof again
|
||||
with assert_checkpoints():
|
||||
await s1.send_eof()
|
||||
|
||||
# and we can still send stuff back the other way
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(send_x_then_eof, s2)
|
||||
nursery.start_soon(expect_x_then_eof, s1)
|
||||
|
||||
if clogged_stream_maker is not None:
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
|
||||
# send_all and send_eof simultaneously is not ok
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s1.send_all, b"x")
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.start_soon(s1.send_eof)
|
||||
|
||||
async with _ForceCloseBoth(await clogged_stream_maker()) as (s1, s2):
|
||||
# wait_send_all_might_not_block and send_eof simultaneously is not
|
||||
# ok either
|
||||
with _assert_raises(_core.BusyResourceError, wrapped=True):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(s1.wait_send_all_might_not_block)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
nursery.start_soon(s1.send_eof)
|
69
lib/python3.13/site-packages/trio/testing/_checkpoints.py
Normal file
69
lib/python3.13/site-packages/trio/testing/_checkpoints.py
Normal file
@ -0,0 +1,69 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import _core
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _assert_yields_or_not(expected: bool) -> Generator[None, None, None]:
|
||||
"""Check if checkpoints are executed in a block of code."""
|
||||
__tracebackhide__ = True
|
||||
task = _core.current_task()
|
||||
orig_cancel = task._cancel_points
|
||||
orig_schedule = task._schedule_points
|
||||
try:
|
||||
yield
|
||||
if expected and (
|
||||
task._cancel_points == orig_cancel or task._schedule_points == orig_schedule
|
||||
):
|
||||
raise AssertionError("assert_checkpoints block did not yield!")
|
||||
finally:
|
||||
if not expected and (
|
||||
task._cancel_points != orig_cancel or task._schedule_points != orig_schedule
|
||||
):
|
||||
raise AssertionError("assert_no_checkpoints block yielded!")
|
||||
|
||||
|
||||
def assert_checkpoints() -> AbstractContextManager[None]:
|
||||
"""Use as a context manager to check that the code inside the ``with``
|
||||
block either exits with an exception or executes at least one
|
||||
:ref:`checkpoint <checkpoints>`.
|
||||
|
||||
Raises:
|
||||
AssertionError: if no checkpoint was executed.
|
||||
|
||||
Example:
|
||||
Check that :func:`trio.sleep` is a checkpoint, even if it doesn't
|
||||
block::
|
||||
|
||||
with trio.testing.assert_checkpoints():
|
||||
await trio.sleep(0)
|
||||
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
return _assert_yields_or_not(True)
|
||||
|
||||
|
||||
def assert_no_checkpoints() -> AbstractContextManager[None]:
|
||||
"""Use as a context manager to check that the code inside the ``with``
|
||||
block does not execute any :ref:`checkpoints <checkpoints>`.
|
||||
|
||||
Raises:
|
||||
AssertionError: if a checkpoint was executed.
|
||||
|
||||
Example:
|
||||
Synchronous code never contains any checkpoints, but we can double-check
|
||||
that::
|
||||
|
||||
send_channel, receive_channel = trio.open_memory_channel(10)
|
||||
with trio.testing.assert_no_checkpoints():
|
||||
send_channel.send_nowait(None)
|
||||
|
||||
"""
|
||||
__tracebackhide__ = True
|
||||
return _assert_yields_or_not(False)
|
578
lib/python3.13/site-packages/trio/testing/_fake_net.py
Normal file
578
lib/python3.13/site-packages/trio/testing/_fake_net.py
Normal file
@ -0,0 +1,578 @@
|
||||
# This should eventually be cleaned up and become public, but for right now I'm just
|
||||
# implementing enough to test DTLS.
|
||||
|
||||
# TODO:
|
||||
# - user-defined routers
|
||||
# - TCP
|
||||
# - UDP broadcast
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import errno
|
||||
import ipaddress
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Iterable,
|
||||
NoReturn,
|
||||
TypeVar,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attrs
|
||||
|
||||
import trio
|
||||
from trio._util import NoPublicConstructor, final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import builtins
|
||||
from socket import AddressFamily, SocketKind
|
||||
from types import TracebackType
|
||||
|
||||
from typing_extensions import Buffer, Self, TypeAlias
|
||||
|
||||
IPAddress: TypeAlias = Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
|
||||
|
||||
|
||||
def _family_for(ip: IPAddress) -> int:
|
||||
if isinstance(ip, ipaddress.IPv4Address):
|
||||
return trio.socket.AF_INET
|
||||
elif isinstance(ip, ipaddress.IPv6Address):
|
||||
return trio.socket.AF_INET6
|
||||
raise NotImplementedError("Unhandled IPAddress instance type") # pragma: no cover
|
||||
|
||||
|
||||
def _wildcard_ip_for(family: int) -> IPAddress:
|
||||
if family == trio.socket.AF_INET:
|
||||
return ipaddress.ip_address("0.0.0.0")
|
||||
elif family == trio.socket.AF_INET6:
|
||||
return ipaddress.ip_address("::")
|
||||
raise NotImplementedError("Unhandled ip address family") # pragma: no cover
|
||||
|
||||
|
||||
# not used anywhere
|
||||
def _localhost_ip_for(family: int) -> IPAddress: # pragma: no cover
|
||||
if family == trio.socket.AF_INET:
|
||||
return ipaddress.ip_address("127.0.0.1")
|
||||
elif family == trio.socket.AF_INET6:
|
||||
return ipaddress.ip_address("::1")
|
||||
raise NotImplementedError("Unhandled ip address family")
|
||||
|
||||
|
||||
def _fake_err(code: int) -> NoReturn:
|
||||
raise OSError(code, os.strerror(code))
|
||||
|
||||
|
||||
def _scatter(data: bytes, buffers: Iterable[Buffer]) -> int:
|
||||
written = 0
|
||||
for buf in buffers: # pragma: no branch
|
||||
next_piece = data[written : written + memoryview(buf).nbytes]
|
||||
with memoryview(buf) as mbuf:
|
||||
mbuf[: len(next_piece)] = next_piece
|
||||
written += len(next_piece)
|
||||
if written == len(data): # pragma: no branch
|
||||
break
|
||||
return written
|
||||
|
||||
|
||||
T_UDPEndpoint = TypeVar("T_UDPEndpoint", bound="UDPEndpoint")
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class UDPEndpoint:
|
||||
ip: IPAddress
|
||||
port: int
|
||||
|
||||
def as_python_sockaddr(self) -> tuple[str, int] | tuple[str, int, int, int]:
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int] = (
|
||||
self.ip.compressed,
|
||||
self.port,
|
||||
)
|
||||
if isinstance(self.ip, ipaddress.IPv6Address):
|
||||
sockaddr += (0, 0) # type: ignore[assignment]
|
||||
return sockaddr
|
||||
|
||||
@classmethod
|
||||
def from_python_sockaddr(
|
||||
cls: type[T_UDPEndpoint],
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
) -> T_UDPEndpoint:
|
||||
ip, port = sockaddr[:2]
|
||||
return cls(ip=ipaddress.ip_address(ip), port=port)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class UDPBinding:
|
||||
local: UDPEndpoint
|
||||
# remote: UDPEndpoint # ??
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class UDPPacket:
|
||||
source: UDPEndpoint
|
||||
destination: UDPEndpoint
|
||||
payload: bytes = attrs.field(repr=lambda p: p.hex())
|
||||
|
||||
# not used/tested anywhere
|
||||
def reply(self, payload: bytes) -> UDPPacket: # pragma: no cover
|
||||
return UDPPacket(
|
||||
source=self.destination,
|
||||
destination=self.source,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class FakeSocketFactory(trio.abc.SocketFactory):
|
||||
fake_net: FakeNet
|
||||
|
||||
def socket(self, family: int, type_: int, proto: int) -> FakeSocket: # type: ignore[override]
|
||||
return FakeSocket._create(self.fake_net, family, type_, proto)
|
||||
|
||||
|
||||
@attrs.frozen
|
||||
class FakeHostnameResolver(trio.abc.HostnameResolver):
|
||||
fake_net: FakeNet
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError("FakeNet doesn't do fake DNS yet")
|
||||
|
||||
|
||||
@final
|
||||
class FakeNet:
|
||||
def __init__(self) -> None:
|
||||
# When we need to pick an arbitrary unique ip address/port, use these:
|
||||
self._auto_ipv4_iter = ipaddress.IPv4Network("1.0.0.0/8").hosts() # untested
|
||||
self._auto_ipv6_iter = ipaddress.IPv6Network("1::/16").hosts() # untested
|
||||
self._auto_port_iter = iter(range(50000, 65535))
|
||||
|
||||
self._bound: dict[UDPBinding, FakeSocket] = {}
|
||||
|
||||
self.route_packet = None
|
||||
|
||||
def _bind(self, binding: UDPBinding, socket: FakeSocket) -> None:
|
||||
if binding in self._bound:
|
||||
_fake_err(errno.EADDRINUSE)
|
||||
self._bound[binding] = socket
|
||||
|
||||
def enable(self) -> None:
|
||||
trio.socket.set_custom_socket_factory(FakeSocketFactory(self))
|
||||
trio.socket.set_custom_hostname_resolver(FakeHostnameResolver(self))
|
||||
|
||||
def send_packet(self, packet: UDPPacket) -> None:
|
||||
if self.route_packet is None:
|
||||
self.deliver_packet(packet)
|
||||
else:
|
||||
self.route_packet(packet)
|
||||
|
||||
def deliver_packet(self, packet: UDPPacket) -> None:
|
||||
binding = UDPBinding(local=packet.destination)
|
||||
if binding in self._bound:
|
||||
self._bound[binding]._deliver_packet(packet)
|
||||
else:
|
||||
# No valid destination, so drop it
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class FakeSocket(trio.socket.SocketType, metaclass=NoPublicConstructor):
|
||||
def __init__(
|
||||
self,
|
||||
fake_net: FakeNet,
|
||||
family: AddressFamily,
|
||||
type: SocketKind,
|
||||
proto: int,
|
||||
):
|
||||
self._fake_net = fake_net
|
||||
|
||||
if not family: # pragma: no cover
|
||||
family = trio.socket.AF_INET
|
||||
if not type: # pragma: no cover
|
||||
type = trio.socket.SOCK_STREAM # noqa: A001 # name shadowing builtin
|
||||
|
||||
if family not in (trio.socket.AF_INET, trio.socket.AF_INET6):
|
||||
raise NotImplementedError(f"FakeNet doesn't (yet) support family={family}")
|
||||
if type != trio.socket.SOCK_DGRAM:
|
||||
raise NotImplementedError(f"FakeNet doesn't (yet) support type={type}")
|
||||
|
||||
self._family = family
|
||||
self._type = type
|
||||
self._proto = proto
|
||||
|
||||
self._closed = False
|
||||
|
||||
self._packet_sender, self._packet_receiver = trio.open_memory_channel[
|
||||
UDPPacket
|
||||
](float("inf"))
|
||||
|
||||
# This is the source-of-truth for what port etc. this socket is bound to
|
||||
self._binding: UDPBinding | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily:
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int:
|
||||
return self._proto
|
||||
|
||||
def _check_closed(self) -> None:
|
||||
if self._closed:
|
||||
_fake_err(errno.EBADF)
|
||||
|
||||
def close(self) -> None:
|
||||
if self._closed:
|
||||
return
|
||||
self._closed = True
|
||||
if self._binding is not None:
|
||||
del self._fake_net._bound[self._binding]
|
||||
self._packet_receiver.close()
|
||||
|
||||
async def _resolve_address_nocp(
|
||||
self,
|
||||
address: object,
|
||||
*,
|
||||
local: bool,
|
||||
) -> tuple[str, int]:
|
||||
return await trio._socket._resolve_address_nocp( # type: ignore[no-any-return]
|
||||
self.type,
|
||||
self.family,
|
||||
self.proto,
|
||||
address=address,
|
||||
ipv6_v6only=False,
|
||||
local=local,
|
||||
)
|
||||
|
||||
def _deliver_packet(self, packet: UDPPacket) -> None:
|
||||
# sending to a closed socket -- UDP packets get dropped
|
||||
with contextlib.suppress(trio.BrokenResourceError):
|
||||
self._packet_sender.send_nowait(packet)
|
||||
|
||||
################################################################
|
||||
# Actual IO operation implementations
|
||||
################################################################
|
||||
|
||||
async def bind(self, addr: object) -> None:
|
||||
self._check_closed()
|
||||
if self._binding is not None:
|
||||
_fake_err(errno.EINVAL)
|
||||
await trio.lowlevel.checkpoint()
|
||||
ip_str, port, *_ = await self._resolve_address_nocp(addr, local=True)
|
||||
assert _ == [], "TODO: handle other values?"
|
||||
|
||||
ip = ipaddress.ip_address(ip_str)
|
||||
assert _family_for(ip) == self.family
|
||||
# We convert binds to INET_ANY into binds to localhost
|
||||
if ip == ipaddress.ip_address("0.0.0.0"):
|
||||
ip = ipaddress.ip_address("127.0.0.1")
|
||||
elif ip == ipaddress.ip_address("::"):
|
||||
ip = ipaddress.ip_address("::1")
|
||||
if port == 0:
|
||||
port = next(self._fake_net._auto_port_iter)
|
||||
binding = UDPBinding(local=UDPEndpoint(ip, port))
|
||||
self._fake_net._bind(binding, self)
|
||||
self._binding = binding
|
||||
|
||||
async def connect(self, peer: object) -> NoReturn:
|
||||
raise NotImplementedError("FakeNet does not (yet) support connected sockets")
|
||||
|
||||
async def _sendmsg(
|
||||
self,
|
||||
buffers: Iterable[Buffer],
|
||||
ancdata: Iterable[tuple[int, int, Buffer]] = (),
|
||||
flags: int = 0,
|
||||
address: Any | None = None,
|
||||
) -> int:
|
||||
self._check_closed()
|
||||
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
if address is not None:
|
||||
address = await self._resolve_address_nocp(address, local=False)
|
||||
if ancdata:
|
||||
raise NotImplementedError("FakeNet doesn't support ancillary data")
|
||||
if flags:
|
||||
raise NotImplementedError(f"FakeNet send flags must be 0, not {flags}")
|
||||
|
||||
if address is None:
|
||||
_fake_err(errno.ENOTCONN)
|
||||
|
||||
destination = UDPEndpoint.from_python_sockaddr(address)
|
||||
|
||||
if self._binding is None:
|
||||
await self.bind((_wildcard_ip_for(self.family).compressed, 0))
|
||||
|
||||
payload = b"".join(buffers)
|
||||
|
||||
assert self._binding is not None
|
||||
packet = UDPPacket(
|
||||
source=self._binding.local,
|
||||
destination=destination,
|
||||
payload=payload,
|
||||
)
|
||||
|
||||
self._fake_net.send_packet(packet)
|
||||
|
||||
return len(payload)
|
||||
|
||||
if sys.platform != "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
|
||||
):
|
||||
sendmsg = _sendmsg
|
||||
|
||||
async def _recvmsg_into(
|
||||
self,
|
||||
buffers: Iterable[Buffer],
|
||||
ancbufsize: int = 0,
|
||||
flags: int = 0,
|
||||
) -> tuple[int, list[tuple[int, int, bytes]], int, Any]:
|
||||
if ancbufsize != 0:
|
||||
raise NotImplementedError("FakeNet doesn't support ancillary data")
|
||||
if flags != 0:
|
||||
raise NotImplementedError("FakeNet doesn't support any recv flags")
|
||||
if self._binding is None:
|
||||
# I messed this up a few times when writing tests ... but it also never happens
|
||||
# in any of the existing tests, so maybe it could be intentional...
|
||||
raise NotImplementedError(
|
||||
"The code will most likely hang if you try to receive on a fakesocket "
|
||||
"without a binding. If that is not the case, or you explicitly want to "
|
||||
"test that, remove this warning.",
|
||||
)
|
||||
|
||||
self._check_closed()
|
||||
|
||||
ancdata: list[tuple[int, int, bytes]] = []
|
||||
msg_flags = 0
|
||||
|
||||
packet = await self._packet_receiver.receive()
|
||||
address = packet.source.as_python_sockaddr()
|
||||
written = _scatter(packet.payload, buffers)
|
||||
if written < len(packet.payload):
|
||||
msg_flags |= trio.socket.MSG_TRUNC
|
||||
return written, ancdata, msg_flags, address
|
||||
|
||||
if sys.platform != "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
|
||||
):
|
||||
recvmsg_into = _recvmsg_into
|
||||
|
||||
################################################################
|
||||
# Simple state query stuff
|
||||
################################################################
|
||||
|
||||
def getsockname(self) -> tuple[str, int] | tuple[str, int, int, int]:
|
||||
self._check_closed()
|
||||
if self._binding is not None:
|
||||
return self._binding.local.as_python_sockaddr()
|
||||
elif self.family == trio.socket.AF_INET:
|
||||
return ("0.0.0.0", 0)
|
||||
else:
|
||||
assert self.family == trio.socket.AF_INET6
|
||||
return ("::", 0)
|
||||
|
||||
# TODO: This method is not tested, and seems to make incorrect assumptions. It should maybe raise NotImplementedError.
|
||||
def getpeername(self) -> tuple[str, int] | tuple[str, int, int, int]:
|
||||
self._check_closed()
|
||||
if self._binding is not None:
|
||||
assert hasattr(
|
||||
self._binding,
|
||||
"remote",
|
||||
), "This method seems to assume that self._binding has a remote UDPEndpoint"
|
||||
if self._binding.remote is not None: # pragma: no cover
|
||||
assert isinstance(
|
||||
self._binding.remote,
|
||||
UDPEndpoint,
|
||||
), "Self._binding.remote should be a UDPEndpoint"
|
||||
return self._binding.remote.as_python_sockaddr()
|
||||
_fake_err(errno.ENOTCONN)
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
|
||||
|
||||
def getsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
self._check_closed()
|
||||
raise OSError(f"FakeNet doesn't implement getsockopt({level}, {optname})")
|
||||
|
||||
@overload
|
||||
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
self._check_closed()
|
||||
|
||||
if (level, optname) == (
|
||||
trio.socket.IPPROTO_IPV6,
|
||||
trio.socket.IPV6_V6ONLY,
|
||||
) and not value:
|
||||
raise NotImplementedError("FakeNet always has IPV6_V6ONLY=True")
|
||||
|
||||
raise OSError(f"FakeNet doesn't implement setsockopt({level}, {optname}, ...)")
|
||||
|
||||
################################################################
|
||||
# Various boilerplate and trivial stubs
|
||||
################################################################
|
||||
|
||||
def __enter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: builtins.type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
async def send(self, data: Buffer, flags: int = 0) -> int:
|
||||
return await self.sendto(data, flags, None)
|
||||
|
||||
@overload
|
||||
async def sendto(
|
||||
self,
|
||||
__data: Buffer,
|
||||
__address: tuple[object, ...] | str | Buffer,
|
||||
) -> int: ...
|
||||
|
||||
@overload
|
||||
async def sendto(
|
||||
self,
|
||||
__data: Buffer,
|
||||
__flags: int,
|
||||
__address: tuple[object, ...] | str | None | Buffer,
|
||||
) -> int: ...
|
||||
|
||||
async def sendto(self, *args: Any) -> int:
|
||||
data: Buffer
|
||||
flags: int
|
||||
address: tuple[object, ...] | str | Buffer
|
||||
if len(args) == 2:
|
||||
data, address = args
|
||||
flags = 0
|
||||
elif len(args) == 3:
|
||||
data, flags, address = args
|
||||
else:
|
||||
raise TypeError("wrong number of arguments")
|
||||
return await self._sendmsg([data], [], flags, address)
|
||||
|
||||
async def recv(self, bufsize: int, flags: int = 0) -> bytes:
|
||||
data, address = await self.recvfrom(bufsize, flags)
|
||||
return data
|
||||
|
||||
async def recv_into(self, buf: Buffer, nbytes: int = 0, flags: int = 0) -> int:
|
||||
got_bytes, address = await self.recvfrom_into(buf, nbytes, flags)
|
||||
return got_bytes
|
||||
|
||||
async def recvfrom(self, bufsize: int, flags: int = 0) -> tuple[bytes, Any]:
|
||||
data, ancdata, msg_flags, address = await self._recvmsg(bufsize, flags)
|
||||
return data, address
|
||||
|
||||
async def recvfrom_into(
|
||||
self,
|
||||
buf: Buffer,
|
||||
nbytes: int = 0,
|
||||
flags: int = 0,
|
||||
) -> tuple[int, Any]:
|
||||
if nbytes != 0 and nbytes != memoryview(buf).nbytes:
|
||||
raise NotImplementedError("partial recvfrom_into")
|
||||
got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into(
|
||||
[buf],
|
||||
0,
|
||||
flags,
|
||||
)
|
||||
return got_nbytes, address
|
||||
|
||||
async def _recvmsg(
|
||||
self,
|
||||
bufsize: int,
|
||||
ancbufsize: int = 0,
|
||||
flags: int = 0,
|
||||
) -> tuple[bytes, list[tuple[int, int, bytes]], int, Any]:
|
||||
buf = bytearray(bufsize)
|
||||
got_nbytes, ancdata, msg_flags, address = await self._recvmsg_into(
|
||||
[buf],
|
||||
ancbufsize,
|
||||
flags,
|
||||
)
|
||||
return (bytes(buf[:got_nbytes]), ancdata, msg_flags, address)
|
||||
|
||||
if sys.platform != "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "sendmsg")
|
||||
):
|
||||
recvmsg = _recvmsg
|
||||
|
||||
def fileno(self) -> int:
|
||||
raise NotImplementedError("can't get fileno() for FakeNet sockets")
|
||||
|
||||
def detach(self) -> int:
|
||||
raise NotImplementedError("can't detach() a FakeNet socket")
|
||||
|
||||
def get_inheritable(self) -> bool:
|
||||
return False
|
||||
|
||||
def set_inheritable(self, inheritable: bool) -> None:
|
||||
if inheritable:
|
||||
raise NotImplementedError("FakeNet can't make inheritable sockets")
|
||||
|
||||
if sys.platform == "win32" or (
|
||||
not TYPE_CHECKING and hasattr(socket.socket, "share")
|
||||
):
|
||||
|
||||
def share(self, process_id: int) -> bytes:
|
||||
raise NotImplementedError("FakeNet can't share sockets")
|
626
lib/python3.13/site-packages/trio/testing/_memory_streams.py
Normal file
626
lib/python3.13/site-packages/trio/testing/_memory_streams.py
Normal file
@ -0,0 +1,626 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, TypeVar
|
||||
|
||||
from .. import _core, _util
|
||||
from .._highlevel_generic import StapledStream
|
||||
from ..abc import ReceiveStream, SendStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
AsyncHook: TypeAlias = Callable[[], Awaitable[object]]
|
||||
# Would be nice to exclude awaitable here, but currently not possible.
|
||||
SyncHook: TypeAlias = Callable[[], object]
|
||||
SendStreamT = TypeVar("SendStreamT", bound=SendStream)
|
||||
ReceiveStreamT = TypeVar("ReceiveStreamT", bound=ReceiveStream)
|
||||
|
||||
|
||||
################################################################
|
||||
# In-memory streams - Unbounded buffer version
|
||||
################################################################
|
||||
|
||||
|
||||
class _UnboundedByteQueue:
|
||||
def __init__(self) -> None:
|
||||
self._data = bytearray()
|
||||
self._closed = False
|
||||
self._lot = _core.ParkingLot()
|
||||
self._fetch_lock = _util.ConflictDetector(
|
||||
"another task is already fetching data",
|
||||
)
|
||||
|
||||
# This object treats "close" as being like closing the send side of a
|
||||
# channel: so after close(), calling put() raises ClosedResourceError, and
|
||||
# calling the get() variants drains the buffer and then returns an empty
|
||||
# bytearray.
|
||||
def close(self) -> None:
|
||||
self._closed = True
|
||||
self._lot.unpark_all()
|
||||
|
||||
def close_and_wipe(self) -> None:
|
||||
self._data = bytearray()
|
||||
self.close()
|
||||
|
||||
def put(self, data: bytes | bytearray | memoryview) -> None:
|
||||
if self._closed:
|
||||
raise _core.ClosedResourceError("virtual connection closed")
|
||||
self._data += data
|
||||
self._lot.unpark_all()
|
||||
|
||||
def _check_max_bytes(self, max_bytes: int | None) -> None:
|
||||
if max_bytes is None:
|
||||
return
|
||||
max_bytes = operator.index(max_bytes)
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
|
||||
def _get_impl(self, max_bytes: int | None) -> bytearray:
|
||||
assert self._closed or self._data
|
||||
if max_bytes is None:
|
||||
max_bytes = len(self._data)
|
||||
if self._data:
|
||||
chunk = self._data[:max_bytes]
|
||||
del self._data[:max_bytes]
|
||||
assert chunk
|
||||
return chunk
|
||||
else:
|
||||
return bytearray()
|
||||
|
||||
def get_nowait(self, max_bytes: int | None = None) -> bytearray:
|
||||
with self._fetch_lock:
|
||||
self._check_max_bytes(max_bytes)
|
||||
if not self._closed and not self._data:
|
||||
raise _core.WouldBlock
|
||||
return self._get_impl(max_bytes)
|
||||
|
||||
async def get(self, max_bytes: int | None = None) -> bytearray:
|
||||
with self._fetch_lock:
|
||||
self._check_max_bytes(max_bytes)
|
||||
if not self._closed and not self._data:
|
||||
await self._lot.park()
|
||||
else:
|
||||
await _core.checkpoint()
|
||||
return self._get_impl(max_bytes)
|
||||
|
||||
|
||||
@_util.final
|
||||
class MemorySendStream(SendStream):
|
||||
"""An in-memory :class:`~trio.abc.SendStream`.
|
||||
|
||||
Args:
|
||||
send_all_hook: An async function, or None. Called from
|
||||
:meth:`send_all`. Can do whatever you like.
|
||||
wait_send_all_might_not_block_hook: An async function, or None. Called
|
||||
from :meth:`wait_send_all_might_not_block`. Can do whatever you
|
||||
like.
|
||||
close_hook: A synchronous function, or None. Called from :meth:`close`
|
||||
and :meth:`aclose`. Can do whatever you like.
|
||||
|
||||
.. attribute:: send_all_hook
|
||||
wait_send_all_might_not_block_hook
|
||||
close_hook
|
||||
|
||||
All of these hooks are also exposed as attributes on the object, and
|
||||
you can change them at any time.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
send_all_hook: AsyncHook | None = None,
|
||||
wait_send_all_might_not_block_hook: AsyncHook | None = None,
|
||||
close_hook: SyncHook | None = None,
|
||||
):
|
||||
self._conflict_detector = _util.ConflictDetector(
|
||||
"another task is using this stream",
|
||||
)
|
||||
self._outgoing = _UnboundedByteQueue()
|
||||
self.send_all_hook = send_all_hook
|
||||
self.wait_send_all_might_not_block_hook = wait_send_all_might_not_block_hook
|
||||
self.close_hook = close_hook
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Places the given data into the object's internal buffer, and then
|
||||
calls the :attr:`send_all_hook` (if any).
|
||||
|
||||
"""
|
||||
# Execute two checkpoints so we have more of a chance to detect
|
||||
# buggy user code that calls this twice at the same time.
|
||||
with self._conflict_detector:
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
self._outgoing.put(data)
|
||||
if self.send_all_hook is not None:
|
||||
await self.send_all_hook()
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
"""Calls the :attr:`wait_send_all_might_not_block_hook` (if any), and
|
||||
then returns immediately.
|
||||
|
||||
"""
|
||||
# Execute two checkpoints so that we have more of a chance to detect
|
||||
# buggy user code that calls this twice at the same time.
|
||||
with self._conflict_detector:
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
# check for being closed:
|
||||
self._outgoing.put(b"")
|
||||
if self.wait_send_all_might_not_block_hook is not None:
|
||||
await self.wait_send_all_might_not_block_hook()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Marks this stream as closed, and then calls the :attr:`close_hook`
|
||||
(if any).
|
||||
|
||||
"""
|
||||
# XXX should this cancel any pending calls to the send_all_hook and
|
||||
# wait_send_all_might_not_block_hook? Those are the only places where
|
||||
# send_all and wait_send_all_might_not_block can be blocked.
|
||||
#
|
||||
# The way we set things up, send_all_hook is memory_stream_pump, and
|
||||
# wait_send_all_might_not_block_hook is unset. memory_stream_pump is
|
||||
# synchronous. So normally, send_all and wait_send_all_might_not_block
|
||||
# cannot block at all.
|
||||
self._outgoing.close()
|
||||
if self.close_hook is not None:
|
||||
self.close_hook()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Same as :meth:`close`, but async."""
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def get_data(self, max_bytes: int | None = None) -> bytearray:
|
||||
"""Retrieves data from the internal buffer, blocking if necessary.
|
||||
|
||||
Args:
|
||||
max_bytes (int or None): The maximum amount of data to
|
||||
retrieve. None (the default) means to retrieve all the data
|
||||
that's present (but still blocks until at least one byte is
|
||||
available).
|
||||
|
||||
Returns:
|
||||
If this stream has been closed, an empty bytearray. Otherwise, the
|
||||
requested data.
|
||||
|
||||
"""
|
||||
return await self._outgoing.get(max_bytes)
|
||||
|
||||
def get_data_nowait(self, max_bytes: int | None = None) -> bytearray:
|
||||
"""Retrieves data from the internal buffer, but doesn't block.
|
||||
|
||||
See :meth:`get_data` for details.
|
||||
|
||||
Raises:
|
||||
trio.WouldBlock: if no data is available to retrieve.
|
||||
|
||||
"""
|
||||
return self._outgoing.get_nowait(max_bytes)
|
||||
|
||||
|
||||
@_util.final
|
||||
class MemoryReceiveStream(ReceiveStream):
|
||||
"""An in-memory :class:`~trio.abc.ReceiveStream`.
|
||||
|
||||
Args:
|
||||
receive_some_hook: An async function, or None. Called from
|
||||
:meth:`receive_some`. Can do whatever you like.
|
||||
close_hook: A synchronous function, or None. Called from :meth:`close`
|
||||
and :meth:`aclose`. Can do whatever you like.
|
||||
|
||||
.. attribute:: receive_some_hook
|
||||
close_hook
|
||||
|
||||
Both hooks are also exposed as attributes on the object, and you can
|
||||
change them at any time.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
receive_some_hook: AsyncHook | None = None,
|
||||
close_hook: SyncHook | None = None,
|
||||
):
|
||||
self._conflict_detector = _util.ConflictDetector(
|
||||
"another task is using this stream",
|
||||
)
|
||||
self._incoming = _UnboundedByteQueue()
|
||||
self._closed = False
|
||||
self.receive_some_hook = receive_some_hook
|
||||
self.close_hook = close_hook
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytearray:
|
||||
"""Calls the :attr:`receive_some_hook` (if any), and then retrieves
|
||||
data from the internal buffer, blocking if necessary.
|
||||
|
||||
"""
|
||||
# Execute two checkpoints so we have more of a chance to detect
|
||||
# buggy user code that calls this twice at the same time.
|
||||
with self._conflict_detector:
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
if self._closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self.receive_some_hook is not None:
|
||||
await self.receive_some_hook()
|
||||
# self._incoming's closure state tracks whether we got an EOF.
|
||||
# self._closed tracks whether we, ourselves, are closed.
|
||||
# self.close() sends an EOF to wake us up and sets self._closed,
|
||||
# so after we wake up we have to check self._closed again.
|
||||
data = await self._incoming.get(max_bytes)
|
||||
if self._closed:
|
||||
raise _core.ClosedResourceError
|
||||
return data
|
||||
|
||||
def close(self) -> None:
|
||||
"""Discards any pending data from the internal buffer, and marks this
|
||||
stream as closed.
|
||||
|
||||
"""
|
||||
self._closed = True
|
||||
self._incoming.close_and_wipe()
|
||||
if self.close_hook is not None:
|
||||
self.close_hook()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Same as :meth:`close`, but async."""
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
def put_data(self, data: bytes | bytearray | memoryview) -> None:
|
||||
"""Appends the given data to the internal buffer."""
|
||||
self._incoming.put(data)
|
||||
|
||||
def put_eof(self) -> None:
|
||||
"""Adds an end-of-file marker to the internal buffer."""
|
||||
self._incoming.close()
|
||||
|
||||
|
||||
def memory_stream_pump(
|
||||
memory_send_stream: MemorySendStream,
|
||||
memory_receive_stream: MemoryReceiveStream,
|
||||
*,
|
||||
max_bytes: int | None = None,
|
||||
) -> bool:
|
||||
"""Take data out of the given :class:`MemorySendStream`'s internal buffer,
|
||||
and put it into the given :class:`MemoryReceiveStream`'s internal buffer.
|
||||
|
||||
Args:
|
||||
memory_send_stream (MemorySendStream): The stream to get data from.
|
||||
memory_receive_stream (MemoryReceiveStream): The stream to put data into.
|
||||
max_bytes (int or None): The maximum amount of data to transfer in this
|
||||
call, or None to transfer all available data.
|
||||
|
||||
Returns:
|
||||
True if it successfully transferred some data, or False if there was no
|
||||
data to transfer.
|
||||
|
||||
This is used to implement :func:`memory_stream_one_way_pair` and
|
||||
:func:`memory_stream_pair`; see the latter's docstring for an example
|
||||
of how you might use it yourself.
|
||||
|
||||
"""
|
||||
try:
|
||||
data = memory_send_stream.get_data_nowait(max_bytes)
|
||||
except _core.WouldBlock:
|
||||
return False
|
||||
try:
|
||||
if not data:
|
||||
memory_receive_stream.put_eof()
|
||||
else:
|
||||
memory_receive_stream.put_data(data)
|
||||
except _core.ClosedResourceError:
|
||||
raise _core.BrokenResourceError("MemoryReceiveStream was closed") from None
|
||||
return True
|
||||
|
||||
|
||||
def memory_stream_one_way_pair() -> tuple[MemorySendStream, MemoryReceiveStream]:
|
||||
"""Create a connected, pure-Python, unidirectional stream with infinite
|
||||
buffering and flexible configuration options.
|
||||
|
||||
You can think of this as being a no-operating-system-involved
|
||||
Trio-streamsified version of :func:`os.pipe` (except that :func:`os.pipe`
|
||||
returns the streams in the wrong order – we follow the superior convention
|
||||
that data flows from left to right).
|
||||
|
||||
Returns:
|
||||
A tuple (:class:`MemorySendStream`, :class:`MemoryReceiveStream`), where
|
||||
the :class:`MemorySendStream` has its hooks set up so that it calls
|
||||
:func:`memory_stream_pump` from its
|
||||
:attr:`~MemorySendStream.send_all_hook` and
|
||||
:attr:`~MemorySendStream.close_hook`.
|
||||
|
||||
The end result is that data automatically flows from the
|
||||
:class:`MemorySendStream` to the :class:`MemoryReceiveStream`. But you're
|
||||
also free to rearrange things however you like. For example, you can
|
||||
temporarily set the :attr:`~MemorySendStream.send_all_hook` to None if you
|
||||
want to simulate a stall in data transmission. Or see
|
||||
:func:`memory_stream_pair` for a more elaborate example.
|
||||
|
||||
"""
|
||||
send_stream = MemorySendStream()
|
||||
recv_stream = MemoryReceiveStream()
|
||||
|
||||
def pump_from_send_stream_to_recv_stream() -> None:
|
||||
memory_stream_pump(send_stream, recv_stream)
|
||||
|
||||
async def async_pump_from_send_stream_to_recv_stream() -> None:
|
||||
pump_from_send_stream_to_recv_stream()
|
||||
|
||||
send_stream.send_all_hook = async_pump_from_send_stream_to_recv_stream
|
||||
send_stream.close_hook = pump_from_send_stream_to_recv_stream
|
||||
return send_stream, recv_stream
|
||||
|
||||
|
||||
def _make_stapled_pair(
|
||||
one_way_pair: Callable[[], tuple[SendStreamT, ReceiveStreamT]],
|
||||
) -> tuple[
|
||||
StapledStream[SendStreamT, ReceiveStreamT],
|
||||
StapledStream[SendStreamT, ReceiveStreamT],
|
||||
]:
|
||||
pipe1_send, pipe1_recv = one_way_pair()
|
||||
pipe2_send, pipe2_recv = one_way_pair()
|
||||
stream1 = StapledStream(pipe1_send, pipe2_recv)
|
||||
stream2 = StapledStream(pipe2_send, pipe1_recv)
|
||||
return stream1, stream2
|
||||
|
||||
|
||||
def memory_stream_pair() -> tuple[
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
]:
|
||||
"""Create a connected, pure-Python, bidirectional stream with infinite
|
||||
buffering and flexible configuration options.
|
||||
|
||||
This is a convenience function that creates two one-way streams using
|
||||
:func:`memory_stream_one_way_pair`, and then uses
|
||||
:class:`~trio.StapledStream` to combine them into a single bidirectional
|
||||
stream.
|
||||
|
||||
This is like a no-operating-system-involved, Trio-streamsified version of
|
||||
:func:`socket.socketpair`.
|
||||
|
||||
Returns:
|
||||
A pair of :class:`~trio.StapledStream` objects that are connected so
|
||||
that data automatically flows from one to the other in both directions.
|
||||
|
||||
After creating a stream pair, you can send data back and forth, which is
|
||||
enough for simple tests::
|
||||
|
||||
left, right = memory_stream_pair()
|
||||
await left.send_all(b"123")
|
||||
assert await right.receive_some() == b"123"
|
||||
await right.send_all(b"456")
|
||||
assert await left.receive_some() == b"456"
|
||||
|
||||
But if you read the docs for :class:`~trio.StapledStream` and
|
||||
:func:`memory_stream_one_way_pair`, you'll see that all the pieces
|
||||
involved in wiring this up are public APIs, so you can adjust to suit the
|
||||
requirements of your tests. For example, here's how to tweak a stream so
|
||||
that data flowing from left to right trickles in one byte at a time (but
|
||||
data flowing from right to left proceeds at full speed)::
|
||||
|
||||
left, right = memory_stream_pair()
|
||||
async def trickle():
|
||||
# left is a StapledStream, and left.send_stream is a MemorySendStream
|
||||
# right is a StapledStream, and right.recv_stream is a MemoryReceiveStream
|
||||
while memory_stream_pump(left.send_stream, right.recv_stream, max_bytes=1):
|
||||
# Pause between each byte
|
||||
await trio.sleep(1)
|
||||
# Normally this send_all_hook calls memory_stream_pump directly without
|
||||
# passing in a max_bytes. We replace it with our custom version:
|
||||
left.send_stream.send_all_hook = trickle
|
||||
|
||||
And here's a simple test using our modified stream objects::
|
||||
|
||||
async def sender():
|
||||
await left.send_all(b"12345")
|
||||
await left.send_eof()
|
||||
|
||||
async def receiver():
|
||||
async for data in right:
|
||||
print(data)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
By default, this will print ``b"12345"`` and then immediately exit; with
|
||||
our trickle stream it instead sleeps 1 second, then prints ``b"1"``, then
|
||||
sleeps 1 second, then prints ``b"2"``, etc.
|
||||
|
||||
Pro-tip: you can insert sleep calls (like in our example above) to
|
||||
manipulate the flow of data across tasks... and then use
|
||||
:class:`MockClock` and its :attr:`~MockClock.autojump_threshold`
|
||||
functionality to keep your test suite running quickly.
|
||||
|
||||
If you want to stress test a protocol implementation, one nice trick is to
|
||||
use the :mod:`random` module (preferably with a fixed seed) to move random
|
||||
numbers of bytes at a time, and insert random sleeps in between them. You
|
||||
can also set up a custom :attr:`~MemoryReceiveStream.receive_some_hook` if
|
||||
you want to manipulate things on the receiving side, and not just the
|
||||
sending side.
|
||||
|
||||
"""
|
||||
return _make_stapled_pair(memory_stream_one_way_pair)
|
||||
|
||||
|
||||
################################################################
|
||||
# In-memory streams - Lockstep version
|
||||
################################################################
|
||||
|
||||
|
||||
class _LockstepByteQueue:
|
||||
def __init__(self) -> None:
|
||||
self._data = bytearray()
|
||||
self._sender_closed = False
|
||||
self._receiver_closed = False
|
||||
self._receiver_waiting = False
|
||||
self._waiters = _core.ParkingLot()
|
||||
self._send_conflict_detector = _util.ConflictDetector(
|
||||
"another task is already sending",
|
||||
)
|
||||
self._receive_conflict_detector = _util.ConflictDetector(
|
||||
"another task is already receiving",
|
||||
)
|
||||
|
||||
def _something_happened(self) -> None:
|
||||
self._waiters.unpark_all()
|
||||
|
||||
# Always wakes up when one side is closed, because everyone always reacts
|
||||
# to that.
|
||||
async def _wait_for(self, fn: Callable[[], bool]) -> None:
|
||||
while True:
|
||||
if fn():
|
||||
break
|
||||
if self._sender_closed or self._receiver_closed:
|
||||
break
|
||||
await self._waiters.park()
|
||||
await _core.checkpoint()
|
||||
|
||||
def close_sender(self) -> None:
|
||||
self._sender_closed = True
|
||||
self._something_happened()
|
||||
|
||||
def close_receiver(self) -> None:
|
||||
self._receiver_closed = True
|
||||
self._something_happened()
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
with self._send_conflict_detector:
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self._receiver_closed:
|
||||
raise _core.BrokenResourceError
|
||||
assert not self._data
|
||||
self._data += data
|
||||
self._something_happened()
|
||||
await self._wait_for(lambda: self._data == b"")
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self._data and self._receiver_closed:
|
||||
raise _core.BrokenResourceError
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
with self._send_conflict_detector:
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
if self._receiver_closed:
|
||||
await _core.checkpoint()
|
||||
return
|
||||
await self._wait_for(lambda: self._receiver_waiting)
|
||||
if self._sender_closed:
|
||||
raise _core.ClosedResourceError
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
|
||||
with self._receive_conflict_detector:
|
||||
# Argument validation
|
||||
if max_bytes is not None:
|
||||
max_bytes = operator.index(max_bytes)
|
||||
if max_bytes < 1:
|
||||
raise ValueError("max_bytes must be >= 1")
|
||||
# State validation
|
||||
if self._receiver_closed:
|
||||
raise _core.ClosedResourceError
|
||||
# Wake wait_send_all_might_not_block and wait for data
|
||||
self._receiver_waiting = True
|
||||
self._something_happened()
|
||||
try:
|
||||
await self._wait_for(lambda: self._data != b"")
|
||||
finally:
|
||||
self._receiver_waiting = False
|
||||
if self._receiver_closed:
|
||||
raise _core.ClosedResourceError
|
||||
# Get data, possibly waking send_all
|
||||
if self._data:
|
||||
# Neat trick: if max_bytes is None, then obj[:max_bytes] is
|
||||
# the same as obj[:].
|
||||
got = self._data[:max_bytes]
|
||||
del self._data[:max_bytes]
|
||||
self._something_happened()
|
||||
return got
|
||||
else:
|
||||
assert self._sender_closed
|
||||
return b""
|
||||
|
||||
|
||||
class _LockstepSendStream(SendStream):
|
||||
def __init__(self, lbq: _LockstepByteQueue):
|
||||
self._lbq = lbq
|
||||
|
||||
def close(self) -> None:
|
||||
self._lbq.close_sender()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def send_all(self, data: bytes | bytearray | memoryview) -> None:
|
||||
await self._lbq.send_all(data)
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
await self._lbq.wait_send_all_might_not_block()
|
||||
|
||||
|
||||
class _LockstepReceiveStream(ReceiveStream):
|
||||
def __init__(self, lbq: _LockstepByteQueue):
|
||||
self._lbq = lbq
|
||||
|
||||
def close(self) -> None:
|
||||
self._lbq.close_receiver()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.close()
|
||||
await _core.checkpoint()
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes | bytearray:
|
||||
return await self._lbq.receive_some(max_bytes)
|
||||
|
||||
|
||||
def lockstep_stream_one_way_pair() -> tuple[SendStream, ReceiveStream]:
|
||||
"""Create a connected, pure Python, unidirectional stream where data flows
|
||||
in lockstep.
|
||||
|
||||
Returns:
|
||||
A tuple
|
||||
(:class:`~trio.abc.SendStream`, :class:`~trio.abc.ReceiveStream`).
|
||||
|
||||
This stream has *absolutely no* buffering. Each call to
|
||||
:meth:`~trio.abc.SendStream.send_all` will block until all the given data
|
||||
has been returned by a call to
|
||||
:meth:`~trio.abc.ReceiveStream.receive_some`.
|
||||
|
||||
This can be useful for testing flow control mechanisms in an extreme case,
|
||||
or for setting up "clogged" streams to use with
|
||||
:func:`check_one_way_stream` and friends.
|
||||
|
||||
In addition to fulfilling the :class:`~trio.abc.SendStream` and
|
||||
:class:`~trio.abc.ReceiveStream` interfaces, the return objects
|
||||
also have a synchronous ``close`` method.
|
||||
|
||||
"""
|
||||
|
||||
lbq = _LockstepByteQueue()
|
||||
return _LockstepSendStream(lbq), _LockstepReceiveStream(lbq)
|
||||
|
||||
|
||||
def lockstep_stream_pair() -> tuple[
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
]:
|
||||
"""Create a connected, pure-Python, bidirectional stream where data flows
|
||||
in lockstep.
|
||||
|
||||
Returns:
|
||||
A tuple (:class:`~trio.StapledStream`, :class:`~trio.StapledStream`).
|
||||
|
||||
This is a convenience function that creates two one-way streams using
|
||||
:func:`lockstep_stream_one_way_pair`, and then uses
|
||||
:class:`~trio.StapledStream` to combine them into a single bidirectional
|
||||
stream.
|
||||
|
||||
"""
|
||||
return _make_stapled_pair(lockstep_stream_one_way_pair)
|
36
lib/python3.13/site-packages/trio/testing/_network.py
Normal file
36
lib/python3.13/site-packages/trio/testing/_network.py
Normal file
@ -0,0 +1,36 @@
|
||||
from .. import socket as tsocket
|
||||
from .._highlevel_socket import SocketListener, SocketStream
|
||||
|
||||
|
||||
async def open_stream_to_socket_listener(
|
||||
socket_listener: SocketListener,
|
||||
) -> SocketStream:
|
||||
"""Connect to the given :class:`~trio.SocketListener`.
|
||||
|
||||
This is particularly useful in tests when you want to let a server pick
|
||||
its own port, and then connect to it::
|
||||
|
||||
listeners = await trio.open_tcp_listeners(0)
|
||||
client = await trio.testing.open_stream_to_socket_listener(listeners[0])
|
||||
|
||||
Args:
|
||||
socket_listener (~trio.SocketListener): The
|
||||
:class:`~trio.SocketListener` to connect to.
|
||||
|
||||
Returns:
|
||||
SocketStream: a stream connected to the given listener.
|
||||
|
||||
"""
|
||||
family = socket_listener.socket.family
|
||||
sockaddr = socket_listener.socket.getsockname()
|
||||
if family in (tsocket.AF_INET, tsocket.AF_INET6):
|
||||
sockaddr = list(sockaddr)
|
||||
if sockaddr[0] == "0.0.0.0":
|
||||
sockaddr[0] = "127.0.0.1"
|
||||
if sockaddr[0] == "::":
|
||||
sockaddr[0] = "::1"
|
||||
sockaddr = tuple(sockaddr)
|
||||
|
||||
sock = tsocket.socket(family=family)
|
||||
await sock.connect(sockaddr)
|
||||
return SocketStream(sock)
|
568
lib/python3.13/site-packages/trio/testing/_raises_group.py
Normal file
568
lib/python3.13/site-packages/trio/testing/_raises_group.py
Normal file
@ -0,0 +1,568 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Generic,
|
||||
Literal,
|
||||
Pattern,
|
||||
Sequence,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
from trio._util import final
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import builtins
|
||||
|
||||
# sphinx will *only* work if we use types.TracebackType, and import
|
||||
# *inside* TYPE_CHECKING. No other combination works.....
|
||||
import types
|
||||
|
||||
from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback
|
||||
from typing_extensions import TypeGuard, TypeVar
|
||||
|
||||
MatchE = TypeVar(
|
||||
"MatchE",
|
||||
bound=BaseException,
|
||||
default=BaseException,
|
||||
covariant=True,
|
||||
)
|
||||
else:
|
||||
from typing import TypeVar
|
||||
|
||||
MatchE = TypeVar("MatchE", bound=BaseException, covariant=True)
|
||||
# RaisesGroup doesn't work with a default.
|
||||
E = TypeVar("E", bound=BaseException, covariant=True)
|
||||
# These two typevars are special cased in sphinx config to workaround lookup bugs.
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
|
||||
@final
|
||||
class _ExceptionInfo(Generic[MatchE]):
|
||||
"""Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`."""
|
||||
|
||||
_excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
excinfo: tuple[type[MatchE], MatchE, types.TracebackType] | None,
|
||||
):
|
||||
self._excinfo = excinfo
|
||||
|
||||
def fill_unfilled(
|
||||
self,
|
||||
exc_info: tuple[type[MatchE], MatchE, types.TracebackType],
|
||||
) -> None:
|
||||
"""Fill an unfilled ExceptionInfo created with ``for_later()``."""
|
||||
assert self._excinfo is None, "ExceptionInfo was already filled"
|
||||
self._excinfo = exc_info
|
||||
|
||||
@classmethod
|
||||
def for_later(cls) -> _ExceptionInfo[MatchE]:
|
||||
"""Return an unfilled ExceptionInfo."""
|
||||
return cls(None)
|
||||
|
||||
# Note, special cased in sphinx config, since "type" conflicts.
|
||||
@property
|
||||
def type(self) -> type[MatchE]:
|
||||
"""The exception class."""
|
||||
assert (
|
||||
self._excinfo is not None
|
||||
), ".type can only be used after the context manager exits"
|
||||
return self._excinfo[0]
|
||||
|
||||
@property
|
||||
def value(self) -> MatchE:
|
||||
"""The exception value."""
|
||||
assert (
|
||||
self._excinfo is not None
|
||||
), ".value can only be used after the context manager exits"
|
||||
return self._excinfo[1]
|
||||
|
||||
@property
|
||||
def tb(self) -> types.TracebackType:
|
||||
"""The exception raw traceback."""
|
||||
assert (
|
||||
self._excinfo is not None
|
||||
), ".tb can only be used after the context manager exits"
|
||||
return self._excinfo[2]
|
||||
|
||||
def exconly(self, tryshort: bool = False) -> str:
|
||||
raise NotImplementedError(
|
||||
"This is a helper method only available if you use RaisesGroup with the pytest package installed",
|
||||
)
|
||||
|
||||
def errisinstance(
|
||||
self,
|
||||
exc: builtins.type[BaseException] | tuple[builtins.type[BaseException], ...],
|
||||
) -> bool:
|
||||
raise NotImplementedError(
|
||||
"This is a helper method only available if you use RaisesGroup with the pytest package installed",
|
||||
)
|
||||
|
||||
def getrepr(
|
||||
self,
|
||||
showlocals: bool = False,
|
||||
style: str = "long",
|
||||
abspath: bool = False,
|
||||
tbfilter: bool | Callable[[_ExceptionInfo], Traceback] = True,
|
||||
funcargs: bool = False,
|
||||
truncate_locals: bool = True,
|
||||
chain: bool = True,
|
||||
) -> ReprExceptionInfo | ExceptionChainRepr:
|
||||
raise NotImplementedError(
|
||||
"This is a helper method only available if you use RaisesGroup with the pytest package installed",
|
||||
)
|
||||
|
||||
|
||||
# Type checkers are not able to do conditional types depending on installed packages, so
|
||||
# we've added signatures for all helpers to _ExceptionInfo, and then always use that.
|
||||
# If this ends up leading to problems, we can resort to always using _ExceptionInfo and
|
||||
# users that want to use getrepr/errisinstance/exconly can write helpers on their own, or
|
||||
# we reimplement them ourselves...or get this merged in upstream pytest.
|
||||
if TYPE_CHECKING:
|
||||
ExceptionInfo = _ExceptionInfo
|
||||
|
||||
else:
|
||||
try:
|
||||
from pytest import ExceptionInfo # noqa: PT013
|
||||
except ImportError: # pragma: no cover
|
||||
ExceptionInfo = _ExceptionInfo
|
||||
|
||||
|
||||
# copied from pytest.ExceptionInfo
|
||||
def _stringify_exception(exc: BaseException) -> str:
|
||||
return "\n".join(
|
||||
[
|
||||
getattr(exc, "message", str(exc)),
|
||||
*getattr(exc, "__notes__", []),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# String patterns default to including the unicode flag.
|
||||
_regex_no_flags = re.compile("").flags
|
||||
|
||||
|
||||
@final
|
||||
class Matcher(Generic[MatchE]):
|
||||
"""Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments.
|
||||
The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter.
|
||||
:meth:`trio.testing.Matcher.matches` can also be used standalone to check individual exceptions.
|
||||
|
||||
Examples::
|
||||
|
||||
with RaisesGroups(Matcher(ValueError, match="string"))
|
||||
...
|
||||
with RaisesGroups(Matcher(check=lambda x: x.args == (3, "hello"))):
|
||||
...
|
||||
with RaisesGroups(Matcher(check=lambda x: type(x) is ValueError)):
|
||||
...
|
||||
|
||||
"""
|
||||
|
||||
# At least one of the three parameters must be passed.
|
||||
@overload
|
||||
def __init__(
|
||||
self: Matcher[MatchE],
|
||||
exception_type: type[MatchE],
|
||||
match: str | Pattern[str] = ...,
|
||||
check: Callable[[MatchE], bool] = ...,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: Matcher[BaseException], # Give E a value.
|
||||
*,
|
||||
match: str | Pattern[str],
|
||||
# If exception_type is not provided, check() must do any typechecks itself.
|
||||
check: Callable[[BaseException], bool] = ...,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(self, *, check: Callable[[BaseException], bool]): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exception_type: type[MatchE] | None = None,
|
||||
match: str | Pattern[str] | None = None,
|
||||
check: Callable[[MatchE], bool] | None = None,
|
||||
):
|
||||
if exception_type is None and match is None and check is None:
|
||||
raise ValueError("You must specify at least one parameter to match on.")
|
||||
if exception_type is not None and not issubclass(exception_type, BaseException):
|
||||
raise ValueError(
|
||||
f"exception_type {exception_type} must be a subclass of BaseException",
|
||||
)
|
||||
self.exception_type = exception_type
|
||||
self.match: Pattern[str] | None
|
||||
if isinstance(match, str):
|
||||
self.match = re.compile(match)
|
||||
else:
|
||||
self.match = match
|
||||
self.check = check
|
||||
|
||||
def matches(self, exception: BaseException) -> TypeGuard[MatchE]:
|
||||
"""Check if an exception matches the requirements of this Matcher.
|
||||
|
||||
Examples::
|
||||
|
||||
assert Matcher(ValueError).matches(my_exception):
|
||||
# is equivalent to
|
||||
assert isinstance(my_exception, ValueError)
|
||||
|
||||
# this can be useful when checking e.g. the ``__cause__`` of an exception.
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
...
|
||||
assert Matcher(SyntaxError, match="foo").matches(excinfo.value.__cause__)
|
||||
# above line is equivalent to
|
||||
assert isinstance(excinfo.value.__cause__, SyntaxError)
|
||||
assert re.search("foo", str(excinfo.value.__cause__)
|
||||
|
||||
"""
|
||||
if self.exception_type is not None and not isinstance(
|
||||
exception,
|
||||
self.exception_type,
|
||||
):
|
||||
return False
|
||||
if self.match is not None and not re.search(
|
||||
self.match,
|
||||
_stringify_exception(exception),
|
||||
):
|
||||
return False
|
||||
# If exception_type is None check() accepts BaseException.
|
||||
# If non-none, we have done an isinstance check above.
|
||||
return self.check is None or self.check(cast(MatchE, exception))
|
||||
|
||||
def __str__(self) -> str:
|
||||
reqs = []
|
||||
if self.exception_type is not None:
|
||||
reqs.append(self.exception_type.__name__)
|
||||
if (match := self.match) is not None:
|
||||
# If no flags were specified, discard the redundant re.compile() here.
|
||||
reqs.append(
|
||||
f"match={match.pattern if match.flags == _regex_no_flags else match!r}",
|
||||
)
|
||||
if self.check is not None:
|
||||
reqs.append(f"check={self.check!r}")
|
||||
return f'Matcher({", ".join(reqs)})'
|
||||
|
||||
|
||||
# typing this has been somewhat of a nightmare, with the primary difficulty making
|
||||
# the return type of __enter__ correct. Ideally it would function like this
|
||||
# with RaisesGroup(RaisesGroup(ValueError)) as excinfo:
|
||||
# ...
|
||||
# assert_type(excinfo.value, ExceptionGroup[ExceptionGroup[ValueError]])
|
||||
# in addition to all the simple cases, but getting all the way to the above seems maybe
|
||||
# impossible. The type being RaisesGroup[RaisesGroup[ValueError]] is probably also fine,
|
||||
# as long as I add fake properties corresponding to the properties of exceptiongroup. But
|
||||
# I had trouble with it handling recursive cases properly.
|
||||
|
||||
# Current solution settles on the above giving BaseExceptionGroup[RaisesGroup[ValueError]], and it not
|
||||
# being a type error to do `with RaisesGroup(ValueError()): ...` - but that will error on runtime.
|
||||
|
||||
# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups
|
||||
if TYPE_CHECKING:
|
||||
SuperClass = BaseExceptionGroup
|
||||
else:
|
||||
# At runtime, use a redundant Generic base class which effectively gets ignored.
|
||||
SuperClass = Generic
|
||||
|
||||
|
||||
@final
|
||||
class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E]):
|
||||
"""Contextmanager for checking for an expected `ExceptionGroup`.
|
||||
This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538
|
||||
|
||||
|
||||
The catching behaviour differs from :ref:`except* <except_star>` in multiple different ways, being much stricter by default. By using ``allow_unwrapped=True`` and ``flatten_subgroups=True`` you can match ``except*`` fully when expecting a single exception.
|
||||
|
||||
#. All specified exceptions must be present, *and no others*.
|
||||
|
||||
* If you expect a variable number of exceptions you need to use ``pytest.raises(ExceptionGroup)`` and manually check the contained exceptions. Consider making use of :func:`Matcher.matches`.
|
||||
|
||||
#. It will only catch exceptions wrapped in an exceptiongroup by default.
|
||||
|
||||
* With ``allow_unwrapped=True`` you can specify a single expected exception or `Matcher` and it will match the exception even if it is not inside an `ExceptionGroup`. If you expect one of several different exception types you need to use a `Matcher` object.
|
||||
|
||||
#. By default it cares about the full structure with nested `ExceptionGroup`'s. You can specify nested `ExceptionGroup`'s by passing `RaisesGroup` objects as expected exceptions.
|
||||
|
||||
* With ``flatten_subgroups=True`` it will "flatten" the raised `ExceptionGroup`, extracting all exceptions inside any nested :class:`ExceptionGroup`, before matching.
|
||||
|
||||
It currently does not care about the order of the exceptions, so ``RaisesGroups(ValueError, TypeError)`` is equivalent to ``RaisesGroups(TypeError, ValueError)``.
|
||||
|
||||
This class is not as polished as ``pytest.raises``, and is currently not as helpful in e.g. printing diffs when strings don't match, suggesting you use ``re.escape``, etc.
|
||||
|
||||
Examples::
|
||||
|
||||
with RaisesGroups(ValueError):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
with RaisesGroups(ValueError, ValueError, Matcher(TypeError, match="expected int")):
|
||||
...
|
||||
with RaisesGroups(KeyboardInterrupt, match="hello", check=lambda x: type(x) is BaseExceptionGroup):
|
||||
...
|
||||
with RaisesGroups(RaisesGroups(ValueError)):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
|
||||
# flatten_subgroups
|
||||
with RaisesGroups(ValueError, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
|
||||
# allow_unwrapped
|
||||
with RaisesGroups(ValueError, allow_unwrapped=True):
|
||||
raise ValueError
|
||||
|
||||
|
||||
`RaisesGroup.matches` can also be used directly to check a standalone exception group.
|
||||
|
||||
|
||||
The matching algorithm is greedy, which means cases such as this may fail::
|
||||
|
||||
with RaisesGroups(ValueError, Matcher(ValueError, match="hello")):
|
||||
raise ExceptionGroup("", (ValueError("hello"), ValueError("goodbye")))
|
||||
|
||||
even though it generally does not care about the order of the exceptions in the group.
|
||||
To avoid the above you should specify the first ValueError with a Matcher as well.
|
||||
|
||||
It is also not typechecked perfectly, and that's likely not possible with the current approach. Most common usage should work without issue though.
|
||||
"""
|
||||
|
||||
# needed for pyright, since BaseExceptionGroup.__new__ takes two arguments
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __new__(cls, *args: object, **kwargs: object) -> RaisesGroup[E]: ...
|
||||
|
||||
# allow_unwrapped=True requires: singular exception, exception not being
|
||||
# RaisesGroup instance, match is None, check is None
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
exception: type[E] | Matcher[E],
|
||||
*,
|
||||
allow_unwrapped: Literal[True],
|
||||
flatten_subgroups: bool = False,
|
||||
match: None = None,
|
||||
check: None = None,
|
||||
): ...
|
||||
|
||||
# flatten_subgroups = True also requires no nested RaisesGroup
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
exception: type[E] | Matcher[E],
|
||||
*other_exceptions: type[E] | Matcher[E],
|
||||
allow_unwrapped: Literal[False] = False,
|
||||
flatten_subgroups: Literal[True],
|
||||
match: str | Pattern[str] | None = None,
|
||||
check: Callable[[BaseExceptionGroup[E]], bool] | None = None,
|
||||
): ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
exception: type[E] | Matcher[E] | E,
|
||||
*other_exceptions: type[E] | Matcher[E] | E,
|
||||
allow_unwrapped: Literal[False] = False,
|
||||
flatten_subgroups: Literal[False] = False,
|
||||
match: str | Pattern[str] | None = None,
|
||||
check: Callable[[BaseExceptionGroup[E]], bool] | None = None,
|
||||
): ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exception: type[E] | Matcher[E] | E,
|
||||
*other_exceptions: type[E] | Matcher[E] | E,
|
||||
allow_unwrapped: bool = False,
|
||||
flatten_subgroups: bool = False,
|
||||
match: str | Pattern[str] | None = None,
|
||||
check: Callable[[BaseExceptionGroup[E]], bool] | None = None,
|
||||
):
|
||||
self.expected_exceptions: tuple[type[E] | Matcher[E] | E, ...] = (
|
||||
exception,
|
||||
*other_exceptions,
|
||||
)
|
||||
self.flatten_subgroups: bool = flatten_subgroups
|
||||
self.allow_unwrapped = allow_unwrapped
|
||||
self.match_expr = match
|
||||
self.check = check
|
||||
self.is_baseexceptiongroup = False
|
||||
|
||||
if allow_unwrapped and other_exceptions:
|
||||
raise ValueError(
|
||||
"You cannot specify multiple exceptions with `allow_unwrapped=True.`"
|
||||
" If you want to match one of multiple possible exceptions you should"
|
||||
" use a `Matcher`."
|
||||
" E.g. `Matcher(check=lambda e: isinstance(e, (...)))`",
|
||||
)
|
||||
if allow_unwrapped and isinstance(exception, RaisesGroup):
|
||||
raise ValueError(
|
||||
"`allow_unwrapped=True` has no effect when expecting a `RaisesGroup`."
|
||||
" You might want it in the expected `RaisesGroup`, or"
|
||||
" `flatten_subgroups=True` if you don't care about the structure.",
|
||||
)
|
||||
if allow_unwrapped and (match is not None or check is not None):
|
||||
raise ValueError(
|
||||
"`allow_unwrapped=True` bypasses the `match` and `check` parameters"
|
||||
" if the exception is unwrapped. If you intended to match/check the"
|
||||
" exception you should use a `Matcher` object. If you want to match/check"
|
||||
" the exceptiongroup when the exception *is* wrapped you need to"
|
||||
" do e.g. `if isinstance(exc.value, ExceptionGroup):"
|
||||
" assert RaisesGroup(...).matches(exc.value)` afterwards.",
|
||||
)
|
||||
|
||||
# verify `expected_exceptions` and set `self.is_baseexceptiongroup`
|
||||
for exc in self.expected_exceptions:
|
||||
if isinstance(exc, RaisesGroup):
|
||||
if self.flatten_subgroups:
|
||||
raise ValueError(
|
||||
"You cannot specify a nested structure inside a RaisesGroup with"
|
||||
" `flatten_subgroups=True`. The parameter will flatten subgroups"
|
||||
" in the raised exceptiongroup before matching, which would never"
|
||||
" match a nested structure.",
|
||||
)
|
||||
self.is_baseexceptiongroup |= exc.is_baseexceptiongroup
|
||||
elif isinstance(exc, Matcher):
|
||||
# The Matcher could match BaseExceptions through the other arguments
|
||||
# but `self.is_baseexceptiongroup` is only used for printing.
|
||||
if exc.exception_type is None:
|
||||
continue
|
||||
# Matcher __init__ assures it's a subclass of BaseException
|
||||
self.is_baseexceptiongroup |= not issubclass(
|
||||
exc.exception_type,
|
||||
Exception,
|
||||
)
|
||||
elif isinstance(exc, type) and issubclass(exc, BaseException):
|
||||
self.is_baseexceptiongroup |= not issubclass(exc, Exception)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'Invalid argument "{exc!r}" must be exception type, Matcher, or'
|
||||
" RaisesGroup.",
|
||||
)
|
||||
|
||||
def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[E]]:
|
||||
self.excinfo: ExceptionInfo[BaseExceptionGroup[E]] = ExceptionInfo.for_later()
|
||||
return self.excinfo
|
||||
|
||||
def _unroll_exceptions(
|
||||
self,
|
||||
exceptions: Sequence[BaseException],
|
||||
) -> Sequence[BaseException]:
|
||||
"""Used if `flatten_subgroups=True`."""
|
||||
res: list[BaseException] = []
|
||||
for exc in exceptions:
|
||||
if isinstance(exc, BaseExceptionGroup):
|
||||
res.extend(self._unroll_exceptions(exc.exceptions))
|
||||
|
||||
else:
|
||||
res.append(exc)
|
||||
return res
|
||||
|
||||
def matches(
|
||||
self,
|
||||
exc_val: BaseException | None,
|
||||
) -> TypeGuard[BaseExceptionGroup[E]]:
|
||||
"""Check if an exception matches the requirements of this RaisesGroup.
|
||||
|
||||
Example::
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
...
|
||||
assert RaisesGroups(ValueError).matches(excinfo.value.__cause__)
|
||||
# the above line is equivalent to
|
||||
myexc = excinfo.value.__cause
|
||||
assert isinstance(myexc, BaseExceptionGroup)
|
||||
assert len(myexc.exceptions) == 1
|
||||
assert isinstance(myexc.exceptions[0], ValueError)
|
||||
"""
|
||||
if exc_val is None:
|
||||
return False
|
||||
# TODO: print/raise why a match fails, in a way that works properly in nested cases
|
||||
# maybe have a list of strings logging failed matches, that __exit__ can
|
||||
# recursively step through and print on a failing match.
|
||||
if not isinstance(exc_val, BaseExceptionGroup):
|
||||
if self.allow_unwrapped:
|
||||
exp_exc = self.expected_exceptions[0]
|
||||
if isinstance(exp_exc, Matcher) and exp_exc.matches(exc_val):
|
||||
return True
|
||||
if isinstance(exp_exc, type) and isinstance(exc_val, exp_exc):
|
||||
return True
|
||||
return False
|
||||
|
||||
if self.match_expr is not None and not re.search(
|
||||
self.match_expr,
|
||||
_stringify_exception(exc_val),
|
||||
):
|
||||
return False
|
||||
if self.check is not None and not self.check(exc_val):
|
||||
return False
|
||||
|
||||
remaining_exceptions = list(self.expected_exceptions)
|
||||
actual_exceptions: Sequence[BaseException] = exc_val.exceptions
|
||||
if self.flatten_subgroups:
|
||||
actual_exceptions = self._unroll_exceptions(actual_exceptions)
|
||||
|
||||
# important to check the length *after* flattening subgroups
|
||||
if len(actual_exceptions) != len(self.expected_exceptions):
|
||||
return False
|
||||
|
||||
# it should be possible to get RaisesGroup.matches typed so as not to
|
||||
# need type: ignore, but I'm not sure that's possible while also having it
|
||||
# transparent for the end user.
|
||||
for e in actual_exceptions:
|
||||
for rem_e in remaining_exceptions:
|
||||
if (
|
||||
(isinstance(rem_e, type) and isinstance(e, rem_e))
|
||||
or (isinstance(rem_e, RaisesGroup) and rem_e.matches(e))
|
||||
or (isinstance(rem_e, Matcher) and rem_e.matches(e))
|
||||
):
|
||||
remaining_exceptions.remove(rem_e) # type: ignore[arg-type]
|
||||
break
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: types.TracebackType | None,
|
||||
) -> bool:
|
||||
__tracebackhide__ = True
|
||||
assert (
|
||||
exc_type is not None
|
||||
), f"DID NOT RAISE any exception, expected {self.expected_type()}"
|
||||
assert (
|
||||
self.excinfo is not None
|
||||
), "Internal error - should have been constructed in __enter__"
|
||||
|
||||
if not self.matches(exc_val):
|
||||
return False
|
||||
|
||||
# Cast to narrow the exception type now that it's verified.
|
||||
exc_info = cast(
|
||||
"tuple[type[BaseExceptionGroup[E]], BaseExceptionGroup[E], types.TracebackType]",
|
||||
(exc_type, exc_val, exc_tb),
|
||||
)
|
||||
self.excinfo.fill_unfilled(exc_info)
|
||||
return True
|
||||
|
||||
def expected_type(self) -> str:
|
||||
subexcs = []
|
||||
for e in self.expected_exceptions:
|
||||
if isinstance(e, Matcher):
|
||||
subexcs.append(str(e))
|
||||
elif isinstance(e, RaisesGroup):
|
||||
subexcs.append(e.expected_type())
|
||||
elif isinstance(e, type):
|
||||
subexcs.append(e.__name__)
|
||||
else: # pragma: no cover
|
||||
raise AssertionError("unknown type")
|
||||
group_type = "Base" if self.is_baseexceptiongroup else ""
|
||||
return f"{group_type}ExceptionGroup({', '.join(subexcs)})"
|
87
lib/python3.13/site-packages/trio/testing/_sequencer.py
Normal file
87
lib/python3.13/site-packages/trio/testing/_sequencer.py
Normal file
@ -0,0 +1,87 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import attrs
|
||||
|
||||
from .. import Event, _core, _util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
|
||||
@_util.final
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class Sequencer:
|
||||
"""A convenience class for forcing code in different tasks to run in an
|
||||
explicit linear order.
|
||||
|
||||
Instances of this class implement a ``__call__`` method which returns an
|
||||
async context manager. The idea is that you pass a sequence number to
|
||||
``__call__`` to say where this block of code should go in the linear
|
||||
sequence. Block 0 starts immediately, and then block N doesn't start until
|
||||
block N-1 has finished.
|
||||
|
||||
Example:
|
||||
An extremely elaborate way to print the numbers 0-5, in order::
|
||||
|
||||
async def worker1(seq):
|
||||
async with seq(0):
|
||||
print(0)
|
||||
async with seq(4):
|
||||
print(4)
|
||||
|
||||
async def worker2(seq):
|
||||
async with seq(2):
|
||||
print(2)
|
||||
async with seq(5):
|
||||
print(5)
|
||||
|
||||
async def worker3(seq):
|
||||
async with seq(1):
|
||||
print(1)
|
||||
async with seq(3):
|
||||
print(3)
|
||||
|
||||
async def main():
|
||||
seq = trio.testing.Sequencer()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(worker1, seq)
|
||||
nursery.start_soon(worker2, seq)
|
||||
nursery.start_soon(worker3, seq)
|
||||
|
||||
"""
|
||||
|
||||
_sequence_points: defaultdict[int, Event] = attrs.field(
|
||||
factory=lambda: defaultdict(Event),
|
||||
init=False,
|
||||
)
|
||||
_claimed: set[int] = attrs.field(factory=set, init=False)
|
||||
_broken: bool = attrs.field(default=False, init=False)
|
||||
|
||||
@asynccontextmanager
|
||||
async def __call__(self, position: int) -> AsyncIterator[None]:
|
||||
if position in self._claimed:
|
||||
raise RuntimeError(f"Attempted to reuse sequence point {position}")
|
||||
if self._broken:
|
||||
raise RuntimeError("sequence broken!")
|
||||
self._claimed.add(position)
|
||||
if position != 0:
|
||||
try:
|
||||
await self._sequence_points[position].wait()
|
||||
except _core.Cancelled:
|
||||
self._broken = True
|
||||
for event in self._sequence_points.values():
|
||||
event.set()
|
||||
raise RuntimeError(
|
||||
"Sequencer wait cancelled -- sequence broken",
|
||||
) from None
|
||||
else:
|
||||
if self._broken:
|
||||
raise RuntimeError("sequence broken!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._sequence_points[position + 1].set()
|
50
lib/python3.13/site-packages/trio/testing/_trio_test.py
Normal file
50
lib/python3.13/site-packages/trio/testing/_trio_test.py
Normal file
@ -0,0 +1,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial, wraps
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from .. import _core
|
||||
from ..abc import Clock, Instrument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
ArgsT = ParamSpec("ArgsT")
|
||||
|
||||
|
||||
RetT = TypeVar("RetT")
|
||||
|
||||
|
||||
def trio_test(fn: Callable[ArgsT, Awaitable[RetT]]) -> Callable[ArgsT, RetT]:
|
||||
"""Converts an async test function to be synchronous, running via Trio.
|
||||
|
||||
Usage::
|
||||
|
||||
@trio_test
|
||||
async def test_whatever():
|
||||
await ...
|
||||
|
||||
If a pytest fixture is passed in that subclasses the :class:`~trio.abc.Clock` or
|
||||
:class:`~trio.abc.Instrument` ABCs, then those are passed to :meth:`trio.run()`.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
|
||||
__tracebackhide__ = True
|
||||
clocks = [c for c in kwargs.values() if isinstance(c, Clock)]
|
||||
if not clocks:
|
||||
clock = None
|
||||
elif len(clocks) == 1:
|
||||
clock = clocks[0]
|
||||
else:
|
||||
raise ValueError("too many clocks spoil the broth!")
|
||||
instruments = [i for i in kwargs.values() if isinstance(i, Instrument)]
|
||||
return _core.run(
|
||||
partial(fn, *args, **kwargs),
|
||||
clock=clock,
|
||||
instruments=instruments,
|
||||
)
|
||||
|
||||
return wrapper
|
Reference in New Issue
Block a user