Updated script that can be controled by Nodejs web app

This commit is contained in:
mac OS
2024-11-25 12:24:18 +07:00
parent c440eda1f4
commit 8b0ab2bd3a
8662 changed files with 1803808 additions and 34 deletions

View File

@ -0,0 +1,246 @@
#!/usr/bin/env python3
"""This is a file that wraps calls to `pyright --verifytypes`, achieving two things:
1. give an error if docstrings are missing.
pyright will give a number of missing docstrings, and error messages, but not exit with a non-zero value.
2. filter out specific errors we don't care about.
this is largely due to 1, but also because Trio does some very complex stuff and --verifytypes has few to no ways of ignoring specific errors.
If this check is giving you false alarms, you can ignore them by adding logic to `has_docstring_at_runtime`, in the main loop in `check_type`, or by updating the json file.
"""
from __future__ import annotations
# this file is not run as part of the tests, instead it's run standalone from check.sh
import argparse
import json
import subprocess
import sys
from pathlib import Path
import trio
import trio.testing
# not needed if everything is working, but if somebody does something to generate
# tons of errors, we can be nice and stop them from getting 3*tons of output
printed_diagnostics: set[str] = set()
# TODO: consider checking manually without `--ignoreexternal`, and/or
# removing it from the below call later on.
def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]:
return subprocess.run(
[
"pyright",
# Specify a platform and version to keep imported modules consistent.
f"--pythonplatform={platform}",
"--pythonversion=3.8",
"--verifytypes=trio",
"--outputjson",
"--ignoreexternal",
],
capture_output=True,
)
def has_docstring_at_runtime(name: str) -> bool:
"""Pyright gives us an object identifier of xx.yy.zz
This function tries to decompose that into its constituent parts, such that we
can resolve it, in order to check whether it has a `__doc__` at runtime and
verifytypes misses it because we're doing overly fancy stuff.
"""
# This assert is solely for stopping isort from removing our imports of trio & trio.testing
# It could also be done with isort:skip, but that'd also disable import sorting and the like.
assert trio.testing
# figure out what part of the name is the module, so we can "import" it
name_parts = name.split(".")
assert name_parts[0] == "trio"
if name_parts[1] == "tests":
return True
# traverse down the remaining identifiers with getattr
obj = trio
try:
for obj_name in name_parts[1:]:
obj = getattr(obj, obj_name)
except AttributeError as exc:
# asynciowrapper does funky getattr stuff
if "AsyncIOWrapper" in str(exc) or name in (
# Symbols not existing on all platforms, so we can't dynamically inspect them.
# Manually confirmed to have docstrings but pyright doesn't see them due to
# export shenanigans. TODO: actually manually confirm that.
# In theory we could verify these at runtime, probably by running the script separately
# on separate platforms. It might also be a decent idea to work the other way around,
# a la test_static_tool_sees_class_members
# darwin
"trio.lowlevel.current_kqueue",
"trio.lowlevel.monitor_kevent",
"trio.lowlevel.wait_kevent",
"trio._core._io_kqueue._KqueueStatistics",
# windows
"trio._socket.SocketType.share",
"trio._core._io_windows._WindowsStatistics",
"trio._core._windows_cffi.Handle",
"trio.lowlevel.current_iocp",
"trio.lowlevel.monitor_completion_key",
"trio.lowlevel.readinto_overlapped",
"trio.lowlevel.register_with_iocp",
"trio.lowlevel.wait_overlapped",
"trio.lowlevel.write_overlapped",
"trio.lowlevel.WaitForSingleObject",
"trio.socket.fromshare",
# linux
# this test will fail on linux, but I don't develop on linux. So the next
# person to do so is very welcome to open a pull request and populate with
# objects
# TODO: these are erroring on all platforms, why?
"trio._highlevel_generic.StapledStream.send_stream",
"trio._highlevel_generic.StapledStream.receive_stream",
"trio._ssl.SSLStream.transport_stream",
"trio._file_io._HasFileNo",
"trio._file_io._HasFileNo.fileno",
):
return True
else:
print(
f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).",
file=sys.stderr,
)
return False
return bool(obj.__doc__)
def check_type(
platform: str,
full_diagnostics_file: Path | None,
expected_errors: list[object],
) -> list[object]:
# convince isort we use the trio import
assert trio
# run pyright, load output into json
res = run_pyright(platform)
current_result = json.loads(res.stdout)
if res.stderr:
print(res.stderr, file=sys.stderr)
if full_diagnostics_file:
with open(full_diagnostics_file, "a") as f:
json.dump(current_result, f, sort_keys=True, indent=4)
errors = []
for symbol in current_result["typeCompleteness"]["symbols"]:
diagnostics = symbol["diagnostics"]
name = symbol["name"]
for diagnostic in diagnostics:
message = diagnostic["message"]
if name in (
"trio._path.PosixPath",
"trio._path.WindowsPath",
) and message.startswith("Type of base class "):
continue
if name.startswith("trio._path.Path"):
if message.startswith("No docstring found for"):
continue
if message.startswith(
"Type is missing type annotation and could be inferred differently by type checkers",
):
continue
# ignore errors about missing docstrings if they're available at runtime
if message.startswith("No docstring found for"):
if has_docstring_at_runtime(symbol["name"]):
continue
else:
# Missing docstring messages include the name of the object.
# Other errors don't, so we add it.
message = f"{name}: {message}"
if message not in expected_errors and message not in printed_diagnostics:
print(f"new error: {message}", file=sys.stderr)
errors.append(message)
printed_diagnostics.add(message)
continue
return errors
def main(args: argparse.Namespace) -> int:
if args.full_diagnostics_file:
full_diagnostics_file = Path(args.full_diagnostics_file)
full_diagnostics_file.write_text("")
else:
full_diagnostics_file = None
errors_by_platform_file = Path(__file__).parent / "_check_type_completeness.json"
if errors_by_platform_file.exists():
with open(errors_by_platform_file) as f:
errors_by_platform = json.load(f)
else:
errors_by_platform = {"Linux": [], "Windows": [], "Darwin": [], "all": []}
changed = False
for platform in "Linux", "Windows", "Darwin":
platform_errors = errors_by_platform[platform] + errors_by_platform["all"]
print("*" * 20, f"\nChecking {platform}...")
errors = check_type(platform, full_diagnostics_file, platform_errors)
new_errors = [e for e in errors if e not in platform_errors]
missing_errors = [e for e in platform_errors if e not in errors]
if new_errors:
print(
f"New errors introduced in `pyright --verifytypes`. Fix them, or ignore them by modifying {errors_by_platform_file}, either manually or with '--overwrite-file'.",
file=sys.stderr,
)
changed = True
if missing_errors:
print(
f"Congratulations, you have resolved existing errors! Please remove them from {errors_by_platform_file}, either manually or with '--overwrite-file'.",
file=sys.stderr,
)
changed = True
print(missing_errors, file=sys.stderr)
errors_by_platform[platform] = errors
print("*" * 20)
# cut down the size of the json file by a lot, and make it easier to parse for
# humans, by moving errors that appear on all platforms to a separate category
errors_by_platform["all"] = []
for e in errors_by_platform["Linux"].copy():
if e in errors_by_platform["Darwin"] and e in errors_by_platform["Windows"]:
for platform in "Linux", "Windows", "Darwin":
errors_by_platform[platform].remove(e)
errors_by_platform["all"].append(e)
if changed and args.overwrite_file:
with open(errors_by_platform_file, "w") as f:
json.dump(errors_by_platform, f, indent=4, sort_keys=True)
# newline at end of file
f.write("\n")
# True -> 1 -> non-zero exit value -> error
return changed
parser = argparse.ArgumentParser()
parser.add_argument(
"--overwrite-file",
action="store_true",
default=False,
help="Use this flag to overwrite the current stored results. Either in CI together with a diff check, or to avoid having to manually correct it.",
)
parser.add_argument(
"--full-diagnostics-file",
type=Path,
default=None,
help="Use this for debugging, it will dump the output of all three pyright runs by platform into this file.",
)
args = parser.parse_args()
assert __name__ == "__main__", "This script should be run standalone"
sys.exit(main(args))

View File

@ -0,0 +1,24 @@
regular = "hi"
from .. import _deprecate
_deprecate.enable_attribute_deprecations(__name__)
# Make sure that we don't trigger infinite recursion when accessing module
# attributes in between calling enable_attribute_deprecations and defining
# __deprecated_attributes__:
import sys
this_mod = sys.modules[__name__]
assert this_mod.regular == "hi"
assert not hasattr(this_mod, "dep1")
__deprecated_attributes__ = {
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
"dep2": _deprecate.DeprecatedAttribute(
"value2",
"1.2",
issue=1,
instead="instead-string",
),
}

View File

@ -0,0 +1,54 @@
from __future__ import annotations
import inspect
from typing import NoReturn
import pytest
from ..testing import MockClock, trio_test
RUN_SLOW = True
SKIP_OPTIONAL_IMPORTS = False
def pytest_addoption(parser: pytest.Parser) -> None:
parser.addoption("--run-slow", action="store_true", help="run slow tests")
parser.addoption(
"--skip-optional-imports",
action="store_true",
help="skip tests that rely on libraries not required by trio itself",
)
def pytest_configure(config: pytest.Config) -> None:
global RUN_SLOW
RUN_SLOW = config.getoption("--run-slow", default=True)
global SKIP_OPTIONAL_IMPORTS
SKIP_OPTIONAL_IMPORTS = config.getoption("--skip-optional-imports", default=False)
@pytest.fixture
def mock_clock() -> MockClock:
return MockClock()
@pytest.fixture
def autojump_clock() -> MockClock:
return MockClock(autojump_threshold=0)
# FIXME: split off into a package (or just make part of Trio's public
# interface?), with config file to enable? and I guess a mark option too; I
# guess it's useful with the class- and file-level marking machinery (where
# the raw @trio_test decorator isn't enough).
@pytest.hookimpl(tryfirst=True)
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
if inspect.iscoroutinefunction(pyfuncitem.obj):
pyfuncitem.obj = trio_test(pyfuncitem.obj)
def skip_if_optional_else_raise(error: ImportError) -> NoReturn:
if SKIP_OPTIONAL_IMPORTS:
pytest.skip(error.msg, allow_module_level=True)
else: # pragma: no cover
raise error

View File

@ -0,0 +1,72 @@
from __future__ import annotations
import attrs
import pytest
from .. import abc as tabc
from ..lowlevel import Task
def test_instrument_implements_hook_methods() -> None:
attrs = {
"before_run": (),
"after_run": (),
"task_spawned": (Task,),
"task_scheduled": (Task,),
"before_task_step": (Task,),
"after_task_step": (Task,),
"task_exited": (Task,),
"before_io_wait": (3.3,),
"after_io_wait": (3.3,),
}
mayonnaise = tabc.Instrument()
for method_name, args in attrs.items():
assert hasattr(mayonnaise, method_name)
method = getattr(mayonnaise, method_name)
assert callable(method)
method(*args)
async def test_AsyncResource_defaults() -> None:
@attrs.define(slots=False)
class MyAR(tabc.AsyncResource):
record: list[str] = attrs.Factory(list)
async def aclose(self) -> None:
self.record.append("ac")
async with MyAR() as myar:
assert isinstance(myar, MyAR)
assert myar.record == []
assert myar.record == ["ac"]
def test_abc_generics() -> None:
# Pythons below 3.5.2 had a typing.Generic that would throw
# errors when instantiating or subclassing a parameterized
# version of a class with any __slots__. This is why RunVar
# (which has slots) is not generic. This tests that
# the generic ABCs are fine, because while they are slotted
# they don't actually define any slots.
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
__slots__ = ("x",)
def send_nowait(self, value: object) -> None:
raise RuntimeError
async def send(self, value: object) -> None:
raise RuntimeError # pragma: no cover
def clone(self) -> None:
raise RuntimeError # pragma: no cover
async def aclose(self) -> None:
pass # pragma: no cover
channel = SlottedChannel()
with pytest.raises(RuntimeError):
channel.send_nowait(None)

View File

@ -0,0 +1,413 @@
from __future__ import annotations
from typing import Union
import pytest
import trio
from trio import EndOfChannel, open_memory_channel
from ..testing import assert_checkpoints, wait_all_tasks_blocked
async def test_channel() -> None:
with pytest.raises(TypeError):
open_memory_channel(1.0)
with pytest.raises(ValueError, match="^max_buffer_size must be >= 0$"):
open_memory_channel(-1)
s, r = open_memory_channel[Union[int, str, None]](2)
repr(s) # smoke test
repr(r) # smoke test
s.send_nowait(1)
with assert_checkpoints():
await s.send(2)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
with assert_checkpoints():
assert await r.receive() == 1
assert r.receive_nowait() == 2
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
s.send_nowait("last")
await s.aclose()
with pytest.raises(trio.ClosedResourceError):
await s.send("too late")
with pytest.raises(trio.ClosedResourceError):
s.send_nowait("too late")
with pytest.raises(trio.ClosedResourceError):
s.clone()
await s.aclose()
assert r.receive_nowait() == "last"
with pytest.raises(EndOfChannel):
await r.receive()
await r.aclose()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
await r.aclose()
async def test_553(autojump_clock: trio.abc.Clock) -> None:
s, r = open_memory_channel[str](1)
with trio.move_on_after(10) as timeout_scope:
await r.receive()
assert timeout_scope.cancelled_caught
await s.send("Test for PR #553")
async def test_channel_multiple_producers() -> None:
async def producer(send_channel: trio.MemorySendChannel[int], i: int) -> None:
# We close our handle when we're done with it
async with send_channel:
for j in range(3 * i, 3 * (i + 1)):
await send_channel.send(j)
send_channel, receive_channel = open_memory_channel[int](0)
async with trio.open_nursery() as nursery:
# We hand out clones to all the new producers, and then close the
# original.
async with send_channel:
for i in range(10):
nursery.start_soon(producer, send_channel.clone(), i)
got = [value async for value in receive_channel]
got.sort()
assert got == list(range(30))
async def test_channel_multiple_consumers() -> None:
successful_receivers = set()
received = []
async def consumer(receive_channel: trio.MemoryReceiveChannel[int], i: int) -> None:
async for value in receive_channel:
successful_receivers.add(i)
received.append(value)
async with trio.open_nursery() as nursery:
send_channel, receive_channel = trio.open_memory_channel[int](1)
async with send_channel:
for i in range(5):
nursery.start_soon(consumer, receive_channel, i)
await wait_all_tasks_blocked()
for i in range(10):
await send_channel.send(i)
assert successful_receivers == set(range(5))
assert len(received) == 10
assert set(received) == set(range(10))
async def test_close_basics() -> None:
async def send_block(
s: trio.MemorySendChannel[None],
expect: type[BaseException],
) -> None:
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
await s.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
await r.aclose()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None:
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s2, r2 = open_memory_channel[int](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r2)
await wait_all_tasks_blocked()
await r2.aclose()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r2.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r2.receive()
async def test_close_sync() -> None:
async def send_block(
s: trio.MemorySendChannel[None],
expect: type[BaseException],
) -> None:
with pytest.raises(expect):
await s.send(None)
# closing send -> other send gets ClosedResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.ClosedResourceError)
await wait_all_tasks_blocked()
s.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
s.send_nowait(None)
with pytest.raises(trio.ClosedResourceError):
await s.send(None)
# and receive gets EndOfChannel
with pytest.raises(EndOfChannel):
r.receive_nowait()
with pytest.raises(EndOfChannel):
await r.receive()
# closing receive -> send gets BrokenResourceError
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(send_block, s, trio.BrokenResourceError)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
with pytest.raises(trio.BrokenResourceError):
await s.send(None)
# closing receive -> other receive gets ClosedResourceError
async def receive_block(r: trio.MemoryReceiveChannel[None]) -> None:
with pytest.raises(trio.ClosedResourceError):
await r.receive()
s, r = open_memory_channel[None](0)
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_block, r)
await wait_all_tasks_blocked()
r.close()
# and it's persistent
with pytest.raises(trio.ClosedResourceError):
r.receive_nowait()
with pytest.raises(trio.ClosedResourceError):
await r.receive()
async def test_receive_channel_clone_and_close() -> None:
s, r = open_memory_channel[None](10)
r2 = r.clone()
r3 = r.clone()
s.send_nowait(None)
await r.aclose()
with r2:
pass
with pytest.raises(trio.ClosedResourceError):
r.clone()
with pytest.raises(trio.ClosedResourceError):
r2.clone()
# Can still send, r3 is still open
s.send_nowait(None)
await r3.aclose()
# But now the receiver is really closed
with pytest.raises(trio.BrokenResourceError):
s.send_nowait(None)
async def test_close_multiple_send_handles() -> None:
# With multiple send handles, closing one handle only wakes senders on
# that handle, but others can continue just fine
s1, r = open_memory_channel[str](0)
s2 = s1.clone()
async def send_will_close() -> None:
with pytest.raises(trio.ClosedResourceError):
await s1.send("nope")
async def send_will_succeed() -> None:
await s2.send("ok")
async with trio.open_nursery() as nursery:
nursery.start_soon(send_will_close)
nursery.start_soon(send_will_succeed)
await wait_all_tasks_blocked()
await s1.aclose()
assert await r.receive() == "ok"
async def test_close_multiple_receive_handles() -> None:
# With multiple receive handles, closing one handle only wakes receivers on
# that handle, but others can continue just fine
s, r1 = open_memory_channel[str](0)
r2 = r1.clone()
async def receive_will_close() -> None:
with pytest.raises(trio.ClosedResourceError):
await r1.receive()
async def receive_will_succeed() -> None:
assert await r2.receive() == "ok"
async with trio.open_nursery() as nursery:
nursery.start_soon(receive_will_close)
nursery.start_soon(receive_will_succeed)
await wait_all_tasks_blocked()
await r1.aclose()
await s.send("ok")
async def test_inf_capacity() -> None:
send, receive = open_memory_channel[int](float("inf"))
# It's accepted, and we can send all day without blocking
with send:
for i in range(10):
send.send_nowait(i)
got = [i async for i in receive]
assert got == list(range(10))
async def test_statistics() -> None:
s, r = open_memory_channel[None](2)
assert s.statistics() == r.statistics()
stats = s.statistics()
assert stats.current_buffer_used == 0
assert stats.max_buffer_size == 2
assert stats.open_send_channels == 1
assert stats.open_receive_channels == 1
assert stats.tasks_waiting_send == 0
assert stats.tasks_waiting_receive == 0
s.send_nowait(None)
assert s.statistics().current_buffer_used == 1
s2 = s.clone()
assert s.statistics().open_send_channels == 2
await s.aclose()
assert s2.statistics().open_send_channels == 1
r2 = r.clone()
assert s2.statistics().open_receive_channels == 2
await r2.aclose()
assert s2.statistics().open_receive_channels == 1
async with trio.open_nursery() as nursery:
s2.send_nowait(None) # fill up the buffer
assert s.statistics().current_buffer_used == 2
nursery.start_soon(s2.send, None)
nursery.start_soon(s2.send, None)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_send == 2
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_send == 0
# empty out the buffer again
try:
while True:
r.receive_nowait()
except trio.WouldBlock:
pass
async with trio.open_nursery() as nursery:
nursery.start_soon(r.receive)
await wait_all_tasks_blocked()
assert s.statistics().tasks_waiting_receive == 1
nursery.cancel_scope.cancel()
assert s.statistics().tasks_waiting_receive == 0
async def test_channel_fairness() -> None:
# We can remove an item we just sent, and send an item back in after, if
# no-one else is waiting.
s, r = open_memory_channel[Union[int, None]](1)
s.send_nowait(1)
assert r.receive_nowait() == 1
s.send_nowait(2)
assert r.receive_nowait() == 2
# But if someone else is waiting to receive, then they "own" the item we
# send, so we can't receive it (even though we run first):
result: int | None = None
async def do_receive(r: trio.MemoryReceiveChannel[int | None]) -> None:
nonlocal result
result = await r.receive()
async with trio.open_nursery() as nursery:
nursery.start_soon(do_receive, r)
await wait_all_tasks_blocked()
s.send_nowait(2)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
assert result == 2
# And the analogous situation for send: if we free up a space, we can't
# immediately send something in it if someone is already waiting to do
# that
s, r = open_memory_channel[Union[int, None]](1)
s.send_nowait(1)
with pytest.raises(trio.WouldBlock):
s.send_nowait(None)
async with trio.open_nursery() as nursery:
nursery.start_soon(s.send, 2)
await wait_all_tasks_blocked()
assert r.receive_nowait() == 1
with pytest.raises(trio.WouldBlock):
s.send_nowait(3)
assert (await r.receive()) == 2
async def test_unbuffered() -> None:
s, r = open_memory_channel[int](0)
with pytest.raises(trio.WouldBlock):
r.receive_nowait()
with pytest.raises(trio.WouldBlock):
s.send_nowait(1)
async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
with assert_checkpoints():
await s.send(v)
async with trio.open_nursery() as nursery:
nursery.start_soon(do_send, s, 1)
with assert_checkpoints():
assert await r.receive() == 1
with pytest.raises(trio.WouldBlock):
r.receive_nowait()

View File

@ -0,0 +1,56 @@
from __future__ import annotations
import contextvars
from .. import _core
trio_testing_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
"trio_testing_contextvar",
)
async def test_contextvars_default() -> None:
trio_testing_contextvar.set("main")
record: list[str] = []
async def child() -> None:
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
assert record == ["main"]
async def test_contextvars_set() -> None:
trio_testing_contextvar.set("main")
record: list[str] = []
async def child() -> None:
trio_testing_contextvar.set("child")
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
value = trio_testing_contextvar.get()
assert record == ["child"]
assert value == "main"
async def test_contextvars_copy() -> None:
trio_testing_contextvar.set("main")
context = contextvars.copy_context()
trio_testing_contextvar.set("second_main")
record: list[str] = []
async def child() -> None:
value = trio_testing_contextvar.get()
record.append(value)
async with _core.open_nursery() as nursery:
context.run(nursery.start_soon, child)
nursery.start_soon(child)
value = trio_testing_contextvar.get()
assert set(record) == {"main", "second_main"}
assert value == "second_main"

View File

@ -0,0 +1,283 @@
from __future__ import annotations
import inspect
import warnings
import pytest
from .._deprecate import (
TrioDeprecationWarning,
deprecated,
deprecated_alias,
warn_deprecated,
)
from . import module_with_deprecations
@pytest.fixture
def recwarn_always(recwarn: pytest.WarningsRecorder) -> pytest.WarningsRecorder:
warnings.simplefilter("always")
# ResourceWarnings about unclosed sockets can occur nondeterministically
# (during GC) which throws off the tests in this file
warnings.simplefilter("ignore", ResourceWarning)
return recwarn
def _here() -> tuple[str, int]:
frame = inspect.currentframe()
assert frame is not None
assert frame.f_back is not None
info = inspect.getframeinfo(frame.f_back)
return (info.filename, info.lineno)
def test_warn_deprecated(recwarn_always: pytest.WarningsRecorder) -> None:
def deprecated_thing() -> None:
warn_deprecated("ice", "1.2", issue=1, instead="water")
deprecated_thing()
filename, lineno = _here()
assert len(recwarn_always) == 1
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "ice is deprecated" in got.message.args[0]
assert "Trio 1.2" in got.message.args[0]
assert "water instead" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert got.filename == filename
assert got.lineno == lineno - 1
def test_warn_deprecated_no_instead_or_issue(
recwarn_always: pytest.WarningsRecorder,
) -> None:
# Explicitly no instead or issue
warn_deprecated("water", "1.3", issue=None, instead=None)
assert len(recwarn_always) == 1
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "water is deprecated" in got.message.args[0]
assert "no replacement" in got.message.args[0]
assert "Trio 1.3" in got.message.args[0]
def test_warn_deprecated_stacklevel(recwarn_always: pytest.WarningsRecorder) -> None:
def nested1() -> None:
nested2()
def nested2() -> None:
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
filename, lineno = _here()
nested1()
got = recwarn_always.pop(DeprecationWarning)
assert got.filename == filename
assert got.lineno == lineno + 1
def old() -> None: # pragma: no cover
pass
def new() -> None: # pragma: no cover
pass
def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None:
warn_deprecated(old, "1.0", issue=1, instead=new)
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.old is deprecated" in got.message.args[0]
assert "test_deprecate.new instead" in got.message.args[0]
@deprecated("1.5", issue=123, instead=new)
def deprecated_old() -> int:
return 3
def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None:
assert deprecated_old() == 3
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
assert "1.5" in got.message.args[0]
assert "test_deprecate.new" in got.message.args[0]
assert "issues/123" in got.message.args[0]
class Foo:
@deprecated("1.0", issue=123, instead="crying")
def method(self) -> int:
return 7
def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None:
f = Foo()
assert f.method() == 7
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
@deprecated("1.2", thing="the thing", issue=None, instead=None)
def deprecated_with_thing() -> int:
return 72
def test_deprecated_decorator_with_explicit_thing(
recwarn_always: pytest.WarningsRecorder,
) -> None:
assert deprecated_with_thing() == 72
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "the thing is deprecated" in got.message.args[0]
def new_hotness() -> str:
return "new hotness"
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None:
assert old_hotness() == "new hotness"
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
assert "1.23" in got.message.args[0]
assert "test_deprecate.new_hotness instead" in got.message.args[0]
assert "issues/1" in got.message.args[0]
assert isinstance(old_hotness.__doc__, str)
assert ".. deprecated:: 1.23" in old_hotness.__doc__
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
assert "issues/1>`__" in old_hotness.__doc__
class Alias:
def new_hotness_method(self) -> str:
return "new hotness method"
old_hotness_method = deprecated_alias(
"Alias.old_hotness_method",
new_hotness_method,
"3.21",
issue=1,
)
def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None:
obj = Alias()
assert obj.old_hotness_method() == "new hotness method"
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
msg = got.message.args[0]
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
assert "test_deprecate.Alias.new_hotness_method instead" in msg
@deprecated("2.1", issue=1, instead="hi")
def docstring_test1() -> None: # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead="hi")
def docstring_test2() -> None: # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=1, instead=None)
def docstring_test3() -> None: # pragma: no cover
"""Hello!"""
@deprecated("2.1", issue=None, instead=None)
def docstring_test4() -> None: # pragma: no cover
"""Hello!"""
def test_deprecated_docstring_munging() -> None:
assert (
docstring_test1.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test2.__doc__
== """Hello!
.. deprecated:: 2.1
Use hi instead.
"""
)
assert (
docstring_test3.__doc__
== """Hello!
.. deprecated:: 2.1
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
"""
)
assert (
docstring_test4.__doc__
== """Hello!
.. deprecated:: 2.1
"""
)
def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> None:
assert module_with_deprecations.regular == "hi"
assert len(recwarn_always) == 0
filename, lineno = _here()
assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined]
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert got.filename == filename
assert got.lineno == lineno + 1
assert "module_with_deprecations.dep1" in got.message.args[0]
assert "Trio 1.1" in got.message.args[0]
assert "/issues/1" in got.message.args[0]
assert "value1 instead" in got.message.args[0]
assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined]
got = recwarn_always.pop(DeprecationWarning)
assert isinstance(got.message, Warning)
assert "instead-string instead" in got.message.args[0]
with pytest.raises(AttributeError):
module_with_deprecations.asdf # type: ignore[attr-defined] # noqa: B018 # "useless expression"
def test_warning_class() -> None:
with pytest.deprecated_call():
warn_deprecated("foo", "bar", issue=None, instead=None)
# essentially the same as the above check
with pytest.warns(DeprecationWarning):
warn_deprecated("foo", "bar", issue=None, instead=None)
with pytest.warns(TrioDeprecationWarning):
warn_deprecated(
"foo",
"bar",
issue=None,
instead=None,
use_triodeprecationwarning=True,
)

View File

@ -0,0 +1,64 @@
from typing import Awaitable, Callable
import pytest
import trio
async def test_deprecation_warning_open_nursery() -> None:
with pytest.warns(
trio.TrioDeprecationWarning,
match="strict_exception_groups=False",
) as record:
async with trio.open_nursery(strict_exception_groups=False):
...
assert len(record) == 1
async with trio.open_nursery(strict_exception_groups=True):
...
async with trio.open_nursery():
...
def test_deprecation_warning_run() -> None:
async def foo() -> None: ...
async def foo_nursery() -> None:
# this should not raise a warning, even if it's implied loose
async with trio.open_nursery():
...
async def foo_loose_nursery() -> None:
# this should raise a warning, even if specifying the parameter is redundant
async with trio.open_nursery(strict_exception_groups=False):
...
def helper(fun: Callable[..., Awaitable[None]], num: int) -> None:
with pytest.warns(
trio.TrioDeprecationWarning,
match="strict_exception_groups=False",
) as record:
trio.run(fun, strict_exception_groups=False)
assert len(record) == num
helper(foo, 1)
helper(foo_nursery, 1)
helper(foo_loose_nursery, 2)
def test_deprecation_warning_start_guest_run() -> None:
# "The simplest possible "host" loop."
from .._core._tests.test_guest_mode import trivial_guest_run
async def trio_return(in_host: object) -> str:
await trio.lowlevel.checkpoint()
return "ok"
with pytest.warns(
trio.TrioDeprecationWarning,
match="strict_exception_groups=False",
) as record:
trivial_guest_run(
trio_return,
strict_exception_groups=False,
)
assert len(record) == 1

View File

@ -0,0 +1,900 @@
from __future__ import annotations
import random
from contextlib import asynccontextmanager
from itertools import count
from typing import TYPE_CHECKING, NoReturn
import attrs
import pytest
from trio._tests.pytest_plugin import skip_if_optional_else_raise
try:
import trustme
from OpenSSL import SSL
except ImportError as error:
skip_if_optional_else_raise(error)
import trio
import trio.testing
from trio import DTLSChannel, DTLSEndpoint
from trio.testing._fake_net import FakeNet, UDPPacket
from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
ca = trustme.CA()
server_cert = ca.issue_cert("example.com")
server_ctx = SSL.Context(SSL.DTLS_METHOD)
server_cert.configure_cert(server_ctx)
client_ctx = SSL.Context(SSL.DTLS_METHOD)
ca.configure_trust(client_ctx)
parametrize_ipv6 = pytest.mark.parametrize(
"ipv6",
[False, pytest.param(True, marks=binds_ipv6)],
ids=["ipv4", "ipv6"],
)
def endpoint(**kwargs: int | bool) -> DTLSEndpoint:
ipv6 = kwargs.pop("ipv6", False)
family = trio.socket.AF_INET6 if ipv6 else trio.socket.AF_INET
sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
return DTLSEndpoint(sock, **kwargs)
@asynccontextmanager
async def dtls_echo_server(
*,
autocancel: bool = True,
mtu: int | None = None,
ipv6: bool = False,
) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]:
with endpoint(ipv6=ipv6) as server:
localhost = "::1" if ipv6 else "127.0.0.1"
await server.socket.bind((localhost, 0))
async with trio.open_nursery() as nursery:
async def echo_handler(dtls_channel: DTLSChannel) -> None:
print(
"echo handler started: "
f"server {dtls_channel.endpoint.socket.getsockname()!r} "
f"client {dtls_channel.peer_address!r}",
)
if mtu is not None:
dtls_channel.set_ciphertext_mtu(mtu)
try:
print("server starting do_handshake")
await dtls_channel.do_handshake()
print("server finished do_handshake")
async for packet in dtls_channel:
print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}")
await dtls_channel.send(packet)
except trio.BrokenResourceError: # pragma: no cover
print("echo handler channel broken")
await nursery.start(server.serve, server_ctx, echo_handler)
yield server, server.socket.getsockname()
if autocancel:
nursery.cancel_scope.cancel()
@parametrize_ipv6
async def test_smoke(ipv6: bool) -> None:
async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address):
with endpoint(ipv6=ipv6) as client_endpoint:
client_channel = client_endpoint.connect(address, client_ctx)
with pytest.raises(trio.NeedHandshakeError):
client_channel.get_cleartext_mtu()
await client_channel.do_handshake()
await client_channel.send(b"hello")
assert await client_channel.receive() == b"hello"
await client_channel.send(b"goodbye")
assert await client_channel.receive() == b"goodbye"
with pytest.raises(
ValueError,
match="^openssl doesn't support sending empty DTLS packets$",
):
await client_channel.send(b"")
client_channel.set_ciphertext_mtu(1234)
cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
client_channel.set_ciphertext_mtu(4321)
assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
client_channel.set_ciphertext_mtu(1234)
assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
@slow
async def test_handshake_over_terrible_network(
autojump_clock: trio.testing.MockClock,
) -> None:
HANDSHAKES = 100
r = random.Random(0)
fn = FakeNet()
fn.enable()
# avoid spurious timeouts on slow machines
autojump_clock.autojump_threshold = 0.001
async with dtls_echo_server() as (_, address):
async with trio.open_nursery() as nursery:
async def route_packet(packet: UDPPacket) -> None:
while True:
op = r.choices(
["deliver", "drop", "dupe", "delay"],
weights=[0.7, 0.1, 0.1, 0.1],
)[0]
print(f"{packet.source} -> {packet.destination}: {op}")
if op == "drop":
return
elif op == "dupe":
fn.send_packet(packet)
elif op == "delay":
await trio.sleep(r.random() * 3)
# I wanted to test random packet corruption too, but it turns out
# openssl has a bug in the following scenario:
#
# - client sends ClientHello
# - server sends HelloVerifyRequest with cookie -- but cookie is
# invalid b/c either the ClientHello or HelloVerifyRequest was
# corrupted
# - client re-sends ClientHello with invalid cookie
# - server replies with new HelloVerifyRequest and correct cookie
#
# At this point, the client *should* switch to the new, valid
# cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
# the original, invalid cookie over and over. In theory we could
# work around this by detecting cookie changes and starting over
# with a whole new SSL object, but (a) it doesn't seem worth it, (b)
# when I tried then I ran into another issue where OpenSSL got stuck
# in an infinite loop sending alerts over and over, which I didn't
# dig into because see (a).
#
# elif op == "distort":
# payload = bytearray(packet.payload)
# payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
# packet = attrs.evolve(packet, payload=payload)
else:
assert op == "deliver"
print(
f"{packet.source} -> {packet.destination}: delivered"
f" {packet.payload.hex()}",
)
fn.deliver_packet(packet)
break
def route_packet_wrapper(packet: UDPPacket) -> None:
try: # noqa: SIM105 # suppressible-exception
nursery.start_soon(route_packet, packet)
except RuntimeError: # pragma: no cover
# We're exiting the nursery, so any remaining packets can just get
# dropped
pass
fn.route_packet = route_packet_wrapper # type: ignore[assignment] # TODO: Fix FakeNet typing
for i in range(HANDSHAKES):
print("#" * 80)
print("#" * 80)
print("#" * 80)
with endpoint() as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
print("client starting do_handshake")
await client.do_handshake()
print("client finished do_handshake")
msg = str(i).encode()
# Make multiple attempts to send data, because the network might
# drop it
while True:
with trio.move_on_after(10) as cscope:
await client.send(msg)
assert await client.receive() == msg
if not cscope.cancelled_caught:
break
async def test_implicit_handshake() -> None:
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
# Implicit handshake
await client.send(b"xyz")
assert await client.receive() == b"xyz"
async def test_full_duplex() -> None:
# Tests simultaneous send/receive, and also multiple methods implicitly invoking
# do_handshake simultaneously.
with endpoint() as server_endpoint, endpoint() as client_endpoint:
await server_endpoint.socket.bind(("127.0.0.1", 0))
async with trio.open_nursery() as server_nursery:
async def handler(channel: DTLSChannel) -> None:
async with trio.open_nursery() as nursery:
nursery.start_soon(channel.send, b"from server")
nursery.start_soon(channel.receive)
await server_nursery.start(server_endpoint.serve, server_ctx, handler)
client = client_endpoint.connect(
server_endpoint.socket.getsockname(),
client_ctx,
)
async with trio.open_nursery() as nursery:
nursery.start_soon(client.send, b"from client")
nursery.start_soon(client.receive)
server_nursery.cancel_scope.cancel()
async def test_channel_closing() -> None:
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
await client.do_handshake()
client.close()
with pytest.raises(trio.ClosedResourceError):
await client.send(b"abc")
with pytest.raises(trio.ClosedResourceError):
await client.receive()
# close is idempotent
client.close()
# can also aclose
await client.aclose()
async def test_serve_exits_cleanly_on_close() -> None:
async with dtls_echo_server(autocancel=False) as (server_endpoint, address):
server_endpoint.close()
# Testing that the nursery exits even without being cancelled
# close is idempotent
server_endpoint.close()
async def test_client_multiplex() -> None:
async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2):
with endpoint() as client_endpoint:
client1 = client_endpoint.connect(address1, client_ctx)
client2 = client_endpoint.connect(address2, client_ctx)
await client1.send(b"abc")
await client2.send(b"xyz")
assert await client2.receive() == b"xyz"
assert await client1.receive() == b"abc"
client_endpoint.close()
with pytest.raises(trio.ClosedResourceError):
await client1.send(b"xxx")
with pytest.raises(trio.ClosedResourceError):
await client2.receive()
with pytest.raises(trio.ClosedResourceError):
client_endpoint.connect(address1, client_ctx)
async def null_handler(_: object) -> None: # pragma: no cover
pass
async with trio.open_nursery() as nursery:
with pytest.raises(trio.ClosedResourceError):
await nursery.start(client_endpoint.serve, server_ctx, null_handler)
async def test_dtls_over_dgram_only() -> None:
with trio.socket.socket() as s:
with pytest.raises(ValueError, match="^DTLS requires a SOCK_DGRAM socket$"):
DTLSEndpoint(s)
async def test_double_serve() -> None:
async def null_handler(_: object) -> None: # pragma: no cover
pass
with endpoint() as server_endpoint:
await server_endpoint.socket.bind(("127.0.0.1", 0))
async with trio.open_nursery() as nursery:
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
with pytest.raises(trio.BusyResourceError):
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
nursery.cancel_scope.cancel()
async with trio.open_nursery() as nursery:
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
nursery.cancel_scope.cancel()
async def test_connect_to_non_server(autojump_clock: trio.abc.Clock) -> None:
fn = FakeNet()
fn.enable()
with endpoint() as client1, endpoint() as client2:
await client1.socket.bind(("127.0.0.1", 0))
# This should just time out
with trio.move_on_after(100) as cscope:
channel = client2.connect(client1.socket.getsockname(), client_ctx)
await channel.do_handshake()
assert cscope.cancelled_caught
async def test_incoming_buffer_overflow(autojump_clock: trio.abc.Clock) -> None:
fn = FakeNet()
fn.enable()
for buffer_size in [10, 20]:
async with dtls_echo_server() as (_, address):
with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
assert client_endpoint.incoming_packets_buffer == buffer_size
client = client_endpoint.connect(address, client_ctx)
for i in range(buffer_size + 15):
await client.send(str(i).encode())
await trio.sleep(1)
stats = client.statistics()
assert stats.incoming_packets_dropped_in_trio == 15
for i in range(buffer_size):
assert await client.receive() == str(i).encode()
await client.send(b"buffer clear now")
assert await client.receive() == b"buffer clear now"
async def test_server_socket_doesnt_crash_on_garbage(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
from trio._dtls import (
ContentType,
HandshakeFragment,
HandshakeType,
ProtocolVersion,
Record,
encode_handshake_fragment,
encode_record,
)
client_hello = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=encode_handshake_fragment(
HandshakeFragment(
msg_type=HandshakeType.client_hello,
msg_len=10,
msg_seq=0,
frag_offset=0,
frag_len=10,
frag=bytes(10),
),
),
),
)
client_hello_extended = client_hello + b"\x00"
client_hello_short = client_hello[:-1]
# cuts off in middle of handshake message header
client_hello_really_short = client_hello[:14]
client_hello_corrupt_record_len = bytearray(client_hello)
client_hello_corrupt_record_len[11] = 0xFF
client_hello_fragmented = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=encode_handshake_fragment(
HandshakeFragment(
msg_type=HandshakeType.client_hello,
msg_len=20,
msg_seq=0,
frag_offset=0,
frag_len=10,
frag=bytes(10),
),
),
),
)
client_hello_trailing_data_in_record = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=encode_handshake_fragment(
HandshakeFragment(
msg_type=HandshakeType.client_hello,
msg_len=20,
msg_seq=0,
frag_offset=0,
frag_len=10,
frag=bytes(10),
),
)
+ b"\x00",
),
)
handshake_empty = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=b"",
),
)
client_hello_truncated_in_cookie = encode_record(
Record(
content_type=ContentType.handshake,
version=ProtocolVersion.DTLS10,
epoch_seqno=0,
payload=bytes(2 + 32 + 1) + b"\xff",
),
)
async with dtls_echo_server() as (_, address):
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
for bad_packet in [
b"",
b"xyz",
client_hello_extended,
client_hello_short,
client_hello_really_short,
client_hello_corrupt_record_len,
client_hello_fragmented,
client_hello_trailing_data_in_record,
handshake_empty,
client_hello_truncated_in_cookie,
]:
await sock.sendto(bad_packet, address)
await trio.sleep(1)
async def test_invalid_cookie_rejected(autojump_clock: trio.abc.Clock) -> None:
fn = FakeNet()
fn.enable()
from trio._dtls import BadPacket, decode_client_hello_untrusted
with trio.CancelScope() as cscope:
# the first 11 bytes of ClientHello aren't protected by the cookie, so only test
# corrupting bytes after that.
offset_to_corrupt = count(11)
def route_packet(packet: UDPPacket) -> None:
try:
_, cookie, _ = decode_client_hello_untrusted(packet.payload)
except BadPacket:
pass
else:
if len(cookie) != 0:
# this is a challenge response packet
# let's corrupt the next offset so the handshake should fail
payload = bytearray(packet.payload)
offset = next(offset_to_corrupt)
if offset >= len(payload):
# We've tried all offsets. Clamp offset to the end of the
# payload, and terminate the test.
offset = len(payload) - 1
cscope.cancel()
payload[offset] ^= 0x01
packet = attrs.evolve(packet, payload=payload)
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO: Fix FakeNet typing
async with dtls_echo_server() as (_, address):
while True:
with endpoint() as client:
channel = client.connect(address, client_ctx)
await channel.do_handshake()
assert cscope.cancelled_caught
async def test_client_cancels_handshake_and_starts_new_one(
autojump_clock: trio.abc.Clock,
) -> None:
# if a client disappears during the handshake, and then starts a new handshake from
# scratch, then the first handler's channel should fail, and a new handler get
# started
fn = FakeNet()
fn.enable()
with endpoint() as server, endpoint() as client:
await server.socket.bind(("127.0.0.1", 0))
async with trio.open_nursery() as nursery:
first_time = True
async def handler(channel: DTLSChannel) -> None:
nonlocal first_time
if first_time:
first_time = False
print("handler: first time, cancelling connect")
connect_cscope.cancel()
await trio.sleep(0.5)
print("handler: handshake should fail now")
with pytest.raises(trio.BrokenResourceError):
await channel.do_handshake()
else:
print("handler: not first time, sending hello")
await channel.send(b"hello")
await nursery.start(server.serve, server_ctx, handler)
print("client: starting first connect")
with trio.CancelScope() as connect_cscope:
channel = client.connect(server.socket.getsockname(), client_ctx)
await channel.do_handshake()
assert connect_cscope.cancelled_caught
print("client: starting second connect")
channel = client.connect(server.socket.getsockname(), client_ctx)
assert await channel.receive() == b"hello"
# Give handlers a chance to finish
await trio.sleep(10)
nursery.cancel_scope.cancel()
async def test_swap_client_server() -> None:
with endpoint() as a, endpoint() as b:
await a.socket.bind(("127.0.0.1", 0))
await b.socket.bind(("127.0.0.1", 0))
async def echo_handler(channel: DTLSChannel) -> None:
async for packet in channel:
await channel.send(packet)
async def crashing_echo_handler(channel: DTLSChannel) -> None:
with pytest.raises(trio.BrokenResourceError):
await echo_handler(channel)
async with trio.open_nursery() as nursery:
await nursery.start(a.serve, server_ctx, crashing_echo_handler)
await nursery.start(b.serve, server_ctx, echo_handler)
b_to_a = b.connect(a.socket.getsockname(), client_ctx)
await b_to_a.send(b"b as client")
assert await b_to_a.receive() == b"b as client"
a_to_b = a.connect(b.socket.getsockname(), client_ctx)
await a_to_b.do_handshake()
with pytest.raises(trio.BrokenResourceError):
await b_to_a.send(b"association broken")
await a_to_b.send(b"a as client")
assert await a_to_b.receive() == b"a as client"
nursery.cancel_scope.cancel()
@slow
async def test_openssl_retransmit_doesnt_break_stuff() -> None:
# can't use autojump_clock here, because the point of the test is to wait for
# openssl's built-in retransmit timer to expire, which is hard-coded to use
# wall-clock time.
fn = FakeNet()
fn.enable()
blackholed = True
def route_packet(packet: UDPPacket) -> None:
if blackholed:
print("dropped packet", packet)
return
print("delivered packet", packet)
# packets.append(
# scapy.all.IP(
# src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
# )
# / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
# / packet.payload
# )
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (server_endpoint, address):
with endpoint() as client_endpoint:
async with trio.open_nursery() as nursery:
async def connecter() -> None:
client = client_endpoint.connect(address, client_ctx)
await client.do_handshake(initial_retransmit_timeout=1.5)
await client.send(b"hi")
assert await client.receive() == b"hi"
nursery.start_soon(connecter)
# openssl's default timeout is 1 second, so this ensures that it thinks
# the timeout has expired
await trio.sleep(1.1)
# disable blackholing and send a garbage packet to wake up openssl so it
# notices the timeout has expired
blackholed = False
await server_endpoint.socket.sendto(
b"xxx",
client_endpoint.socket.getsockname(),
)
# now the client task should finish connecting and exit cleanly
# scapy.all.wrpcap("/tmp/trace.pcap", packets)
async def test_initial_retransmit_timeout_configuration(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
blackholed = True
def route_packet(packet: UDPPacket) -> None:
nonlocal blackholed
if blackholed:
blackholed = False
else:
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (_, address):
for t in [1, 2, 4]:
with endpoint() as client:
before = trio.current_time()
blackholed = True
channel = client.connect(address, client_ctx)
await channel.do_handshake(initial_retransmit_timeout=t)
after = trio.current_time()
assert after - before == t
async def test_explicit_tiny_mtu_is_respected() -> None:
# ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
# be larger than that. (300 is still smaller than any real network though.)
MTU = 300
fn = FakeNet()
fn.enable()
def route_packet(packet: UDPPacket) -> None:
print(f"delivering {packet}")
print(f"payload size: {len(packet.payload)}")
assert len(packet.payload) <= MTU
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server(mtu=MTU) as (server, address):
with endpoint() as client:
channel = client.connect(address, client_ctx)
channel.set_ciphertext_mtu(MTU)
await channel.do_handshake()
await channel.send(b"hi")
assert await channel.receive() == b"hi"
@parametrize_ipv6
async def test_handshake_handles_minimum_network_mtu(
ipv6: bool,
autojump_clock: trio.abc.Clock,
) -> None:
# Fake network that has the minimum allowable MTU for whatever protocol we're using.
fn = FakeNet()
fn.enable()
mtu = 1280 - 48 if ipv6 else 576 - 28
def route_packet(packet: UDPPacket) -> None:
if len(packet.payload) > mtu:
print(f"dropping {packet}")
else:
print(f"delivering {packet}")
fn.deliver_packet(packet)
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
# See if we can successfully do a handshake -- some of the volleys will get dropped,
# and the retransmit logic should detect this and back off the MTU to something
# smaller until it succeeds.
async with dtls_echo_server(ipv6=ipv6) as (_, address):
with endpoint(ipv6=ipv6) as client_endpoint:
client = client_endpoint.connect(address, client_ctx)
# the handshake mtu backoff shouldn't affect the return value from
# get_cleartext_mtu, b/c that's under the user's control via
# set_ciphertext_mtu
client.set_ciphertext_mtu(9999)
await client.send(b"xyz")
assert await client.receive() == b"xyz"
assert client.get_cleartext_mtu() > 9000 # as vegeta said
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
async def test_system_task_cleaned_up_on_gc() -> None:
before_tasks = trio.lowlevel.current_statistics().tasks_living
# We put this into a sub-function so that everything automatically becomes garbage
# when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
# with coverage enabled -- I think we were hitting this bug:
# https://foss.heptapod.net/pypy/pypy/-/issues/3656
async def start_and_forget_endpoint() -> int:
e = endpoint()
# This connection/handshake attempt can't succeed. The only purpose is to force
# the endpoint to set up a receive loop.
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
await s.bind(("127.0.0.1", 0))
c = e.connect(s.getsockname(), client_ctx)
async with trio.open_nursery() as nursery:
nursery.start_soon(c.do_handshake)
await trio.testing.wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
during_tasks = trio.lowlevel.current_statistics().tasks_living
return during_tasks
with pytest.warns(ResourceWarning):
during_tasks = await start_and_forget_endpoint()
await trio.testing.wait_all_tasks_blocked()
gc_collect_harder()
await trio.testing.wait_all_tasks_blocked()
after_tasks = trio.lowlevel.current_statistics().tasks_living
assert before_tasks < during_tasks
assert before_tasks == after_tasks
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
async def test_gc_before_system_task_starts() -> None:
e = endpoint()
with pytest.warns(ResourceWarning):
del e
gc_collect_harder()
await trio.testing.wait_all_tasks_blocked()
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
async def test_gc_as_packet_received() -> None:
fn = FakeNet()
fn.enable()
e = endpoint()
await e.socket.bind(("127.0.0.1", 0))
e._ensure_receive_loop()
await trio.testing.wait_all_tasks_blocked()
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
await s.sendto(b"xxx", e.socket.getsockname())
# At this point, the endpoint's receive loop has been marked runnable because it
# just received a packet; closing the endpoint socket won't interrupt that. But by
# the time it wakes up to process the packet, the endpoint will be gone.
with pytest.warns(ResourceWarning):
del e
gc_collect_harder()
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
def test_gc_after_trio_exits() -> None:
async def main() -> DTLSEndpoint:
# We use fakenet just to make sure no real sockets can leak out of the test
# case - on pypy somehow the socket was outliving the gc_collect_harder call
# below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
# when called after trio exits, it doesn't need a real socket.
fn = FakeNet()
fn.enable()
return endpoint()
e = trio.run(main)
with pytest.warns(ResourceWarning):
del e
gc_collect_harder()
async def test_already_closed_socket_doesnt_crash() -> None:
with endpoint() as e:
# We close the socket before checkpointing, so the socket will already be closed
# when the system task starts up
e.socket.close()
# Now give it a chance to start up, and hopefully not crash
await trio.testing.wait_all_tasks_blocked()
async def test_socket_closed_while_processing_clienthello(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
# Check what happens if the socket is discovered to be closed when sending a
# HelloVerifyRequest, since that has its own sending logic
async with dtls_echo_server() as (server, address):
def route_packet(packet: UDPPacket) -> None:
fn.deliver_packet(packet)
server.socket.close()
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
with endpoint() as client_endpoint:
with trio.move_on_after(10):
client = client_endpoint.connect(address, client_ctx)
await client.do_handshake()
async def test_association_replaced_while_handshake_running(
autojump_clock: trio.abc.Clock,
) -> None:
fn = FakeNet()
fn.enable()
def route_packet(packet: UDPPacket) -> None:
pass
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
c1 = client_endpoint.connect(address, client_ctx)
async with trio.open_nursery() as nursery:
async def doomed_handshake() -> None:
with pytest.raises(trio.BrokenResourceError):
await c1.do_handshake()
nursery.start_soon(doomed_handshake)
await trio.sleep(10)
client_endpoint.connect(address, client_ctx)
async def test_association_replaced_before_handshake_starts() -> None:
fn = FakeNet()
fn.enable()
# This test shouldn't send any packets
def route_packet(packet: UDPPacket) -> NoReturn: # pragma: no cover
raise AssertionError()
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
c1 = client_endpoint.connect(address, client_ctx)
client_endpoint.connect(address, client_ctx)
with pytest.raises(trio.BrokenResourceError):
await c1.do_handshake()
async def test_send_to_closed_local_port() -> None:
# On Windows, sending a UDP packet to a closed local port can cause a weird
# ECONNRESET error later, inside the receive task. Make sure we're handling it
# properly.
async with dtls_echo_server() as (_, address):
with endpoint() as client_endpoint:
async with trio.open_nursery() as nursery:
for i in range(1, 10):
channel = client_endpoint.connect(("127.0.0.1", i), client_ctx)
nursery.start_soon(channel.do_handshake)
channel = client_endpoint.connect(address, client_ctx)
await channel.send(b"xxx")
assert await channel.receive() == b"xxx"
nursery.cancel_scope.cancel()

View File

@ -0,0 +1,574 @@
from __future__ import annotations # isort: split
import __future__ # Regular import, not special!
import enum
import functools
import importlib
import inspect
import json
import socket as stdlib_socket
import sys
import types
from pathlib import Path, PurePath
from types import ModuleType
from typing import TYPE_CHECKING, Protocol
import attrs
import pytest
import trio
import trio.testing
from trio._tests.pytest_plugin import skip_if_optional_else_raise
from .. import _core, _util
from .._core._tests.tutil import slow
from .pytest_plugin import RUN_SLOW
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
mypy_cache_updated = False
try: # If installed, check both versions of this class.
from typing_extensions import Protocol as Protocol_ext
except ImportError: # pragma: no cover
Protocol_ext = Protocol # type: ignore[assignment]
def _ensure_mypy_cache_updated() -> None:
# This pollutes the `empty` dir. Should this be changed?
try:
from mypy.api import run
except ImportError as error:
skip_if_optional_else_raise(error)
global mypy_cache_updated
if not mypy_cache_updated:
# mypy cache was *probably* already updated by the other tests,
# but `pytest -k ...` might run just this test on its own
result = run(
[
"--config-file=",
"--cache-dir=./.mypy_cache",
"--no-error-summary",
"-c",
"import trio",
],
)
assert not result[1] # stderr
assert not result[0] # stdout
mypy_cache_updated = True
def test_core_is_properly_reexported() -> None:
# Each export from _core should be re-exported by exactly one of these
# three modules:
sources = [trio, trio.lowlevel, trio.testing]
for symbol in dir(_core):
if symbol.startswith("_"):
continue
found = 0
for source in sources:
if symbol in dir(source) and getattr(source, symbol) is getattr(
_core,
symbol,
):
found += 1
print(symbol, found)
assert found == 1
def class_is_final(cls: type) -> bool:
"""Check if a class cannot be subclassed."""
try:
# new_class() handles metaclasses properly, type(...) does not.
types.new_class("SubclassTester", (cls,))
except TypeError:
return True
else:
return False
def iter_modules(
module: types.ModuleType,
only_public: bool,
) -> Iterator[types.ModuleType]:
yield module
for name, class_ in module.__dict__.items():
if name.startswith("_") and only_public:
continue
if not isinstance(class_, ModuleType):
continue
if not class_.__name__.startswith(module.__name__): # pragma: no cover
continue
if class_ is module: # pragma: no cover
continue
yield from iter_modules(class_, only_public)
PUBLIC_MODULES = list(iter_modules(trio, only_public=True))
ALL_MODULES = list(iter_modules(trio, only_public=False))
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
# It doesn't make sense for downstream redistributors to run this test, since
# they might be using a newer version of Python with additional symbols which
# won't be reflected in trio.socket, and this shouldn't cause downstream test
# runs to start failing.
@pytest.mark.redistributors_should_skip()
# Static analysis tools often have trouble with alpha releases, where Python's
# internals are in flux, grammar may not have settled down, etc.
@pytest.mark.skipif(
sys.version_info.releaselevel == "alpha",
reason="skip static introspection tools on Python dev/alpha releases",
)
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"])
@pytest.mark.filterwarnings(
# https://github.com/pypa/setuptools/issues/3274
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
)
def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None:
module = importlib.import_module(modname)
def no_underscores(symbols: Iterable[str]) -> set[str]:
return {symbol for symbol in symbols if not symbol.startswith("_")}
runtime_names = no_underscores(dir(module))
# ignore deprecated module `tests` being invisible
if modname == "trio":
runtime_names.discard("tests")
# Ignore any __future__ feature objects, if imported under that name.
for name in __future__.all_feature_names:
if getattr(module, name, None) is getattr(__future__, name):
runtime_names.remove(name)
if tool == "pylint":
try:
from pylint.lint import PyLinter
except ImportError as error:
skip_if_optional_else_raise(error)
linter = PyLinter()
assert module.__file__ is not None
ast = linter.get_ast(module.__file__, modname)
static_names = no_underscores(ast) # type: ignore[arg-type]
elif tool == "jedi":
if sys.implementation.name != "cpython":
pytest.skip("jedi does not support pypy")
try:
import jedi
except ImportError as error:
skip_if_optional_else_raise(error)
# Simulate typing "import trio; trio.<TAB>"
script = jedi.Script(f"import {modname}; {modname}.")
completions = script.complete()
static_names = no_underscores(c.name for c in completions)
elif tool == "mypy":
if not RUN_SLOW: # pragma: no cover
pytest.skip("use --run-slow to check against mypy")
cache = Path.cwd() / ".mypy_cache"
_ensure_mypy_cache_updated()
trio_cache = next(cache.glob("*/trio"))
_, modname = (modname + ".").split(".", 1)
modname = modname[:-1]
mod_cache = trio_cache / modname if modname else trio_cache
if mod_cache.is_dir(): # pragma: no coverage
mod_cache = mod_cache / "__init__.data.json"
else:
mod_cache = trio_cache / (modname + ".data.json")
assert mod_cache.exists()
assert mod_cache.is_file()
with mod_cache.open() as cache_file:
cache_json = json.loads(cache_file.read())
static_names = no_underscores(
key
for key, value in cache_json["names"].items()
if not key.startswith(".") and value["kind"] == "Gdef"
)
elif tool == "pyright_verifytypes":
if not RUN_SLOW: # pragma: no cover
pytest.skip("use --run-slow to check against pyright")
try:
import pyright # noqa: F401
except ImportError as error:
skip_if_optional_else_raise(error)
import subprocess
res = subprocess.run(
["pyright", f"--verifytypes={modname}", "--outputjson"],
capture_output=True,
)
current_result = json.loads(res.stdout)
static_names = {
x["name"][len(modname) + 1 :]
for x in current_result["typeCompleteness"]["symbols"]
if x["name"].startswith(modname)
}
else: # pragma: no cover
raise AssertionError()
# It's expected that the static set will contain more names than the
# runtime set:
# - static tools are sometimes sloppy and include deleted names
# - some symbols are platform-specific at runtime, but always show up in
# static analysis (e.g. in trio.socket or trio.lowlevel)
# So we check that the runtime names are a subset of the static names.
missing_names = runtime_names - static_names
# ignore warnings about deprecated module tests
missing_names -= {"tests"}
if missing_names: # pragma: no cover
print(f"{tool} can't see the following names in {modname}:")
print()
for name in sorted(missing_names):
print(f" {name}")
raise AssertionError()
# this could be sped up by only invoking mypy once per module, or even once for all
# modules, instead of once per class.
@slow
# see comment on test_static_tool_sees_all_symbols
@pytest.mark.redistributors_should_skip()
# Static analysis tools often have trouble with alpha releases, where Python's
# internals are in flux, grammar may not have settled down, etc.
@pytest.mark.skipif(
sys.version_info.releaselevel == "alpha",
reason="skip static introspection tools on Python dev/alpha releases",
)
@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES)
@pytest.mark.parametrize("tool", ["jedi", "mypy"])
def test_static_tool_sees_class_members(
tool: str,
module_name: str,
tmp_path: Path,
) -> None:
module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]
# ignore hidden, but not dunder, symbols
def no_hidden(symbols: Iterable[str]) -> set[str]:
return {
symbol
for symbol in symbols
if (not symbol.startswith("_")) or symbol.startswith("__")
}
if tool == "jedi" and sys.implementation.name != "cpython":
pytest.skip("jedi does not support pypy")
if tool == "mypy":
cache = Path.cwd() / ".mypy_cache"
_ensure_mypy_cache_updated()
trio_cache = next(cache.glob("*/trio"))
modname = module_name
_, modname = (modname + ".").split(".", 1)
modname = modname[:-1]
mod_cache = trio_cache / modname if modname else trio_cache
if mod_cache.is_dir():
mod_cache = mod_cache / "__init__.data.json"
else:
mod_cache = trio_cache / (modname + ".data.json")
assert mod_cache.exists()
assert mod_cache.is_file()
with mod_cache.open() as cache_file:
cache_json = json.loads(cache_file.read())
# skip a bunch of file-system activity (probably can un-memoize?)
@functools.lru_cache
def lookup_symbol(symbol: str) -> dict[str, str]:
topname, *modname, name = symbol.split(".")
version = next(cache.glob("3.*/"))
mod_cache = version / topname
if not mod_cache.is_dir():
mod_cache = version / (topname + ".data.json")
if modname:
for piece in modname[:-1]:
mod_cache /= piece
next_cache = mod_cache / modname[-1]
if next_cache.is_dir(): # pragma: no coverage
mod_cache = next_cache / "__init__.data.json"
else:
mod_cache = mod_cache / (modname[-1] + ".data.json")
elif mod_cache.is_dir():
mod_cache /= "__init__.data.json"
with mod_cache.open() as f:
return json.loads(f.read())["names"][name] # type: ignore[no-any-return]
errors: dict[str, object] = {}
for class_name, class_ in module.__dict__.items():
if not isinstance(class_, type):
continue
if module_name == "trio.socket" and class_name in dir(stdlib_socket):
continue
# ignore class that does dirty tricks
if class_ is trio.testing.RaisesGroup:
continue
# dir() and inspect.getmembers doesn't display properties from the metaclass
# also ignore some dunder methods that tend to differ but are of no consequence
ignore_names = set(dir(type(class_))) | {
"__annotations__",
"__attrs_attrs__",
"__attrs_own_setattr__",
"__callable_proto_members_only__",
"__class_getitem__",
"__final__",
"__getstate__",
"__match_args__",
"__order__",
"__orig_bases__",
"__parameters__",
"__protocol_attrs__",
"__setstate__",
"__slots__",
"__weakref__",
# ignore errors about dunders inherited from stdlib that tools might
# not see
"__copy__",
"__deepcopy__",
}
if type(class_) is type:
# C extension classes don't have these dunders, but Python classes do
ignore_names.add("__firstlineno__")
ignore_names.add("__static_attributes__")
# pypy seems to have some additional dunders that differ
if sys.implementation.name == "pypy":
ignore_names |= {
"__basicsize__",
"__dictoffset__",
"__itemsize__",
"__sizeof__",
"__weakrefoffset__",
"__unicode__",
}
# inspect.getmembers sees `name` and `value` in Enums, otherwise
# it behaves the same way as `dir`
# runtime_names = no_underscores(dir(class_))
runtime_names = (
no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names
)
if tool == "jedi":
try:
import jedi
except ImportError as error:
skip_if_optional_else_raise(error)
script = jedi.Script(
f"from {module_name} import {class_name}; {class_name}.",
)
completions = script.complete()
static_names = no_hidden(c.name for c in completions) - ignore_names
elif tool == "mypy":
# load the cached type information
cached_type_info = cache_json["names"][class_name]
if "node" not in cached_type_info:
cached_type_info = lookup_symbol(cached_type_info["cross_ref"])
assert "node" in cached_type_info
node = cached_type_info["node"]
static_names = no_hidden(k for k in node["names"] if not k.startswith("."))
for symbol in node["mro"][1:]:
node = lookup_symbol(symbol)["node"]
static_names |= no_hidden(
k for k in node["names"] if not k.startswith(".")
)
static_names -= ignore_names
else: # pragma: no cover
raise AssertionError("unknown tool")
missing = runtime_names - static_names
extra = static_names - runtime_names
# using .remove() instead of .delete() to get an error in case they start not
# being missing
if (
tool == "jedi"
and BaseException in class_.__mro__
and sys.version_info >= (3, 11)
):
missing.remove("add_note")
if (
tool == "mypy"
and BaseException in class_.__mro__
and sys.version_info >= (3, 11)
):
extra.remove("__notes__")
if tool == "mypy" and attrs.has(class_):
# e.g. __trio__core__run_CancelScope_AttrsAttributes__
before = len(extra)
extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
assert len(extra) == before - 1
# mypy does not see these attributes in Enum subclasses
if (
tool == "mypy"
and enum.Enum in class_.__mro__
and sys.version_info >= (3, 12)
):
# Another attribute, in 3.12+ only.
extra.remove("__signature__")
# TODO: this *should* be visible via `dir`!!
if tool == "mypy" and class_ == trio.Nursery:
extra.remove("cancel_scope")
# These are (mostly? solely?) *runtime* attributes, often set in
# __init__, which doesn't show up with dir() or inspect.getmembers,
# but we get them in the way we query mypy & jedi
EXTRAS = {
trio.DTLSChannel: {"peer_address", "endpoint"},
trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"},
trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"},
trio.SSLListener: {"transport_listener"},
trio.SSLStream: {"transport_stream"},
trio.SocketListener: {"socket"},
trio.SocketStream: {"socket"},
trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"},
trio.testing.MemorySendStream: {
"close_hook",
"send_all_hook",
"wait_send_all_might_not_block_hook",
},
trio.testing.Matcher: {
"exception_type",
"match",
"check",
},
}
if tool == "mypy" and class_ in EXTRAS:
before = len(extra)
extra -= EXTRAS[class_]
assert len(extra) == before - len(EXTRAS[class_])
# TODO: why is this? Is it a problem?
# see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916
if class_ == trio.StapledStream:
extra.remove("receive_stream")
extra.remove("send_stream")
# I have not researched why these are missing, should maybe create an issue
# upstream with jedi
if tool == "jedi" and sys.version_info >= (3, 11):
if class_ in (
trio.DTLSChannel,
trio.MemoryReceiveChannel,
trio.MemorySendChannel,
trio.SSLListener,
trio.SocketListener,
):
missing.remove("__aenter__")
missing.remove("__aexit__")
if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel):
missing.remove("__aiter__")
missing.remove("__anext__")
if class_ in (trio.Path, trio.WindowsPath, trio.PosixPath):
# These are from inherited subclasses.
missing -= PurePath.__dict__.keys()
# These are unix-only.
if tool == "mypy" and sys.platform == "win32":
missing -= {"owner", "is_mount", "group"}
if tool == "jedi" and sys.platform == "win32":
extra -= {"owner", "is_mount", "group"}
# not sure why jedi in particular ignores this (static?) method in 3.13
# (especially given the method is from 3.12....)
if (
tool == "jedi"
and sys.version_info >= (3, 13)
and class_ in (trio.Path, trio.WindowsPath, trio.PosixPath)
):
missing.remove("with_segments")
if missing or extra: # pragma: no cover
errors[f"{module_name}.{class_name}"] = {
"missing": missing,
"extra": extra,
}
# `assert not errors` will not print the full content of errors, even with
# `--verbose`, so we manually print it
if errors: # pragma: no cover
from pprint import pprint
print(f"\n{tool} can't see the following symbols in {module_name}:")
pprint(errors)
assert not errors
def test_nopublic_is_final() -> None:
"""Check all NoPublicConstructor classes are also @final."""
assert class_is_final(_util.NoPublicConstructor) # This is itself final.
for module in ALL_MODULES:
for class_ in module.__dict__.values():
if isinstance(class_, _util.NoPublicConstructor):
assert class_is_final(class_)
def test_classes_are_final() -> None:
# Sanity checks.
assert not class_is_final(object)
assert class_is_final(bool)
for module in PUBLIC_MODULES:
for name, class_ in module.__dict__.items():
if not isinstance(class_, type):
continue
# Deprecated classes are exported with a leading underscore
if name.startswith("_"): # pragma: no cover
continue
# Abstract classes can be subclassed, because that's the whole
# point of ABCs
if inspect.isabstract(class_):
continue
# Same with protocols, but only direct children.
if Protocol in class_.__bases__ or Protocol_ext in class_.__bases__:
continue
# Exceptions are allowed to be subclassed, because exception
# subclassing isn't used to inherit behavior.
if issubclass(class_, BaseException):
continue
# These are classes that are conceptually abstract, but
# inspect.isabstract returns False for boring reasons.
if class_ is trio.abc.Instrument or class_ is trio.socket.SocketType:
continue
# ... insert other special cases here ...
# The `Path` class needs to support inheritance to allow `WindowsPath` and `PosixPath`.
if class_ is trio.Path:
continue
# don't care about the *Statistics classes
if name.endswith("Statistics"):
continue
assert class_is_final(class_)

View File

@ -0,0 +1,313 @@
import errno
import re
import socket
import sys
import pytest
import trio
from trio.testing._fake_net import FakeNet
# ENOTCONN gives different messages on different platforms
if sys.platform == "linux":
ENOTCONN_MSG = r"^\[Errno 107\] (Transport endpoint is|Socket) not connected$"
elif sys.platform == "darwin":
ENOTCONN_MSG = r"^\[Errno 57\] Socket is not connected$"
else:
ENOTCONN_MSG = r"^\[Errno 10057\] Unknown error$"
def fn() -> FakeNet:
fn = FakeNet()
fn.enable()
return fn
async def test_basic_udp() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s1.bind(("127.0.0.1", 0))
ip, port = s1.getsockname()
assert ip == "127.0.0.1"
assert port != 0
with pytest.raises(
OSError,
match=r"^\[\w+ \d+\] Invalid argument$",
) as exc: # Cannot rebind.
await s1.bind(("192.0.2.1", 0))
assert exc.value.errno == errno.EINVAL
# Cannot bind multiple sockets to the same address
with pytest.raises(
OSError,
match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$",
) as exc:
await s2.bind(("127.0.0.1", port))
assert exc.value.errno == errno.EADDRINUSE
await s2.sendto(b"xyz", s1.getsockname())
data, addr = await s1.recvfrom(10)
assert data == b"xyz"
assert addr == s2.getsockname()
await s1.sendto(b"abc", s2.getsockname())
data, addr = await s2.recvfrom(10)
assert data == b"abc"
assert addr == s1.getsockname()
async def test_msg_trunc() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s1.bind(("127.0.0.1", 0))
await s2.sendto(b"xyz", s1.getsockname())
data, addr = await s1.recvfrom(10)
async def test_recv_methods() -> None:
"""Test all recv methods for codecov"""
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
# receiving on an unbound socket is a bad idea (I think?)
with pytest.raises(NotImplementedError, match="code will most likely hang"):
await s2.recv(10)
await s1.bind(("127.0.0.1", 0))
ip, port = s1.getsockname()
assert ip == "127.0.0.1"
assert port != 0
# recvfrom
await s2.sendto(b"abc", s1.getsockname())
data, addr = await s1.recvfrom(10)
assert data == b"abc"
assert addr == s2.getsockname()
# recv
await s1.sendto(b"def", s2.getsockname())
data = await s2.recv(10)
assert data == b"def"
# recvfrom_into
assert await s1.sendto(b"ghi", s2.getsockname()) == 3
buf = bytearray(10)
with pytest.raises(NotImplementedError, match="^partial recvfrom_into$"):
(nbytes, addr) = await s2.recvfrom_into(buf, nbytes=2)
(nbytes, addr) = await s2.recvfrom_into(buf)
assert nbytes == 3
assert buf == b"ghi" + b"\x00" * 7
assert addr == s1.getsockname()
# recv_into
assert await s1.sendto(b"jkl", s2.getsockname()) == 3
buf2 = bytearray(10)
nbytes = await s2.recv_into(buf2)
assert nbytes == 3
assert buf2 == b"jkl" + b"\x00" * 7
if sys.platform == "linux" and sys.implementation.name == "cpython":
flags: int = socket.MSG_MORE
else:
flags = 1
# Send seems explicitly non-functional
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
await s2.send(b"mno")
assert exc.value.errno == errno.ENOTCONN
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
await s2.send(b"mno", flags)
# sendto errors
# it's successfully used earlier
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
await s2.sendto(b"mno", flags, s1.getsockname())
with pytest.raises(TypeError, match="wrong number of arguments$"):
await s2.sendto(b"mno", flags, s1.getsockname(), "extra arg") # type: ignore[call-overload]
@pytest.mark.skipif(
sys.platform == "win32",
reason="functions not in socket on windows",
)
async def test_nonwindows_functionality() -> None:
# mypy doesn't support a good way of aborting typechecking on different platforms
if sys.platform != "win32": # pragma: no branch
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s2.bind(("127.0.0.1", 0))
# sendmsg
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
await s2.sendmsg([b"mno"])
assert exc.value.errno == errno.ENOTCONN
assert await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) == 3
(data, ancdata, msg_flags, addr) = await s2.recvmsg(10)
assert data == b"jkl"
assert ancdata == []
assert msg_flags == 0
assert addr == s1.getsockname()
# TODO: recvmsg
# recvmsg_into
assert await s1.sendto(b"xyzw", s2.getsockname()) == 4
buf1 = bytearray(2)
buf2 = bytearray(3)
ret = await s2.recvmsg_into([buf1, buf2])
(nbytes, ancdata, msg_flags, addr) = ret
assert nbytes == 4
assert buf1 == b"xy"
assert buf2 == b"zw" + b"\x00"
assert ancdata == []
assert msg_flags == 0
assert addr == s1.getsockname()
# recvmsg_into with MSG_TRUNC set
assert await s1.sendto(b"xyzwv", s2.getsockname()) == 5
buf1 = bytearray(2)
ret = await s2.recvmsg_into([buf1])
(nbytes, ancdata, msg_flags, addr) = ret
assert nbytes == 2
assert buf1 == b"xy"
assert ancdata == []
assert msg_flags == socket.MSG_TRUNC
assert addr == s1.getsockname()
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'share'$",
):
await s1.share(0) # type: ignore[attr-defined]
@pytest.mark.skipif(
sys.platform != "win32",
reason="windows-specific fakesocket testing",
)
async def test_windows_functionality() -> None:
# mypy doesn't support a good way of aborting typechecking on different platforms
if sys.platform == "win32": # pragma: no branch
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
await s1.bind(("127.0.0.1", 0))
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'sendmsg'$",
):
await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined]
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'recvmsg'$",
):
s2.recvmsg(0) # type: ignore[attr-defined]
with pytest.raises(
AttributeError,
match="^'FakeSocket' object has no attribute 'recvmsg_into'$",
):
s2.recvmsg_into([]) # type: ignore[attr-defined]
with pytest.raises(NotImplementedError):
s1.share(0)
async def test_basic_tcp() -> None:
fn()
with pytest.raises(NotImplementedError):
trio.socket.socket()
async def test_not_implemented_functions() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
# getsockopt
with pytest.raises(
OSError,
match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$",
):
s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
# setsockopt
with pytest.raises(
NotImplementedError,
match="^FakeNet always has IPV6_V6ONLY=True$",
):
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
with pytest.raises(
OSError,
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
):
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)
with pytest.raises(
OSError,
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
):
s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# set_inheritable
s1.set_inheritable(False)
with pytest.raises(
NotImplementedError,
match="^FakeNet can't make inheritable sockets$",
):
s1.set_inheritable(True)
# get_inheritable
assert not s1.get_inheritable()
async def test_getpeername() -> None:
fn()
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
s1.getpeername()
assert exc.value.errno == errno.ENOTCONN
await s1.bind(("127.0.0.1", 0))
with pytest.raises(
AssertionError,
match="^This method seems to assume that self._binding has a remote UDPEndpoint$",
):
s1.getpeername()
async def test_init() -> None:
fn()
with pytest.raises(
NotImplementedError,
match=re.escape(
f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}",
),
):
s1 = trio.socket.socket()
# getsockname on unbound ipv4 socket
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
assert s1.getsockname() == ("0.0.0.0", 0)
# getsockname on bound ipv4 socket
await s1.bind(("0.0.0.0", 0))
ip, port = s1.getsockname()
assert ip == "127.0.0.1"
assert port != 0
# getsockname on unbound ipv6 socket
s2 = trio.socket.socket(family=socket.AF_INET6, type=socket.SOCK_DGRAM)
assert s2.getsockname() == ("::", 0)
# getsockname on bound ipv6 socket
await s2.bind(("::", 0))
ip, port, *_ = s2.getsockname()
assert ip == "::1"
assert port != 0
assert _ == [0, 0]

View File

@ -0,0 +1,269 @@
from __future__ import annotations
import importlib
import io
import os
import re
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import sentinel
import pytest
import trio
from trio import _core, _file_io
from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper
if TYPE_CHECKING:
import pathlib
@pytest.fixture
def path(tmp_path: pathlib.Path) -> str:
return os.fspath(tmp_path / "test")
@pytest.fixture
def wrapped() -> mock.Mock:
return mock.Mock(spec_set=io.StringIO)
@pytest.fixture
def async_file(wrapped: mock.Mock) -> AsyncIOWrapper[mock.Mock]:
return trio.wrap_file(wrapped)
def test_wrap_invalid() -> None:
with pytest.raises(TypeError):
trio.wrap_file("")
def test_wrap_non_iobase() -> None:
class FakeFile:
def close(self) -> None: # pragma: no cover
pass
def write(self) -> None: # pragma: no cover
pass
wrapped = FakeFile()
assert not isinstance(wrapped, io.IOBase)
async_file = trio.wrap_file(wrapped)
assert isinstance(async_file, AsyncIOWrapper)
del FakeFile.write
with pytest.raises(TypeError):
trio.wrap_file(FakeFile())
def test_wrapped_property(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
assert async_file.wrapped is wrapped
def test_dir_matches_wrapped(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
# all supported attrs in wrapped should be available in async_file
assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
# all supported attrs not in wrapped should not be available in async_file
assert not any(
attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
)
def test_unsupported_not_forwarded() -> None:
class FakeFile(io.RawIOBase):
def unsupported_attr(self) -> None: # pragma: no cover
pass
async_file = trio.wrap_file(FakeFile())
assert hasattr(async_file.wrapped, "unsupported_attr")
with pytest.raises(AttributeError):
# B018 "useless expression"
async_file.unsupported_attr # type: ignore[attr-defined] # noqa: B018
def test_type_stubs_match_lists() -> None:
"""Check the manual stubs match the list of wrapped methods."""
# Fetch the module's source code.
assert _file_io.__spec__ is not None
loader = _file_io.__spec__.loader
assert isinstance(loader, importlib.abc.SourceLoader)
source = io.StringIO(loader.get_source("trio._file_io"))
# Find the class, then find the TYPE_CHECKING block.
for line in source:
if "class AsyncIOWrapper" in line:
break
else: # pragma: no cover - should always find this
pytest.fail("No class definition line?")
for line in source:
if "if TYPE_CHECKING" in line:
break
else: # pragma: no cover - should always find this
pytest.fail("No TYPE CHECKING line?")
# Now we should be at the type checking block.
found: list[tuple[str, str]] = []
for line in source: # pragma: no branch - expected to break early
if line.strip() and not line.startswith(" " * 8):
break # Dedented out of the if TYPE_CHECKING block.
match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line)
if match is not None:
kind = "async" if match.group(1) is not None else "sync"
found.append((match.group(2), kind))
# Compare two lists so that we can easily see duplicates, and see what is different overall.
expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS]
expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS]
# Ignore order, error if duplicates are present.
found.sort()
expected.sort()
assert found == expected
def test_sync_attrs_forwarded(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name not in dir(async_file):
continue
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
def test_sync_attrs_match_wrapper(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for attr_name in _FILE_SYNC_ATTRS:
if attr_name in dir(async_file):
continue
with pytest.raises(AttributeError):
getattr(async_file, attr_name)
with pytest.raises(AttributeError):
getattr(wrapped, attr_name)
def test_async_methods_generated_once(async_file: AsyncIOWrapper[mock.Mock]) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
# I gave up on typing this one
def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None:
# use read as a representative of all async methods
assert async_file.read.__name__ == "read"
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
assert async_file.read.__doc__ is not None
assert "io.StringIO.read" in async_file.read.__doc__
async def test_async_methods_wrap(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name not in dir(async_file):
continue
meth = getattr(async_file, meth_name)
wrapped_meth = getattr(wrapped, meth_name)
value = await meth(sentinel.argument, keyword=sentinel.keyword)
wrapped_meth.assert_called_once_with(
sentinel.argument,
keyword=sentinel.keyword,
)
assert value == wrapped_meth()
wrapped.reset_mock()
async def test_async_methods_match_wrapper(
async_file: AsyncIOWrapper[mock.Mock],
wrapped: mock.Mock,
) -> None:
for meth_name in _FILE_ASYNC_METHODS:
if meth_name in dir(async_file):
continue
with pytest.raises(AttributeError):
getattr(async_file, meth_name)
with pytest.raises(AttributeError):
getattr(wrapped, meth_name)
async def test_open(path: pathlib.Path) -> None:
f = await trio.open_file(path, "w")
assert isinstance(f, AsyncIOWrapper)
await f.aclose()
async def test_open_context_manager(path: pathlib.Path) -> None:
async with await trio.open_file(path, "w") as f:
assert isinstance(f, AsyncIOWrapper)
assert not f.closed
assert f.closed
async def test_async_iter() -> None:
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
expected = list(async_file.wrapped)
async_file.wrapped.seek(0)
result = [line async for line in async_file]
assert result == expected
async def test_aclose_cancelled(path: pathlib.Path) -> None:
with _core.CancelScope() as cscope:
f = await trio.open_file(path, "w")
cscope.cancel()
with pytest.raises(_core.Cancelled):
await f.write("a")
with pytest.raises(_core.Cancelled):
await f.aclose()
assert f.closed
async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None:
tmp_file = tmp_path / "filename"
tmp_file.touch()
# flake8-async does not like opening files in async mode
with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC230
buffered = io.BufferedReader(raw)
async_file = trio.wrap_file(buffered)
detached = await async_file.detach()
assert isinstance(detached, AsyncIOWrapper)
assert detached.wrapped is raw

View File

@ -0,0 +1,98 @@
from __future__ import annotations
from typing import NoReturn
import attrs
import pytest
from .._highlevel_generic import StapledStream
from ..abc import ReceiveStream, SendStream
@attrs.define(slots=False)
class RecordSendStream(SendStream):
record: list[str | tuple[str, object]] = attrs.Factory(list)
async def send_all(self, data: object) -> None:
self.record.append(("send_all", data))
async def wait_send_all_might_not_block(self) -> None:
self.record.append("wait_send_all_might_not_block")
async def aclose(self) -> None:
self.record.append("aclose")
@attrs.define(slots=False)
class RecordReceiveStream(ReceiveStream):
record: list[str | tuple[str, int | None]] = attrs.Factory(list)
async def receive_some(self, max_bytes: int | None = None) -> bytes:
self.record.append(("receive_some", max_bytes))
return b""
async def aclose(self) -> None:
self.record.append("aclose")
async def test_StapledStream() -> None:
send_stream = RecordSendStream()
receive_stream = RecordReceiveStream()
stapled = StapledStream(send_stream, receive_stream)
assert stapled.send_stream is send_stream
assert stapled.receive_stream is receive_stream
await stapled.send_all(b"foo")
await stapled.wait_send_all_might_not_block()
assert send_stream.record == [
("send_all", b"foo"),
"wait_send_all_might_not_block",
]
send_stream.record.clear()
await stapled.send_eof()
assert send_stream.record == ["aclose"]
send_stream.record.clear()
async def fake_send_eof() -> None:
send_stream.record.append("send_eof")
send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
await stapled.send_eof()
assert send_stream.record == ["send_eof"]
send_stream.record.clear()
assert receive_stream.record == []
await stapled.receive_some(1234)
assert receive_stream.record == [("receive_some", 1234)]
assert send_stream.record == []
receive_stream.record.clear()
await stapled.aclose()
assert receive_stream.record == ["aclose"]
assert send_stream.record == ["aclose"]
async def test_StapledStream_with_erroring_close() -> None:
# Make sure that if one of the aclose methods errors out, then the other
# one still gets called.
class BrokenSendStream(RecordSendStream):
async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError("send error")
class BrokenReceiveStream(RecordReceiveStream):
async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError("recv error")
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
with pytest.raises(ValueError, match="^(send|recv) error$") as excinfo:
await stapled.aclose()
assert isinstance(excinfo.value.__context__, ValueError)
assert stapled.send_stream.record == ["aclose"]
assert stapled.receive_stream.record == ["aclose"]

View File

@ -0,0 +1,410 @@
from __future__ import annotations
import errno
import socket as stdlib_socket
import sys
from socket import AddressFamily, SocketKind
from typing import TYPE_CHECKING, Any, Sequence, overload
import attrs
import pytest
import trio
from trio import (
SocketListener,
open_tcp_listeners,
open_tcp_stream,
serve_tcp,
)
from trio.abc import HostnameResolver, SendStream, SocketFactory
from trio.testing import open_stream_to_socket_listener
from .. import socket as tsocket
from .._core._tests.tutil import binds_ipv6
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
if TYPE_CHECKING:
from typing_extensions import Buffer
async def test_open_tcp_listeners_basic() -> None:
listeners = await open_tcp_listeners(0)
assert isinstance(listeners, list)
for obj in listeners:
assert isinstance(obj, SocketListener)
# Binds to wildcard address by default
assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
listener = listeners[0]
# Make sure the backlog is at least 2
c1 = await open_stream_to_socket_listener(listener)
c2 = await open_stream_to_socket_listener(listener)
s1 = await listener.accept()
s2 = await listener.accept()
# Note that we don't know which client stream is connected to which server
# stream
await s1.send_all(b"x")
await s2.send_all(b"x")
assert await c1.receive_some(1) == b"x"
assert await c2.receive_some(1) == b"x"
for resource in [c1, c2, s1, s2, *listeners]:
await resource.aclose()
async def test_open_tcp_listeners_specific_port_specific_host() -> None:
# Pick a port
sock = tsocket.socket()
await sock.bind(("127.0.0.1", 0))
host, port = sock.getsockname()
sock.close()
(listener,) = await open_tcp_listeners(port, host=host)
async with listener:
assert listener.socket.getsockname() == (host, port)
@binds_ipv6
async def test_open_tcp_listeners_ipv6_v6only() -> None:
# Check IPV6_V6ONLY is working properly
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
async with ipv6_listener:
_, port, *_ = ipv6_listener.socket.getsockname()
with pytest.raises(
OSError,
match=r"(Error|all attempts to) connect(ing)* to (\(')*127\.0\.0\.1(', |:)\d+(\): Connection refused| failed)$",
):
await open_tcp_stream("127.0.0.1", port)
async def test_open_tcp_listeners_rebind() -> None:
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
sockaddr1 = l1.socket.getsockname()
# Plain old rebinding while it's still there should fail, even if we have
# SO_REUSEADDR set
with stdlib_socket.socket() as probe:
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
with pytest.raises(
OSError,
match="(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$",
):
probe.bind(sockaddr1)
# Now use the first listener to set up some connections in various states,
# and make sure that they don't create any obstacle to rebinding a second
# listener after the first one is closed.
c_established = await open_stream_to_socket_listener(l1)
s_established = await l1.accept()
c_time_wait = await open_stream_to_socket_listener(l1)
s_time_wait = await l1.accept()
# Server-initiated close leaves socket in TIME_WAIT
await s_time_wait.aclose()
await l1.aclose()
(l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
sockaddr2 = l2.socket.getsockname()
assert sockaddr1 == sockaddr2
assert s_established.socket.getsockname() == sockaddr2
assert c_time_wait.socket.getpeername() == sockaddr2
for resource in [
l1,
l2,
c_established,
s_established,
c_time_wait,
s_time_wait,
]:
await resource.aclose()
class FakeOSError(OSError):
pass
@attrs.define(slots=False)
class FakeSocket(tsocket.SocketType):
_family: AddressFamily = attrs.field(converter=AddressFamily)
_type: SocketKind = attrs.field(converter=SocketKind)
_proto: int
closed: bool = False
poison_listen: bool = False
backlog: int | None = None
@property
def type(self) -> SocketKind:
return self._type
@property
def family(self) -> AddressFamily:
return self._family
@property
def proto(self) -> int: # pragma: no cover
return self._proto
@overload
def getsockopt(self, /, level: int, optname: int) -> int: ...
@overload
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
def getsockopt(
self,
/,
level: int,
optname: int,
buflen: int | None = None,
) -> int | bytes:
if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
return True
raise AssertionError() # pragma: no cover
@overload
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
@overload
def setsockopt(
self,
/,
level: int,
optname: int,
value: None,
optlen: int,
) -> None: ...
def setsockopt(
self,
/,
level: int,
optname: int,
value: int | Buffer | None,
optlen: int | None = None,
) -> None:
pass
async def bind(self, address: Any) -> None:
pass
def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
assert self.backlog is None
assert backlog is not None
self.backlog = backlog
if self.poison_listen:
raise FakeOSError("whoops")
def close(self) -> None:
self.closed = True
@attrs.define(slots=False)
class FakeSocketFactory(SocketFactory):
poison_after: int
sockets: list[tsocket.SocketType] = attrs.Factory(list)
raise_on_family: dict[AddressFamily, int] = attrs.Factory(dict) # family => errno
def socket(
self,
family: AddressFamily | int | None = None,
type_: SocketKind | int | None = None,
proto: int = 0,
) -> tsocket.SocketType:
assert family is not None
assert type_ is not None
if isinstance(family, int) and not isinstance(family, AddressFamily):
family = AddressFamily(family) # pragma: no cover
if family in self.raise_on_family:
raise OSError(self.raise_on_family[family], "nope")
sock = FakeSocket(family, type_, proto)
self.poison_after -= 1
if self.poison_after == 0:
sock.poison_listen = True
self.sockets.append(sock)
return sock
@attrs.define(slots=False)
class FakeHostnameResolver(HostnameResolver):
family_addr_pairs: Sequence[tuple[AddressFamily, str]]
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = 0,
type: int = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
assert isinstance(port, int)
return [
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
for family, addr in self.family_addr_pairs
]
async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> tuple[str, str]:
raise NotImplementedError()
async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
# If we were trying to bind to multiple hosts and one of them failed, they
# call get cleaned up before returning
fsf = FakeSocketFactory(3)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver(
[
(tsocket.AF_INET, "1.1.1.1"),
(tsocket.AF_INET, "2.2.2.2"),
(tsocket.AF_INET, "3.3.3.3"),
],
),
)
with pytest.raises(FakeOSError):
await open_tcp_listeners(80, host="example.org")
assert len(fsf.sockets) == 3
for sock in fsf.sockets:
# property only exists on FakeSocket
assert sock.closed # type: ignore[attr-defined]
async def test_open_tcp_listeners_port_checking() -> None:
for host in ["127.0.0.1", None]:
with pytest.raises(TypeError):
await open_tcp_listeners(None, host=host) # type: ignore[arg-type]
with pytest.raises(TypeError):
await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type]
with pytest.raises(TypeError):
await open_tcp_listeners("http", host=host) # type: ignore[arg-type]
async def test_serve_tcp() -> None:
async def handler(stream: SendStream) -> None:
await stream.send_all(b"x")
async with trio.open_nursery() as nursery:
# nursery.start is incorrectly typed, awaiting #2773
listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
stream = await open_stream_to_socket_listener(listeners[0])
async with stream:
assert await stream.receive_some(1) == b"x"
nursery.cancel_scope.cancel()
@pytest.mark.parametrize(
"try_families",
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
)
@pytest.mark.parametrize(
"fail_families",
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
)
async def test_open_tcp_listeners_some_address_families_unavailable(
try_families: set[AddressFamily],
fail_families: set[AddressFamily],
) -> None:
fsf = FakeSocketFactory(
10,
raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families},
)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver([(family, "foo") for family in try_families]),
)
should_succeed = try_families - fail_families
if not should_succeed:
with pytest.raises(OSError, match="This system doesn't support") as exc_info:
await open_tcp_listeners(80, host="example.org")
# open_listeners always creates an exceptiongroup with the
# unsupported address families, regardless of the value of
# strict_exception_groups or number of unsupported families.
assert isinstance(exc_info.value.__cause__, BaseExceptionGroup)
for subexc in exc_info.value.__cause__.exceptions:
assert "nope" in str(subexc)
else:
listeners = await open_tcp_listeners(80)
for listener in listeners:
should_succeed.remove(listener.socket.family)
assert not should_succeed
async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None:
fsf = FakeSocketFactory(
10,
raise_on_family={
tsocket.AF_INET: errno.EAFNOSUPPORT,
tsocket.AF_INET6: errno.EINVAL,
},
)
tsocket.set_custom_socket_factory(fsf)
tsocket.set_custom_hostname_resolver(
FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]),
)
with pytest.raises(OSError, match="nope") as exc_info:
await open_tcp_listeners(80, host="example.org")
assert exc_info.value.errno == errno.EINVAL
assert exc_info.value.__cause__ is None
assert "nope" in str(exc_info.value)
# We used to have an elaborate test that opened a real TCP listening socket
# and then tried to measure its backlog by making connections to it. And most
# of the time, it worked. But no matter what we tried, it was always fragile,
# because it had to do things like use timeouts to guess when the listening
# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
# effectively is no backlog), sometimes the host might not be enough resources
# to give us the full requested backlog... it was a mess. So now we just check
# that the backlog argument is passed through correctly.
async def test_open_tcp_listeners_backlog() -> None:
fsf = FakeSocketFactory(99)
tsocket.set_custom_socket_factory(fsf)
for given, expected in [
(None, 0xFFFF),
(99999999, 0xFFFF),
(10, 10),
(1, 1),
]:
listeners = await open_tcp_listeners(0, backlog=given)
assert listeners
for listener in listeners:
# `backlog` only exists on FakeSocket
assert listener.socket.backlog == expected # type: ignore[attr-defined]
async def test_open_tcp_listeners_backlog_float_error() -> None:
fsf = FakeSocketFactory(99)
tsocket.set_custom_socket_factory(fsf)
for should_fail in (0.0, 2.18, 3.14, 9.75):
with pytest.raises(
TypeError,
match=f"backlog must be an int or None, not {should_fail!r}",
):
await open_tcp_listeners(0, backlog=should_fail) # type: ignore[arg-type]

View File

@ -0,0 +1,686 @@
from __future__ import annotations
import socket
import sys
from socket import AddressFamily, SocketKind
from typing import TYPE_CHECKING, Any, Sequence
import attrs
import pytest
import trio
from trio._highlevel_open_tcp_stream import (
close_all,
format_host_port,
open_tcp_stream,
reorder_for_rfc_6555_section_5_4,
)
from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType
from trio.testing import Matcher, RaisesGroup
if TYPE_CHECKING:
from trio.testing import MockClock
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
def test_close_all() -> None:
class CloseMe(SocketType):
closed = False
def close(self) -> None:
self.closed = True
class CloseKiller(SocketType):
def close(self) -> None:
raise OSError("os error text")
c: CloseMe = CloseMe()
with close_all() as to_close:
to_close.add(c)
assert c.closed
c = CloseMe()
with pytest.raises(RuntimeError):
with close_all() as to_close:
to_close.add(c)
raise RuntimeError
assert c.closed
c = CloseMe()
with pytest.raises(OSError, match="os error text"):
with close_all() as to_close:
to_close.add(CloseKiller())
to_close.add(c)
assert c.closed
def test_reorder_for_rfc_6555_section_5_4() -> None:
def fake4(
i: int,
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
return (
AF_INET,
SOCK_STREAM,
IPPROTO_TCP,
"",
(f"10.0.0.{i}", 80),
)
def fake6(
i: int,
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80))
for fake in fake4, fake6:
# No effect on homogeneous lists
targets = [fake(0), fake(1), fake(2)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0), fake(1), fake(2)]
# Single item lists also OK
targets = [fake(0)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake(0)]
# If the list starts out with different families in positions 0 and 1,
# then it's left alone
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
targets = list(orig)
reorder_for_rfc_6555_section_5_4(targets)
assert targets == orig
# If not, it's reordered
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
reorder_for_rfc_6555_section_5_4(targets)
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
def test_format_host_port() -> None:
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
assert format_host_port("example.com", 443) == "example.com:443"
assert format_host_port(b"example.com", 443) == "example.com:443"
assert format_host_port("::1", "http") == "[::1]:http"
assert format_host_port(b"::1", "http") == "[::1]:http"
# Make sure we can connect to localhost using real kernel sockets
async def test_open_tcp_stream_real_socket_smoketest() -> None:
listen_sock = trio.socket.socket()
await listen_sock.bind(("127.0.0.1", 0))
_, listen_port = listen_sock.getsockname()
listen_sock.listen(1)
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
server_sock, _ = await listen_sock.accept()
await client_stream.send_all(b"x")
assert await server_sock.recv(1) == b"x"
await client_stream.aclose()
server_sock.close()
listen_sock.close()
async def test_open_tcp_stream_input_validation() -> None:
with pytest.raises(ValueError, match="^host must be str or bytes, not None$"):
await open_tcp_stream(None, 80) # type: ignore[arg-type]
with pytest.raises(TypeError):
await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type]
def can_bind_127_0_0_2() -> bool:
with socket.socket() as s:
try:
s.bind(("127.0.0.2", 0))
except OSError:
return False
# s.getsockname() is typed as returning Any
return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return]
async def test_local_address_real() -> None:
with trio.socket.socket() as listener:
await listener.bind(("127.0.0.1", 0))
listener.listen()
# It's hard to test local_address properly, because you need multiple
# local addresses that you can bind to. Fortunately, on most Linux
# systems, you can bind to any 127.*.*.* address, and they all go
# through the loopback interface. So we can use a non-standard
# loopback address. On other systems, the only address we know for
# certain we have is 127.0.0.1, so we can't really test local_address=
# properly -- passing local_address=127.0.0.1 is indistinguishable
# from not passing local_address= at all. But, we can still do a smoke
# test to make sure the local_address= code doesn't crash.
local_address = "127.0.0.2" if can_bind_127_0_0_2() else "127.0.0.1"
async with await open_tcp_stream(
*listener.getsockname(),
local_address=local_address,
) as client_stream:
assert client_stream.socket.getsockname()[0] == local_address
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
assert client_stream.socket.getsockopt(
trio.socket.IPPROTO_IP,
trio.socket.IP_BIND_ADDRESS_NO_PORT,
)
server_sock, remote_addr = await listener.accept()
await client_stream.aclose()
server_sock.close()
# accept returns tuple[SocketType, object], due to typeshed returning `Any`
assert remote_addr[0] == local_address
# Trying to connect to an ipv4 address with the ipv6 wildcard
# local_address should fail
with pytest.raises(
OSError,
match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$",
):
await open_tcp_stream(*listener.getsockname(), local_address="::")
# But the ipv4 wildcard address should work
async with await open_tcp_stream(
*listener.getsockname(),
local_address="0.0.0.0",
) as client_stream:
server_sock, remote_addr = await listener.accept()
server_sock.close()
assert remote_addr == client_stream.socket.getsockname()
# Now, thorough tests using fake sockets
@attrs.define(eq=False, slots=False)
class FakeSocket(trio.socket.SocketType):
scenario: Scenario
_family: AddressFamily
_type: SocketKind
_proto: int
ip: str | int | None = None
port: str | int | None = None
succeeded: bool = False
closed: bool = False
failing: bool = False
@property
def type(self) -> SocketKind:
return self._type
@property
def family(self) -> AddressFamily: # pragma: no cover
return self._family
@property
def proto(self) -> int: # pragma: no cover
return self._proto
async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None:
self.ip = sockaddr[0]
self.port = sockaddr[1]
assert self.ip not in self.scenario.sockets
self.scenario.sockets[self.ip] = self
self.scenario.connect_times[self.ip] = trio.current_time()
delay, result = self.scenario.ip_dict[self.ip]
await trio.sleep(delay)
if result == "error":
raise OSError("sorry")
if result == "postconnect_fail":
self.failing = True
self.succeeded = True
def close(self) -> None:
self.closed = True
# called when SocketStream is constructed
def setsockopt(self, *args: object, **kwargs: object) -> None:
if self.failing:
# raise something that isn't OSError as SocketStream
# ignores those
raise KeyboardInterrupt
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
def __init__(
self,
port: int,
ip_list: Sequence[tuple[str, float, str]],
supported_families: set[AddressFamily],
) -> None:
# ip_list have to be unique
ip_order = [ip for (ip, _, _) in ip_list]
assert len(set(ip_order)) == len(ip_list)
ip_dict: dict[str | int, tuple[float, str]] = {}
for ip, delay, result in ip_list:
assert delay >= 0
assert result in ["error", "success", "postconnect_fail"]
ip_dict[ip] = (delay, result)
self.port = port
self.ip_order = ip_order
self.ip_dict = ip_dict
self.supported_families = supported_families
self.socket_count = 0
self.sockets: dict[str | int, FakeSocket] = {}
self.connect_times: dict[str | int, float] = {}
def socket(
self,
family: AddressFamily | int | None = None,
type_: SocketKind | int | None = None,
proto: int | None = None,
) -> SocketType:
assert isinstance(family, AddressFamily)
assert isinstance(type_, SocketKind)
assert proto is not None
if family not in self.supported_families:
raise OSError("pretending not to support this family")
self.socket_count += 1
return FakeSocket(self, family, type_, proto)
def _ip_to_gai_entry(self, ip: str) -> tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int, int, int] | tuple[str, int],
]:
sockaddr: tuple[str, int] | tuple[str, int, int, int]
if ":" in ip:
family = trio.socket.AF_INET6
sockaddr = (ip, self.port, 0, 0)
else:
family = trio.socket.AF_INET
sockaddr = (ip, self.port)
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = -1,
type: int = -1,
proto: int = -1,
flags: int = -1,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int, int, int] | tuple[str, int],
]
]:
assert host == b"test.example.com"
assert port == self.port
assert family == trio.socket.AF_UNSPEC
assert type == trio.socket.SOCK_STREAM
assert proto == 0
assert flags == 0
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
async def getnameinfo(
self,
sockaddr: tuple[str, int] | tuple[str, int, int, int],
flags: int,
) -> tuple[str, str]:
raise NotImplementedError
def check(self, succeeded: SocketType | None) -> None:
# sockets only go into self.sockets when connect is called; make sure
# all the sockets that were created did in fact go in there.
assert self.socket_count == len(self.sockets)
for ip, socket_ in self.sockets.items():
assert ip in self.ip_dict
if socket_ is not succeeded:
assert socket_.closed
assert socket_.port == self.port
async def run_scenario(
# The port to connect to
port: int,
# A list of
# (ip, delay, result)
# tuples, where delay is in seconds and result is "success" or "error"
# The ip's will be returned from getaddrinfo in this order, and then
# connect() calls to them will have the given result.
ip_list: Sequence[tuple[str, float, str]],
*,
# If False, AF_INET4/6 sockets error out on creation, before connect is
# even called.
ipv4_supported: bool = True,
ipv6_supported: bool = True,
# Normally, we return (winning_sock, scenario object)
# If this is True, we require there to be an exception, and return
# (exception, scenario object)
expect_error: tuple[type[BaseException], ...] | type[BaseException] = (),
**kwargs: Any,
) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]:
supported_families = set()
if ipv4_supported:
supported_families.add(trio.socket.AF_INET)
if ipv6_supported:
supported_families.add(trio.socket.AF_INET6)
scenario = Scenario(port, ip_list, supported_families)
trio.socket.set_custom_hostname_resolver(scenario)
trio.socket.set_custom_socket_factory(scenario)
try:
stream = await open_tcp_stream("test.example.com", port, **kwargs)
assert expect_error == ()
scenario.check(stream.socket)
return (stream.socket, scenario)
except AssertionError: # pragma: no cover
raise
except expect_error as exc:
scenario.check(None)
return (exc, scenario)
async def test_one_host_quick_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
assert isinstance(sock, FakeSocket)
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 0.123
async def test_one_host_slow_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
assert isinstance(sock, FakeSocket)
assert sock.ip == "1.2.3.4"
assert trio.current_time() == 100
async def test_one_host_quick_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
82,
[("1.2.3.4", 0.123, "error")],
expect_error=OSError,
)
assert isinstance(exc, OSError)
assert trio.current_time() == 0.123
async def test_one_host_slow_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
83,
[("1.2.3.4", 100, "error")],
expect_error=OSError,
)
assert isinstance(exc, OSError)
assert trio.current_time() == 100
async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
83,
[("1.2.3.4", 1, "postconnect_fail")],
expect_error=KeyboardInterrupt,
)
assert isinstance(exc, KeyboardInterrupt)
# With the default 0.250 second delay, the third attempt will win
async def test_basic_fallthrough(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "3.3.3.3"
# current time is default time + default time + connection time
assert trio.current_time() == (0.250 + 0.250 + 0.2)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
"3.3.3.3": 0.500,
}
async def test_early_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 0.1, "success"),
("3.3.3.3", 0.2, "success"),
],
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "2.2.2.2"
assert trio.current_time() == (0.250 + 0.1)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
# 3.3.3.3 was never even started
}
# With a 0.450 second delay, the first attempt will win
async def test_custom_delay(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
happy_eyeballs_delay=0.450,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "1.1.1.1"
assert trio.current_time() == 1
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.450,
"3.3.3.3": 0.900,
}
async def test_none_default(autojump_clock: MockClock) -> None:
"""Copy of test_basic_fallthrough, but specifying the delay =None"""
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 1, "success"),
("2.2.2.2", 1, "success"),
("3.3.3.3", 0.2, "success"),
],
happy_eyeballs_delay=None,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "3.3.3.3"
# current time is default time + default time + connection time
assert trio.current_time() == (0.250 + 0.250 + 0.2)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.250,
"3.3.3.3": 0.500,
}
async def test_custom_errors_expedite(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "success"),
# .25 is the default timeout
("4.4.4.4", 0.25, "success"),
],
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_all_fail(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.1, "error"),
("2.2.2.2", 0.2, "error"),
("3.3.3.3", 10, "error"),
("4.4.4.4", 0.250, "error"),
],
expect_error=OSError,
)
assert isinstance(exc, OSError)
subexceptions = (Matcher(OSError, match="^sorry$"),) * 4
assert RaisesGroup(
*subexceptions,
match="all attempts to connect to test.example.com:80 failed",
).matches(exc.__cause__)
assert trio.current_time() == (0.1 + 0.2 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.1,
"3.3.3.3": 0.1 + 0.2,
"4.4.4.4": 0.1 + 0.2 + 0.25,
}
async def test_multi_success(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 0.5, "error"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10 - 1, "success"),
("4.4.4.4", 10 - 2, "success"),
("5.5.5.5", 0.5, "error"),
],
happy_eyeballs_delay=1,
)
assert not scenario.sockets["1.1.1.1"].succeeded
assert (
scenario.sockets["2.2.2.2"].succeeded
or scenario.sockets["3.3.3.3"].succeeded
or scenario.sockets["4.4.4.4"].succeeded
)
assert not scenario.sockets["5.5.5.5"].succeeded
assert isinstance(sock, FakeSocket)
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
assert trio.current_time() == (0.5 + 10)
assert scenario.connect_times == {
"1.1.1.1": 0,
"2.2.2.2": 0.5,
"3.3.3.3": 1.5,
"4.4.4.4": 2.5,
"5.5.5.5": 3.5,
}
async def test_does_reorder(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "error"),
# This would win if we tried it first...
("2.2.2.2", 1, "success"),
# But in fact we try this first, because of section 5.4
("::3", 0.5, "success"),
],
happy_eyeballs_delay=1,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.5
assert scenario.connect_times == {
"1.1.1.1": 0,
"::3": 1,
}
async def test_handles_no_ipv4(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 10, "success"),
("2.2.2.2", 0, "success"),
("::3", 0.1, "success"),
("4.4.4.4", 0, "success"),
],
happy_eyeballs_delay=1,
ipv4_supported=False,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "::3"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"::1": 0,
"::3": 1.0,
}
async def test_handles_no_ipv6(autojump_clock: MockClock) -> None:
sock, scenario = await run_scenario(
80,
# Here the ipv6 addresses fail at socket creation time, so the connect
# configuration doesn't matter
[
("::1", 0, "success"),
("2.2.2.2", 10, "success"),
("::3", 0, "success"),
("4.4.4.4", 0.1, "success"),
],
happy_eyeballs_delay=1,
ipv6_supported=False,
)
assert isinstance(sock, FakeSocket)
assert sock.ip == "4.4.4.4"
assert trio.current_time() == 1 + 0.1
assert scenario.connect_times == {
"2.2.2.2": 0,
"4.4.4.4": 1.0,
}
async def test_no_hosts(autojump_clock: MockClock) -> None:
exc, scenario = await run_scenario(80, [], expect_error=OSError)
assert "no results found" in str(exc)
async def test_cancel(autojump_clock: MockClock) -> None:
with trio.move_on_after(5) as cancel_scope:
exc, scenario = await run_scenario(
80,
[
("1.1.1.1", 10, "success"),
("2.2.2.2", 10, "success"),
("3.3.3.3", 10, "success"),
("4.4.4.4", 10, "success"),
],
expect_error=BaseExceptionGroup,
)
assert isinstance(exc, BaseException)
# What comes out should be 1 or more Cancelled errors that all belong
# to this cancel_scope; this is the easiest way to check that
raise exc
assert cancel_scope.cancelled_caught
assert trio.current_time() == 5
# This should have been called already, but just to make sure, since the
# exception-handling logic in run_scenario is a bit complicated and the
# main thing we care about here is that all the sockets were cleaned up.
scenario.check(succeeded=None)

View File

@ -0,0 +1,86 @@
import os
import socket
import sys
import tempfile
from typing import TYPE_CHECKING
import pytest
from trio import Path, open_unix_socket
from trio._highlevel_open_unix_stream import close_on_error
assert not TYPE_CHECKING or sys.platform != "win32"
skip_if_not_unix = pytest.mark.skipif(
not hasattr(socket, "AF_UNIX"),
reason="Needs unix socket support",
)
@skip_if_not_unix
def test_close_on_error() -> None:
class CloseMe:
closed = False
def close(self) -> None:
self.closed = True
with close_on_error(CloseMe()) as c:
pass
assert not c.closed
with pytest.raises(RuntimeError):
with close_on_error(CloseMe()) as c:
raise RuntimeError
assert c.closed
@skip_if_not_unix
@pytest.mark.parametrize("filename", [4, 4.5])
async def test_open_with_bad_filename_type(filename: float) -> None:
with pytest.raises(TypeError):
await open_unix_socket(filename) # type: ignore[arg-type]
@skip_if_not_unix
async def test_open_bad_socket() -> None:
# mktemp is marked as insecure, but that's okay, we don't want the file to
# exist
name = tempfile.mktemp()
with pytest.raises(FileNotFoundError):
await open_unix_socket(name)
@skip_if_not_unix
async def test_open_unix_socket() -> None:
for name_type in [Path, str]:
name = tempfile.mktemp()
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
with serv_sock:
serv_sock.bind(name)
try:
serv_sock.listen(1)
# The actual function we're testing
unix_socket = await open_unix_socket(name_type(name))
async with unix_socket:
client, _ = serv_sock.accept()
with client:
await unix_socket.send_all(b"test")
assert client.recv(2048) == b"test"
client.sendall(b"response")
received = await unix_socket.receive_some(2048)
assert received == b"response"
finally:
os.unlink(name)
@pytest.mark.skipif(hasattr(socket, "AF_UNIX"), reason="Test for non-unix platforms")
async def test_error_on_no_unix() -> None:
with pytest.raises(
RuntimeError,
match="^Unix sockets are not supported on this platform$",
):
await open_unix_socket("")

View File

@ -0,0 +1,183 @@
from __future__ import annotations
import errno
from functools import partial
from typing import TYPE_CHECKING, Awaitable, Callable, NoReturn
import attrs
import trio
from trio import Nursery, StapledStream, TaskStatus
from trio.testing import (
Matcher,
MemoryReceiveStream,
MemorySendStream,
MockClock,
RaisesGroup,
memory_stream_pair,
wait_all_tasks_blocked,
)
if TYPE_CHECKING:
import pytest
from trio._channel import MemoryReceiveChannel, MemorySendChannel
from trio.abc import Stream
# types are somewhat tentative - I just bruteforced them until I got something that didn't
# give errors
StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream]
@attrs.define(eq=False, slots=False)
class MemoryListener(trio.abc.Listener[StapledMemoryStream]):
closed: bool = False
accepted_streams: list[trio.abc.Stream] = attrs.Factory(list)
queued_streams: tuple[
MemorySendChannel[StapledMemoryStream],
MemoryReceiveChannel[StapledMemoryStream],
] = attrs.Factory(lambda: trio.open_memory_channel[StapledMemoryStream](1))
accept_hook: Callable[[], Awaitable[object]] | None = None
async def connect(self) -> StapledMemoryStream:
assert not self.closed
client, server = memory_stream_pair()
await self.queued_streams[0].send(server)
return client
async def accept(self) -> StapledMemoryStream:
await trio.lowlevel.checkpoint()
assert not self.closed
if self.accept_hook is not None:
await self.accept_hook()
stream = await self.queued_streams[1].receive()
self.accepted_streams.append(stream)
return stream
async def aclose(self) -> None:
self.closed = True
await trio.lowlevel.checkpoint()
async def test_serve_listeners_basic() -> None:
listeners = [MemoryListener(), MemoryListener()]
record = []
def close_hook() -> None:
# Make sure this is a forceful close
assert trio.current_effective_deadline() == float("-inf")
record.append("closed")
async def handler(stream: StapledMemoryStream) -> None:
await stream.send_all(b"123")
assert await stream.receive_some(10) == b"456"
stream.send_stream.close_hook = close_hook
stream.receive_stream.close_hook = close_hook
async def client(listener: MemoryListener) -> None:
s = await listener.connect()
assert await s.receive_some(10) == b"123"
await s.send_all(b"456")
async def do_tests(parent_nursery: Nursery) -> None:
async with trio.open_nursery() as nursery:
for listener in listeners:
for _ in range(3):
nursery.start_soon(client, listener)
await wait_all_tasks_blocked()
# verifies that all 6 streams x 2 directions each were closed ok
assert len(record) == 12
parent_nursery.cancel_scope.cancel()
async with trio.open_nursery() as nursery:
l2: list[MemoryListener] = await nursery.start(
trio.serve_listeners,
handler,
listeners,
)
assert l2 == listeners
# This is just split into another function because gh-136 isn't
# implemented yet
nursery.start_soon(do_tests, nursery)
for listener in listeners:
assert listener.closed
async def test_serve_listeners_accept_unrecognized_error() -> None:
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
listener = MemoryListener()
async def raise_error() -> NoReturn:
raise error # noqa: B023 # Set from loop
def check_error(e: BaseException) -> bool:
return e is error # noqa: B023
listener.accept_hook = raise_error
with RaisesGroup(Matcher(check=check_error)):
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
async def test_serve_listeners_accept_capacity_error(
autojump_clock: MockClock,
caplog: pytest.LogCaptureFixture,
) -> None:
listener = MemoryListener()
async def raise_EMFILE() -> NoReturn:
raise OSError(errno.EMFILE, "out of file descriptors")
listener.accept_hook = raise_EMFILE
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
# = 10 times total
with trio.move_on_after(0.950):
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
assert len(caplog.records) == 10
for record in caplog.records:
assert "retrying" in record.msg
assert record.exc_info is not None
assert isinstance(record.exc_info[1], OSError)
assert record.exc_info[1].errno == errno.EMFILE
async def test_serve_listeners_connection_nursery(autojump_clock: MockClock) -> None:
listener = MemoryListener()
async def handler(stream: Stream) -> None:
await trio.sleep(1)
class Done(Exception):
pass
async def connection_watcher(
*,
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
) -> NoReturn:
async with trio.open_nursery() as nursery:
task_status.started(nursery)
await wait_all_tasks_blocked()
assert len(nursery.child_tasks) == 10
raise Done
# the exception is wrapped twice because we open two nested nurseries
with RaisesGroup(RaisesGroup(Done)):
async with trio.open_nursery() as nursery:
handler_nursery: trio.Nursery = await nursery.start(connection_watcher)
await nursery.start(
partial(
trio.serve_listeners,
handler,
[listener],
handler_nursery=handler_nursery,
),
)
for _ in range(10):
nursery.start_soon(listener.connect)

View File

@ -0,0 +1,330 @@
from __future__ import annotations
import errno
import socket as stdlib_socket
import sys
from typing import Sequence
import pytest
from .. import _core, socket as tsocket
from .._highlevel_socket import *
from ..testing import (
assert_checkpoints,
check_half_closeable_stream,
wait_all_tasks_blocked,
)
from .test_socket import setsockopt_tests
async def test_SocketStream_basics() -> None:
# stdlib socket bad (even if connected)
stdlib_a, stdlib_b = stdlib_socket.socketpair()
with stdlib_a, stdlib_b:
with pytest.raises(TypeError):
SocketStream(stdlib_a) # type: ignore[arg-type]
# DGRAM socket bad
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
with pytest.raises(
ValueError,
match="^SocketStream requires a SOCK_STREAM socket$",
):
# TODO: does not raise an error?
SocketStream(sock)
a, b = tsocket.socketpair()
with a, b:
s = SocketStream(a)
assert s.socket is a
# Use a real, connected socket to test socket options, because
# socketpair() might give us a unix socket that doesn't support any of
# these options
with tsocket.socket() as listen_sock:
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(1)
with tsocket.socket() as client_sock:
await client_sock.connect(listen_sock.getsockname())
s = SocketStream(client_sock)
# TCP_NODELAY enabled by default
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
# We can disable it though
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
assert isinstance(res, bytes)
setsockopt_tests(s)
async def test_SocketStream_send_all() -> None:
BIG = 10000000
a_sock, b_sock = tsocket.socketpair()
with a_sock, b_sock:
a = SocketStream(a_sock)
b = SocketStream(b_sock)
# Check a send_all that has to be split into multiple parts (on most
# platforms... on Windows every send() either succeeds or fails as a
# whole)
async def sender() -> None:
data = bytearray(BIG)
await a.send_all(data)
# send_all uses memoryviews internally, which temporarily "lock"
# the object they view. If it doesn't clean them up properly, then
# some bytearray operations might raise an error afterwards, which
# would be a pretty weird and annoying side-effect to spring on
# users. So test that this doesn't happen, by forcing the
# bytearray's underlying buffer to be realloc'ed:
data += bytes(BIG)
# (Note: the above line of code doesn't do a very good job at
# testing anything, because:
# - on CPython, the refcount GC generally cleans up memoryviews
# for us even if we're sloppy.
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
# bytearray code conspire so that resizing never fails if
# resizing forces the bytearray's internal buffer to move, then
# all memoryview references are automagically updated (!!).
# See:
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
# But I'm leaving the test here in hopes that if this ever changes
# and we break our implementation of send_all, then we'll get some
# early warning...)
async def receiver() -> None:
# Make sure the sender fills up the kernel buffers and blocks
await wait_all_tasks_blocked()
nbytes = 0
while nbytes < BIG:
nbytes += len(await b.receive_some(BIG))
assert nbytes == BIG
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(receiver)
# We know that we received BIG bytes of NULs so far. Make sure that
# was all the data in there.
await a.send_all(b"e")
assert await b.receive_some(10) == b"e"
await a.send_eof()
assert await b.receive_some(10) == b""
async def fill_stream(s: SocketStream) -> None:
async def sender() -> None:
while True:
await s.send_all(b"x" * 10000)
async def waiter(nursery: _core.Nursery) -> None:
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
async with _core.open_nursery() as nursery:
nursery.start_soon(sender)
nursery.start_soon(waiter, nursery)
async def test_SocketStream_generic() -> None:
async def stream_maker() -> tuple[SocketStream, SocketStream]:
left, right = tsocket.socketpair()
return SocketStream(left), SocketStream(right)
async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]:
left, right = await stream_maker()
await fill_stream(left)
await fill_stream(right)
return left, right
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
async def test_SocketListener() -> None:
# Not a Trio socket
with stdlib_socket.socket() as s:
s.bind(("127.0.0.1", 0))
s.listen(10)
with pytest.raises(TypeError):
SocketListener(s) # type: ignore[arg-type]
# Not a SOCK_STREAM
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
await s.bind(("127.0.0.1", 0))
with pytest.raises(
ValueError,
match="^SocketListener requires a SOCK_STREAM socket$",
) as excinfo:
SocketListener(s)
excinfo.match(r".*SOCK_STREAM")
# Didn't call .listen()
# macOS has no way to check for this, so skip testing it there.
if sys.platform != "darwin":
with tsocket.socket() as s:
await s.bind(("127.0.0.1", 0))
with pytest.raises(
ValueError,
match="^SocketListener requires a listening socket$",
) as excinfo:
SocketListener(s)
excinfo.match(r".*listen")
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(10)
listener = SocketListener(listen_sock)
assert listener.socket is listen_sock
client_sock = tsocket.socket()
await client_sock.connect(listen_sock.getsockname())
with assert_checkpoints():
server_stream = await listener.accept()
assert isinstance(server_stream, SocketStream)
assert server_stream.socket.getsockname() == listen_sock.getsockname()
assert server_stream.socket.getpeername() == client_sock.getsockname()
with assert_checkpoints():
await listener.aclose()
with assert_checkpoints():
await listener.aclose()
with assert_checkpoints():
with pytest.raises(_core.ClosedResourceError):
await listener.accept()
client_sock.close()
await server_stream.aclose()
async def test_SocketListener_socket_closed_underfoot() -> None:
listen_sock = tsocket.socket()
await listen_sock.bind(("127.0.0.1", 0))
listen_sock.listen(10)
listener = SocketListener(listen_sock)
# Close the socket, not the listener
listen_sock.close()
# SocketListener gives correct error
with assert_checkpoints():
with pytest.raises(_core.ClosedResourceError):
await listener.accept()
async def test_SocketListener_accept_errors() -> None:
class FakeSocket(tsocket.SocketType):
def __init__(self, events: Sequence[SocketType | BaseException]) -> None:
self._events = iter(events)
type = tsocket.SOCK_STREAM
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
@overload
def getsockopt(self, /, level: int, optname: int) -> int: ...
@overload
def getsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
buflen: int,
) -> bytes: ...
def getsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
buflen: int | None = None,
) -> int | bytes:
return True
@overload
def setsockopt(
self,
/,
level: int,
optname: int,
value: int | Buffer,
) -> None: ...
@overload
def setsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
value: None,
optlen: int,
) -> None: ...
def setsockopt( # noqa: F811
self,
/,
level: int,
optname: int,
value: int | Buffer | None,
optlen: int | None = None,
) -> None:
pass
async def accept(self) -> tuple[SocketType, object]:
await _core.checkpoint()
event = next(self._events)
if isinstance(event, BaseException):
raise event
else:
return event, None
fake_server_sock = FakeSocket([])
fake_listen_sock = FakeSocket(
[
OSError(errno.ECONNABORTED, "Connection aborted"),
OSError(errno.EPERM, "Permission denied"),
OSError(errno.EPROTO, "Bad protocol"),
fake_server_sock,
OSError(errno.EMFILE, "Out of file descriptors"),
OSError(errno.EFAULT, "attempt to write to read-only memory"),
OSError(errno.ENOBUFS, "out of buffers"),
fake_server_sock,
],
)
listener = SocketListener(fake_listen_sock)
with assert_checkpoints():
stream = await listener.accept()
assert stream.socket is fake_server_sock
for code, match in {
errno.EMFILE: r"\[\w+ \d+\] Out of file descriptors$",
errno.EFAULT: r"\[\w+ \d+\] attempt to write to read-only memory$",
errno.ENOBUFS: r"\[\w+ \d+\] out of buffers$",
}.items():
with assert_checkpoints():
with pytest.raises(OSError, match=match) as excinfo:
await listener.accept()
assert excinfo.value.errno == code
with assert_checkpoints():
stream = await listener.accept()
assert stream.socket is fake_server_sock
async def test_socket_stream_works_when_peer_has_already_closed() -> None:
sock_a, sock_b = tsocket.socketpair()
with sock_a, sock_b:
await sock_b.send(b"x")
sock_b.close()
stream = SocketStream(sock_a)
assert await stream.receive_some(1) == b"x"
assert await stream.receive_some(1) == b""

View File

@ -0,0 +1,166 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, NoReturn
import attrs
import pytest
import trio
import trio.testing
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM
from .._highlevel_ssl_helpers import (
open_ssl_over_tcp_listeners,
open_ssl_over_tcp_stream,
serve_ssl_over_tcp,
)
# using noqa because linters don't understand how pytest fixtures work.
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
if TYPE_CHECKING:
from socket import AddressFamily, SocketKind
from ssl import SSLContext
from trio.abc import Stream
from .._highlevel_socket import SocketListener
from .._ssl import SSLListener
async def echo_handler(stream: Stream) -> None:
async with stream:
try:
while True:
data = await stream.receive_some(10000)
if not data:
break
await stream.send_all(data)
except trio.BrokenResourceError:
pass
# Resolver that always returns the given sockaddr, no matter what host/port
# you ask for.
@attrs.define(slots=False)
class FakeHostnameResolver(trio.abc.HostnameResolver):
sockaddr: tuple[str, int] | tuple[str, int, int, int]
async def getaddrinfo(
self,
host: bytes | None,
port: bytes | str | int | None,
family: int = 0,
type: int = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover
raise NotImplementedError
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
# using noqa because linters don't understand how pytest fixtures work.
async def test_open_ssl_over_tcp_stream_and_everything_else(
client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture
) -> None:
async with trio.open_nursery() as nursery:
# TODO: this function wraps an SSLListener around a SocketListener, this is illegal
# according to current type hints, and probably for good reason. But there should
# maybe be a different wrapper class/function that could be used instead?
res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var]
await nursery.start(
partial(
serve_ssl_over_tcp,
echo_handler,
0,
SERVER_CTX,
host="127.0.0.1",
),
)
)
(listener,) = res
async with listener:
# listener.transport_listener is of type Listener[Stream]
tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
sockaddr = tp_listener.socket.getsockname()
hostname_resolver = FakeHostnameResolver(sockaddr)
trio.socket.set_custom_hostname_resolver(hostname_resolver)
# We don't have the right trust set up
# (checks that ssl_context=None is doing some validation)
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# We have the trust but not the hostname
# (checks custom ssl_context + hostname checking)
stream = await open_ssl_over_tcp_stream(
"xyzzy.example.org",
80,
ssl_context=client_ctx,
)
async with stream:
with pytest.raises(trio.BrokenResourceError):
await stream.do_handshake()
# This one should work!
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
)
async with stream:
assert isinstance(stream, trio.SSLStream)
assert stream.server_hostname == "trio-test-1.example.org"
await stream.send_all(b"x")
assert await stream.receive_some(1) == b"x"
# Check https_compatible settings are being passed through
assert not stream._https_compatible
stream = await open_ssl_over_tcp_stream(
"trio-test-1.example.org",
80,
ssl_context=client_ctx,
https_compatible=True,
# also, smoke test happy_eyeballs_delay
happy_eyeballs_delay=1,
)
async with stream:
assert stream._https_compatible
# Stop the echo server
nursery.cancel_scope.cancel()
async def test_open_ssl_over_tcp_listeners() -> None:
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
async with listener:
assert isinstance(listener, trio.SSLListener)
tl = listener.transport_listener
assert isinstance(tl, trio.SocketListener)
assert tl.socket.getsockname()[0] == "127.0.0.1"
assert not listener._https_compatible
(listener,) = await open_ssl_over_tcp_listeners(
0,
SERVER_CTX,
host="127.0.0.1",
https_compatible=True,
)
async with listener:
assert listener._https_compatible

View File

@ -0,0 +1,275 @@
from __future__ import annotations
import os
import pathlib
from typing import TYPE_CHECKING, Type, Union
import pytest
import trio
from trio._file_io import AsyncIOWrapper
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
@pytest.fixture
def path(tmp_path: pathlib.Path) -> trio.Path:
return trio.Path(tmp_path / "test")
def method_pair(
path: str,
method_name: str,
) -> tuple[Callable[[], object], Callable[[], Awaitable[object]]]:
sync_path = pathlib.Path(path)
async_path = trio.Path(path)
return getattr(sync_path, method_name), getattr(async_path, method_name)
@pytest.mark.skipif(os.name == "nt", reason="OS is not posix")
async def test_instantiate_posix() -> None:
assert isinstance(trio.Path(), trio.PosixPath)
@pytest.mark.skipif(os.name != "nt", reason="OS is not Windows")
async def test_instantiate_windows() -> None:
assert isinstance(trio.Path(), trio.WindowsPath)
async def test_open_is_async_context_manager(path: trio.Path) -> None:
async with await path.open("w") as f:
assert isinstance(f, AsyncIOWrapper)
assert f.closed
async def test_magic() -> None:
path = trio.Path("test")
assert str(path) == "test"
assert bytes(path) == b"test"
EitherPathType = Union[Type[trio.Path], Type[pathlib.Path]]
PathOrStrType = Union[EitherPathType, Type[str]]
cls_pairs: list[tuple[EitherPathType, EitherPathType]] = [
(trio.Path, pathlib.Path),
(pathlib.Path, trio.Path),
(trio.Path, trio.Path),
]
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs)
async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None:
a, b = cls_a(""), cls_b("")
assert a == b
assert not a != b # noqa: SIM202 # negate-not-equal-op
a, b = cls_a("a"), cls_b("b")
assert a < b
assert b > a
# this is intentionally testing equivalence with none, due to the
# other=sentinel logic in _forward_magic
assert not a == None # noqa
assert not b == None # noqa
# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
# __*div__ does not properly raise NotImplementedError like the other comparison
# magic, so trio.Path's implementation does not get dispatched
cls_pairs_str: list[tuple[PathOrStrType, PathOrStrType]] = [
(trio.Path, pathlib.Path),
(trio.Path, trio.Path),
(trio.Path, str),
(str, trio.Path),
]
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs_str)
async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None:
a, b = cls_a("a"), cls_b("b")
result = a / b # type: ignore[operator]
# Type checkers think str / str could happen. Check each combo manually in type_tests/.
assert isinstance(result, trio.Path)
assert str(result) == os.path.join("a", "b")
@pytest.mark.parametrize(
("cls_a", "cls_b"),
[(trio.Path, pathlib.Path), (trio.Path, trio.Path)],
)
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
async def test_hash_magic(
cls_a: EitherPathType,
cls_b: EitherPathType,
path: str,
) -> None:
a, b = cls_a(path), cls_b(path)
assert hash(a) == hash(b)
async def test_forwarded_properties(path: trio.Path) -> None:
# use `name` as a representative of forwarded properties
assert "name" in dir(path)
assert path.name == "test"
async def test_async_method_signature(path: trio.Path) -> None:
# use `resolve` as a representative of wrapped methods
assert path.resolve.__name__ == "resolve"
assert path.resolve.__qualname__ == "Path.resolve"
assert path.resolve.__doc__ is not None
assert path.resolve.__qualname__ in path.resolve.__doc__
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
async def test_compare_async_stat_methods(method_name: str) -> None:
method, async_method = method_pair(".", method_name)
result = method()
async_result = await async_method()
assert result == async_result
async def test_invalid_name_not_wrapped(path: trio.Path) -> None:
with pytest.raises(AttributeError):
getattr(path, "invalid_fake_attr") # noqa: B009 # "get-attr-with-constant"
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
async def test_async_methods_rewrap(method_name: str) -> None:
method, async_method = method_pair(".", method_name)
result = method()
async_result = await async_method()
assert isinstance(async_result, trio.Path)
assert str(result) == str(async_result)
async def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) -> None:
with_name = path.with_name("foo")
with_suffix = path.with_suffix(".py")
assert isinstance(with_name, trio.Path)
assert with_name == tmp_path / "foo"
assert isinstance(with_suffix, trio.Path)
assert with_suffix == tmp_path / "test.py"
async def test_forward_properties_rewrap(path: trio.Path) -> None:
assert isinstance(path.parent, trio.Path)
async def test_forward_methods_without_rewrap(path: trio.Path) -> None:
path = await path.parent.resolve()
assert path.as_uri().startswith("file:///")
async def test_repr() -> None:
path = trio.Path(".")
assert repr(path) == "trio.Path('.')"
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
async def test_path_wraps_path(
path: trio.Path,
meth: Callable[[trio.Path, trio.Path], object],
) -> None:
wrapped = await path.absolute()
result = meth(path, wrapped)
if result is None:
result = path
assert wrapped == result
async def test_path_nonpath() -> None:
with pytest.raises(TypeError):
trio.Path(1) # type: ignore
async def test_open_file_can_open_path(path: trio.Path) -> None:
async with await trio.open_file(path, "w") as f:
assert f.name == os.fspath(path)
async def test_globmethods(path: trio.Path) -> None:
# Populate a directory tree
await path.mkdir()
await (path / "foo").mkdir()
await (path / "foo" / "_bar.txt").write_bytes(b"")
await (path / "bar.txt").write_bytes(b"")
await (path / "bar.dat").write_bytes(b"")
# Path.glob
for _pattern, _results in {
"*.txt": {"bar.txt"},
"**/*.txt": {"_bar.txt", "bar.txt"},
}.items():
entries = set()
for entry in await path.glob(_pattern):
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == _results
# Path.rglob
entries = set()
for entry in await path.rglob("*.txt"):
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == {"_bar.txt", "bar.txt"}
async def test_iterdir(path: trio.Path) -> None:
# Populate a directory
await path.mkdir()
await (path / "foo").mkdir()
await (path / "bar.txt").write_bytes(b"")
entries = set()
for entry in await path.iterdir():
assert isinstance(entry, trio.Path)
entries.add(entry.name)
assert entries == {"bar.txt", "foo"}
async def test_classmethods() -> None:
assert isinstance(await trio.Path.home(), trio.Path)
# pathlib.Path has only two classmethods
assert str(await trio.Path.home()) == os.path.expanduser("~")
assert str(await trio.Path.cwd()) == os.getcwd()
# Wrapped method has docstring
assert trio.Path.home.__doc__
@pytest.mark.parametrize(
"wrapper",
[
trio._path._wraps_async,
trio._path._wrap_method,
trio._path._wrap_method_path,
trio._path._wrap_method_path_iterable,
],
)
def test_wrapping_without_docstrings(
wrapper: Callable[[Callable[[], None]], Callable[[], None]],
) -> None:
@wrapper
def func_without_docstring() -> None: ... # pragma: no cover
assert func_without_docstring.__doc__ is None

View File

@ -0,0 +1,242 @@
from __future__ import annotations
import subprocess
import sys
from typing import Protocol
import pytest
import trio._repl
class RawInput(Protocol):
def __call__(self, prompt: str = "") -> str: ...
def build_raw_input(cmds: list[str]) -> RawInput:
"""
Pass in a list of strings.
Returns a callable that returns each string, each time its called
When there are not more strings to return, raise EOFError
"""
cmds_iter = iter(cmds)
prompts = []
def _raw_helper(prompt: str = "") -> str:
prompts.append(prompt)
try:
return next(cmds_iter)
except StopIteration:
raise EOFError from None
return _raw_helper
def test_build_raw_input() -> None:
"""Quick test of our helper function."""
raw_input = build_raw_input(["cmd1"])
assert raw_input() == "cmd1"
with pytest.raises(EOFError):
raw_input()
# In 3.10 or later, types.FunctionType (used internally) will automatically
# attach __builtins__ to the function objects. However we need to explicitly
# include it for 3.8 & 3.9
def build_locals() -> dict[str, object]:
return {"__builtins__": __builtins__}
async def test_basic_interaction(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Run some basic commands through the interpreter while capturing stdout.
Ensure that the interpreted prints the expected results.
"""
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# evaluate simple expression and recall the value
"x = 1",
"print(f'{x=}')",
# Literal gets printed
"'hello'",
# define and call sync function
"def func():",
" print(x + 1)",
"",
"func()",
# define and call async function
"async def afunc():",
" return 4",
"",
"await afunc()",
# import works
"import sys",
"sys.stdout.write('hello stdout\\n')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"]
async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"raise SystemExit",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
with pytest.raises(SystemExit):
await trio._repl.run_repl(console)
async def test_KI_interrupts(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"from trio._util import signal_raise",
"import signal, trio, trio.lowlevel",
"async def f():",
" trio.lowlevel.spawn_system_task("
" trio.to_thread.run_sync,"
" signal_raise,signal.SIGINT,"
" )", # just awaiting this kills the test runner?!
" await trio.sleep_forever()",
" print('should not see this')",
"",
"await f()",
"print('AFTER KeyboardInterrupt')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "KeyboardInterrupt" in err
assert "should" not in out
assert "AFTER KeyboardInterrupt" in out
async def test_system_exits_in_exc_group(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"import sys",
"if sys.version_info < (3, 11):",
" from exceptiongroup import BaseExceptionGroup",
"",
"raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])",
"print('AFTER BaseExceptionGroup')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
# assert that raise SystemExit in an exception group
# doesn't quit
assert "AFTER BaseExceptionGroup" in out
async def test_system_exits_in_nested_exc_group(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"import sys",
"if sys.version_info < (3, 11):",
" from exceptiongroup import BaseExceptionGroup",
"",
"raise BaseExceptionGroup(",
" '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])",
"print('AFTER BaseExceptionGroup')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
# assert that raise SystemExit in an exception group
# doesn't quit
assert "AFTER BaseExceptionGroup" in out
async def test_base_exception_captured(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# The statement after raise should still get executed
"raise BaseException",
"print('AFTER BaseException')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "_threads.py" not in err
assert "_repl.py" not in err
assert "AFTER BaseException" in out
async def test_exc_group_captured(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
# The statement after raise should still get executed
"raise ExceptionGroup('', [KeyError()])",
"print('AFTER ExceptionGroup')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "AFTER ExceptionGroup" in out
async def test_base_exception_capture_from_coroutine(
capsys: pytest.CaptureFixture[str],
monkeypatch: pytest.MonkeyPatch,
) -> None:
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
raw_input = build_raw_input(
[
"async def async_func_raises_base_exception():",
" raise BaseException",
"",
# This will raise, but the statement after should still
# be executed
"await async_func_raises_base_exception()",
"print('AFTER BaseException')",
],
)
monkeypatch.setattr(console, "raw_input", raw_input)
await trio._repl.run_repl(console)
out, err = capsys.readouterr()
assert "_threads.py" not in err
assert "_repl.py" not in err
assert "AFTER BaseException" in out
def test_main_entrypoint() -> None:
"""
Basic smoke test when running via the package __main__ entrypoint.
"""
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
assert repl.returncode == 0

View File

@ -0,0 +1,47 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import trio
if TYPE_CHECKING:
import pytest
async def scheduler_trace() -> tuple[tuple[str, int], ...]:
"""Returns a scheduler-dependent value we can use to check determinism."""
trace = []
async def tracer(name: str) -> None:
for i in range(50):
trace.append((name, i))
await trio.lowlevel.checkpoint()
async with trio.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(tracer, str(i))
return tuple(trace)
def test_the_trio_scheduler_is_not_deterministic() -> None:
# At least, not yet. See https://github.com/python-trio/trio/issues/32
traces = [trio.run(scheduler_trace) for _ in range(10)]
assert len(set(traces)) == len(traces)
def test_the_trio_scheduler_is_deterministic_if_seeded(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
traces = []
for _ in range(10):
state = trio._core._run._r.getstate()
try:
trio._core._run._r.seed(0)
traces.append(trio.run(scheduler_trace))
finally:
trio._core._run._r.setstate(state)
assert len(traces) == 10
assert len(set(traces)) == 1

View File

@ -0,0 +1,188 @@
from __future__ import annotations
import signal
from typing import TYPE_CHECKING, NoReturn
import pytest
import trio
from trio.testing import RaisesGroup
from .. import _core
from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver
from .._util import signal_raise
if TYPE_CHECKING:
from types import FrameType
async def test_open_signal_receiver() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL) as receiver:
# Raise it a few times, to exercise signal coalescing, both at the
# call_soon level and at the SignalQueue level
signal_raise(signal.SIGILL)
signal_raise(signal.SIGILL)
await _core.wait_all_tasks_blocked()
signal_raise(signal.SIGILL)
await _core.wait_all_tasks_blocked()
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert get_pending_signal_count(receiver) == 0
signal_raise(signal.SIGILL)
async for signum in receiver: # pragma: no branch
assert signum == signal.SIGILL
break
assert get_pending_signal_count(receiver) == 0
with pytest.raises(RuntimeError):
await receiver.__anext__()
assert signal.getsignal(signal.SIGILL) is orig
async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with pytest.raises(
ValueError,
match="(signal number out of range|invalid signal value)$",
):
with open_signal_receiver(signal.SIGILL, 1234567):
pass # pragma: no cover
# Still restored even if we errored out
assert signal.getsignal(signal.SIGILL) is orig
async def test_open_signal_receiver_empty_fail() -> None:
with pytest.raises(TypeError, match="No signals were provided"):
with open_signal_receiver():
pass
async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None:
orig = signal.getsignal(signal.SIGILL)
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
pass
# Still restored correctly
assert signal.getsignal(signal.SIGILL) is orig
async def test_catch_signals_wrong_thread() -> None:
async def naughty() -> None:
with open_signal_receiver(signal.SIGINT):
pass # pragma: no cover
with pytest.raises(RuntimeError):
await trio.to_thread.run_sync(trio.run, naughty)
async def test_open_signal_receiver_conflict() -> None:
with RaisesGroup(trio.BusyResourceError):
with open_signal_receiver(signal.SIGILL) as receiver:
async with trio.open_nursery() as nursery:
nursery.start_soon(receiver.__anext__)
nursery.start_soon(receiver.__anext__)
# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
# processed.
async def wait_run_sync_soon_idempotent_queue_barrier() -> None:
ev = trio.Event()
token = _core.current_trio_token()
token.run_sync_soon(ev.set, idempotent=True)
await ev.wait()
async def test_open_signal_receiver_no_starvation() -> None:
# Set up a situation where there are always 2 pending signals available to
# report, and make sure that instead of getting the same signal reported
# over and over, it alternates between reporting both of them.
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
try:
print(signal.getsignal(signal.SIGILL))
previous = None
for _ in range(10):
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
if previous is None:
previous = await receiver.__anext__()
else:
got = await receiver.__anext__()
assert got in [signal.SIGILL, signal.SIGFPE]
assert got != previous
previous = got
# Clear out the last signal so that it doesn't get redelivered
while get_pending_signal_count(receiver) != 0:
await receiver.__anext__()
except BaseException: # pragma: no cover
# If there's an unhandled exception above, then exiting the
# open_signal_receiver block might cause the signal to be
# redelivered and give us a core dump instead of a traceback...
import traceback
traceback.print_exc()
async def test_catch_signals_race_condition_on_exit() -> None:
delivered_directly: set[int] = set()
def direct_handler(signo: int, frame: FrameType | None) -> None:
delivered_directly.add(signo)
print(1)
# Test the version where the call_soon *doesn't* have a chance to run
# before we exit the with block:
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()
print(2)
# Test the version where the call_soon *does* have a chance to run before
# we exit the with block:
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert get_pending_signal_count(receiver) == 2
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
delivered_directly.clear()
# Again, but with a SIG_IGN signal:
print(3)
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
# test passes if the process reaches this point without dying
print(4)
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
with open_signal_receiver(signal.SIGILL) as receiver:
signal_raise(signal.SIGILL)
await wait_run_sync_soon_idempotent_queue_barrier()
assert get_pending_signal_count(receiver) == 1
# test passes if the process reaches this point without dying
# Check exception chaining if there are multiple exception-raising
# handlers
def raise_handler(signum: int, frame: FrameType | None) -> NoReturn:
raise RuntimeError(signum)
with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
with pytest.raises(RuntimeError) as excinfo:
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
signal_raise(signal.SIGILL)
signal_raise(signal.SIGFPE)
await wait_run_sync_soon_idempotent_queue_barrier()
assert get_pending_signal_count(receiver) == 2
exc = excinfo.value
signums = {exc.args[0]}
assert isinstance(exc.__context__, RuntimeError)
signums.add(exc.__context__.args[0])
assert signums == {signal.SIGILL, signal.SIGFPE}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,696 @@
from __future__ import annotations
import gc
import os
import random
import signal
import subprocess
import sys
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path as SyncPath
from signal import Signals
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
AsyncIterator,
Callable,
NoReturn,
)
import pytest
import trio
from trio.testing import Matcher, RaisesGroup
from .. import (
Event,
Process,
_core,
fail_after,
move_on_after,
run_process,
sleep,
sleep_forever,
)
from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow
from ..lowlevel import open_process
from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked
if TYPE_CHECKING:
from types import FrameType
from typing_extensions import TypeAlias
from .._abc import ReceiveStream
if sys.platform == "win32":
SignalType: TypeAlias = None
else:
SignalType: TypeAlias = Signals
SIGKILL: SignalType
SIGTERM: SignalType
SIGUSR1: SignalType
posix = os.name == "posix"
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
from signal import SIGKILL, SIGTERM, SIGUSR1
else:
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
# Since Windows has very few command-line utilities generally available,
# all of our subprocesses are Python processes running short bits of
# (mostly) cross-platform code.
def python(code: str) -> list[str]:
return [sys.executable, "-u", "-c", "import sys; " + code]
EXIT_TRUE = python("sys.exit(0)")
EXIT_FALSE = python("sys.exit(1)")
CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
if posix:
def SLEEP(seconds: int) -> list[str]:
return ["sleep", str(seconds)]
else:
def SLEEP(seconds: int) -> list[str]:
return python(f"import time; time.sleep({seconds})")
def got_signal(proc: Process, sig: SignalType) -> bool:
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
return proc.returncode == -sig
else:
return proc.returncode != 0
@asynccontextmanager # type: ignore[misc] # Any in decorator
async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
proc = await open_process(*args, **kwargs)
try:
yield proc
finally:
proc.kill()
await proc.wait()
@asynccontextmanager # type: ignore[misc] # Any in decorator
async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
async with _core.open_nursery() as nursery:
kwargs.setdefault("check", False)
proc: Process = await nursery.start(partial(run_process, *args, **kwargs))
yield proc
nursery.cancel_scope.cancel()
background_process_param = pytest.mark.parametrize(
"background_process",
[open_process_then_kill, run_process_in_nursery],
ids=["open_process", "run_process in nursery"],
)
BackgroundProcessType: TypeAlias = Callable[..., AsyncContextManager[Process]]
@background_process_param
async def test_basic(background_process: BackgroundProcessType) -> None:
async with background_process(EXIT_TRUE) as proc:
await proc.wait()
assert isinstance(proc, Process)
assert proc._pidfd is None
assert proc.returncode == 0
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
async with background_process(EXIT_FALSE) as proc:
await proc.wait()
assert proc.returncode == 1
assert repr(proc) == "<trio.Process {!r}: {}>".format(
EXIT_FALSE,
"exited with status 1",
)
@background_process_param
async def test_auto_update_returncode(
background_process: BackgroundProcessType,
) -> None:
async with background_process(SLEEP(9999)) as p:
assert p.returncode is None
assert "running" in repr(p)
p.kill()
p._proc.wait()
assert p.returncode is not None
assert "exited" in repr(p)
assert p._pidfd is None
assert p.returncode is not None
@background_process_param
async def test_multi_wait(background_process: BackgroundProcessType) -> None:
async with background_process(SLEEP(10)) as proc:
# Check that wait (including multi-wait) tolerates being cancelled
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# Now try waiting for real
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
proc.kill()
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
"data = sys.stdin.buffer.read(); "
"sys.stdout.buffer.write(data); "
"sys.stderr.buffer.write(data[::-1])",
)
@background_process_param
async def test_pipes(background_process: BackgroundProcessType) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
msg = b"the quick brown fox jumps over the lazy dog"
async def feed_input() -> None:
assert proc.stdin is not None
await proc.stdin.send_all(msg)
await proc.stdin.aclose()
async def check_output(stream: ReceiveStream, expected: bytes) -> None:
seen = bytearray()
async for chunk in stream:
seen += chunk
assert seen == expected
assert proc.stdout is not None
assert proc.stderr is not None
async with _core.open_nursery() as nursery:
# fail eventually if something is broken
nursery.cancel_scope.deadline = _core.current_time() + 30.0
nursery.start_soon(feed_input)
nursery.start_soon(check_output, proc.stdout, msg)
nursery.start_soon(check_output, proc.stderr, msg[::-1])
assert not nursery.cancel_scope.cancelled_caught
assert await proc.wait() == 0
@background_process_param
async def test_interactive(background_process: BackgroundProcessType) -> None:
# Test some back-and-forth with a subprocess. This one works like so:
# in: 32\n
# out: 0000...0000\n (32 zeroes)
# err: 1111...1111\n (64 ones)
# in: 10\n
# out: 2222222222\n (10 twos)
# err: 3333....3333\n (20 threes)
# in: EOF
# out: EOF
# err: EOF
async with background_process(
python(
"idx = 0\n"
"while True:\n"
" line = sys.stdin.readline()\n"
" if line == '': break\n"
" request = int(line.strip())\n"
" print(str(idx * 2) * request)\n"
" print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
" idx += 1\n",
),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
) as proc:
newline = b"\n" if posix else b"\r\n"
async def expect(idx: int, request: int) -> None:
async with _core.open_nursery() as nursery:
async def drain_one(
stream: ReceiveStream,
count: int,
digit: int,
) -> None:
while count > 0:
result = await stream.receive_some(count)
assert result == (f"{digit}".encode() * len(result))
count -= len(result)
assert count == 0
assert await stream.receive_some(len(newline)) == newline
assert proc.stdout is not None
assert proc.stderr is not None
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
assert proc.stdin is not None
assert proc.stdout is not None
assert proc.stderr is not None
with fail_after(5):
await proc.stdin.send_all(b"12")
await sleep(0.1)
await proc.stdin.send_all(b"345" + newline)
await expect(0, 12345)
await proc.stdin.send_all(b"100" + newline + b"200" + newline)
await expect(1, 100)
await expect(2, 200)
await proc.stdin.send_all(b"0" + newline)
await expect(3, 0)
await proc.stdin.send_all(b"999999")
with move_on_after(0.1) as scope:
await expect(4, 0)
assert scope.cancelled_caught
await proc.stdin.send_all(newline)
await expect(4, 999999)
await proc.stdin.aclose()
assert await proc.stdout.receive_some(1) == b""
assert await proc.stderr.receive_some(1) == b""
await proc.wait()
assert proc.returncode == 0
async def test_run() -> None:
data = bytes(random.randint(0, 255) for _ in range(2**18))
result = await run_process(
CAT,
stdin=data,
capture_stdout=True,
capture_stderr=True,
)
assert result.args == CAT
assert result.returncode == 0
assert result.stdout == data
assert result.stderr == b""
result = await run_process(CAT, capture_stdout=True)
assert result.args == CAT
assert result.returncode == 0
assert result.stdout == b""
assert result.stderr is None
result = await run_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=data,
capture_stdout=True,
capture_stderr=True,
)
assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
assert result.returncode == 0
assert result.stdout == data
assert result.stderr == data[::-1]
# invalid combinations
with pytest.raises(UnicodeError):
await run_process(CAT, stdin="oh no, it's text")
pipe_stdout_error = r"^stdout=subprocess\.PIPE is only valid with nursery\.start, since that's the only way to access the pipe(; use nursery\.start or pass the data you want to write directly)*$"
with pytest.raises(ValueError, match=pipe_stdout_error):
await run_process(CAT, stdin=subprocess.PIPE)
with pytest.raises(ValueError, match=pipe_stdout_error):
await run_process(CAT, stdout=subprocess.PIPE)
with pytest.raises(
ValueError,
match=pipe_stdout_error.replace("stdout", "stderr", 1),
):
await run_process(CAT, stderr=subprocess.PIPE)
with pytest.raises(
ValueError,
match="^can't specify both stdout and capture_stdout$",
):
await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
with pytest.raises(
ValueError,
match="^can't specify both stderr and capture_stderr$",
):
await run_process(CAT, capture_stderr=True, stderr=None)
async def test_run_check() -> None:
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
with pytest.raises(subprocess.CalledProcessError) as excinfo:
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
assert excinfo.value.cmd == cmd
assert excinfo.value.returncode == 1
assert excinfo.value.stderr == b"test\n"
assert excinfo.value.stdout is None
result = await run_process(
cmd,
capture_stdout=True,
capture_stderr=True,
check=False,
)
assert result.args == cmd
assert result.stdout == b""
assert result.stderr == b"test\n"
assert result.returncode == 1
@skip_if_fbsd_pipes_broken
async def test_run_with_broken_pipe() -> None:
result = await run_process(
[sys.executable, "-c", "import sys; sys.stdin.close()"],
stdin=b"x" * 131072,
)
assert result.returncode == 0
assert result.stdout is result.stderr is None
@background_process_param
async def test_stderr_stdout(background_process: BackgroundProcessType) -> None:
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as proc:
assert proc.stdio is not None
assert proc.stdout is not None
assert proc.stderr is None
await proc.stdio.send_all(b"1234")
await proc.stdio.send_eof()
output = []
while True:
chunk = await proc.stdio.receive_some(16)
if chunk == b"":
break
output.append(chunk)
assert b"".join(output) == b"12344321"
assert proc.returncode == 0
# equivalent test with run_process()
result = await run_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=b"1234",
capture_stdout=True,
stderr=subprocess.STDOUT,
)
assert result.returncode == 0
assert result.stdout == b"12344321"
assert result.stderr is None
# this one hits the branch where stderr=STDOUT but stdout
# is not redirected
async with background_process(
CAT,
stdin=subprocess.PIPE,
stderr=subprocess.STDOUT,
) as proc:
assert proc.stdout is None
assert proc.stderr is None
await proc.stdin.aclose()
await proc.wait()
assert proc.returncode == 0
if posix:
try:
r, w = os.pipe()
async with background_process(
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
stdin=subprocess.PIPE,
stdout=w,
stderr=subprocess.STDOUT,
) as proc:
os.close(w)
assert proc.stdio is None
assert proc.stdout is None
assert proc.stderr is None
await proc.stdin.send_all(b"1234")
await proc.stdin.aclose()
assert await proc.wait() == 0
assert os.read(r, 4096) == b"12344321"
assert os.read(r, 4096) == b""
finally:
os.close(r)
async def test_errors() -> None:
with pytest.raises(TypeError) as excinfo:
# call-overload on unix, call-arg on windows
await open_process(["ls"], encoding="utf-8") # type: ignore
assert "unbuffered byte streams" in str(excinfo.value)
assert "the 'encoding' option is not supported" in str(excinfo.value)
if posix:
with pytest.raises(TypeError) as excinfo:
await open_process(["ls"], shell=True)
with pytest.raises(TypeError) as excinfo:
await open_process("ls", shell=False)
@background_process_param
async def test_signals(background_process: BackgroundProcessType) -> None:
async def test_one_signal(
send_it: Callable[[Process], None],
signum: signal.Signals | None,
) -> None:
with move_on_after(1.0) as scope:
async with background_process(SLEEP(3600)) as proc:
send_it(proc)
await proc.wait()
assert not scope.cancelled_caught
if posix:
assert signum is not None
assert proc.returncode == -signum
else:
assert proc.returncode != 0
await test_one_signal(Process.kill, SIGKILL)
await test_one_signal(Process.terminate, SIGTERM)
# Test that we can send arbitrary signals.
#
# We used to use SIGINT here, but it turns out that the Python interpreter
# has race conditions that can cause it to explode in weird ways if it
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
# to terminate the target process, and Python doesn't try to do anything
# clever to handle it.
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
@pytest.mark.skipif(not posix, reason="POSIX specific")
@background_process_param
async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None:
if TYPE_CHECKING and sys.platform == "win32":
return
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
try:
# With SIGCHLD disabled, the wait() syscall will wait for the
# process to exit but then fail with ECHILD. Make sure we
# support this case as the stdlib subprocess module does.
async with background_process(SLEEP(3600)) as proc:
async with _core.open_nursery() as nursery:
nursery.start_soon(proc.wait)
await wait_all_tasks_blocked()
proc.kill()
nursery.cancel_scope.deadline = _core.current_time() + 1.0
assert not nursery.cancel_scope.cancelled_caught
assert proc.returncode == 0 # exit status unknowable, so...
finally:
signal.signal(signal.SIGCHLD, old_sigchld)
@slow
def test_waitid_eintr() -> None:
# This only matters on PyPy (where we're coding EINTR handling
# ourselves) but the test works on all waitid platforms.
from .._subprocess_platform import wait_child_exiting
if TYPE_CHECKING and (sys.platform == "win32" or sys.platform == "darwin"):
return
if not wait_child_exiting.__module__.endswith("waitid"):
pytest.skip("waitid only")
# despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc
# this import is still checked on win32&darwin and raises [attr-defined].
# Linux doesn't raise [attr-defined] though, so we need [unused-ignore]
from .._subprocess_platform.waitid import ( # type: ignore[attr-defined, unused-ignore]
sync_wait_reapable,
)
got_alarm = False
sleeper = subprocess.Popen(["sleep", "3600"])
def on_alarm(sig: int, frame: FrameType | None) -> None:
nonlocal got_alarm
got_alarm = True
sleeper.kill()
old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
try:
signal.alarm(1)
sync_wait_reapable(sleeper.pid)
assert sleeper.wait(timeout=1) == -9
finally:
if sleeper.returncode is None: # pragma: no cover
# We only get here if something fails in the above;
# if the test passes, wait() will reap the process
sleeper.kill()
sleeper.wait()
signal.signal(signal.SIGALRM, old_sigalrm)
async def test_custom_deliver_cancel() -> None:
custom_deliver_cancel_called = False
async def custom_deliver_cancel(proc: Process) -> None:
nonlocal custom_deliver_cancel_called
custom_deliver_cancel_called = True
proc.terminate()
# Make sure this does get cancelled when the process exits, and that
# the process really exited.
try:
await sleep_forever()
finally:
assert proc.returncode is not None
async with _core.open_nursery() as nursery:
nursery.start_soon(
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert custom_deliver_cancel_called
def test_bad_deliver_cancel() -> None:
async def custom_deliver_cancel(proc: Process) -> None:
proc.terminate()
raise ValueError("foo")
async def do_stuff() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# double wrap from our nursery + the internal nursery
with RaisesGroup(RaisesGroup(Matcher(ValueError, "^foo$"))):
_core.run(do_stuff, strict_exception_groups=True)
async def test_warn_on_failed_cancel_terminate(monkeypatch: pytest.MonkeyPatch) -> None:
original_terminate = Process.terminate
def broken_terminate(self: Process) -> NoReturn:
original_terminate(self)
raise OSError("whoops")
monkeypatch.setattr(Process, "terminate", broken_terminate)
with pytest.warns(RuntimeWarning, match=".*whoops.*"):
async with _core.open_nursery() as nursery:
nursery.start_soon(run_process, SLEEP(9999))
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
@pytest.mark.skipif(not posix, reason="posix only")
async def test_warn_on_cancel_SIGKILL_escalation(
autojump_clock: MockClock,
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(Process, "terminate", lambda *args: None)
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
async with _core.open_nursery() as nursery:
nursery.start_soon(run_process, SLEEP(9999))
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# the background_process_param exercises a lot of run_process cases, but it uses
# check=False, so lets have a test that uses check=True as well
async def test_run_process_background_fail() -> None:
with RaisesGroup(subprocess.CalledProcessError):
async with _core.open_nursery() as nursery:
proc: Process = await nursery.start(run_process, EXIT_FALSE)
assert proc.returncode == 1
@pytest.mark.skipif(
not SyncPath("/dev/fd").exists(),
reason="requires a way to iterate through open files",
)
async def test_for_leaking_fds() -> None:
gc.collect() # address possible flakiness on PyPy
starting_fds = set(SyncPath("/dev/fd").iterdir())
await run_process(EXIT_TRUE)
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
with pytest.raises(subprocess.CalledProcessError):
await run_process(EXIT_FALSE)
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
with pytest.raises(PermissionError):
await run_process(["/dev/fd/0"])
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
async def test_run_process_internal_error(monkeypatch: pytest.MonkeyPatch) -> None:
# There's probably less extreme ways of triggering errors inside the nursery
# in run_process.
async def very_broken_open(*args: object, **kwargs: object) -> str:
return "oops"
monkeypatch.setattr(trio._subprocess, "open_process", very_broken_open)
with RaisesGroup(AttributeError, AttributeError):
await run_process(EXIT_TRUE, capture_stdout=True)
# regression test for #2209
async def test_subprocess_pidfd_unnotified() -> None:
noticed_exit = None
async def wait_and_tell(proc: Process) -> None:
nonlocal noticed_exit
noticed_exit = Event()
await proc.wait()
noticed_exit.set()
proc = await open_process(SLEEP(9999))
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_and_tell, proc)
await wait_all_tasks_blocked()
assert isinstance(noticed_exit, Event)
proc.terminate()
# without giving trio a chance to do so,
with assert_no_checkpoints():
# wait until the process has actually exited;
proc._proc.wait()
# force a call to poll (that closes the pidfd on linux)
proc.poll()
with move_on_after(5):
# Some platforms use threads to wait for exit, so it might take a bit
# for everything to notice
await noticed_exit.wait()
assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK"

View File

@ -0,0 +1,655 @@
from __future__ import annotations
import re
import weakref
from typing import TYPE_CHECKING, Callable, Union
import pytest
from trio.testing import Matcher, RaisesGroup
from .. import _core
from .._core._parking_lot import GLOBAL_PARKING_LOT_BREAKER
from .._sync import *
from .._timeouts import sleep_forever
from ..testing import assert_checkpoints, wait_all_tasks_blocked
if TYPE_CHECKING:
from typing_extensions import TypeAlias
async def test_Event() -> None:
e = Event()
assert not e.is_set()
assert e.statistics().tasks_waiting == 0
e.set()
assert e.is_set()
with assert_checkpoints():
await e.wait()
e = Event()
record = []
async def child() -> None:
record.append("sleeping")
await e.wait()
record.append("woken")
async with _core.open_nursery() as nursery:
nursery.start_soon(child)
nursery.start_soon(child)
await wait_all_tasks_blocked()
assert record == ["sleeping", "sleeping"]
assert e.statistics().tasks_waiting == 2
e.set()
await wait_all_tasks_blocked()
assert record == ["sleeping", "sleeping", "woken", "woken"]
async def test_CapacityLimiter() -> None:
with pytest.raises(TypeError):
CapacityLimiter(1.0)
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
CapacityLimiter(-1)
c = CapacityLimiter(2)
repr(c) # smoke test
assert c.total_tokens == 2
assert c.borrowed_tokens == 0
assert c.available_tokens == 2
with pytest.raises(RuntimeError):
c.release()
assert c.borrowed_tokens == 0
c.acquire_nowait()
assert c.borrowed_tokens == 1
assert c.available_tokens == 1
stats = c.statistics()
assert stats.borrowed_tokens == 1
assert stats.total_tokens == 2
assert stats.borrowers == [_core.current_task()]
assert stats.tasks_waiting == 0
# Can't re-acquire when we already have it
with pytest.raises(RuntimeError):
c.acquire_nowait()
assert c.borrowed_tokens == 1
with pytest.raises(RuntimeError):
await c.acquire()
assert c.borrowed_tokens == 1
# We can acquire on behalf of someone else though
with assert_checkpoints():
await c.acquire_on_behalf_of("someone")
# But then we've run out of capacity
assert c.borrowed_tokens == 2
with pytest.raises(_core.WouldBlock):
c.acquire_on_behalf_of_nowait("third party")
assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
# Until we release one
c.release_on_behalf_of(_core.current_task())
assert c.statistics().borrowers == ["someone"]
c.release_on_behalf_of("someone")
assert c.borrowed_tokens == 0
with assert_checkpoints():
async with c:
assert c.borrowed_tokens == 1
async with _core.open_nursery() as nursery:
await c.acquire_on_behalf_of("value 1")
await c.acquire_on_behalf_of("value 2")
nursery.start_soon(c.acquire_on_behalf_of, "value 3")
await wait_all_tasks_blocked()
assert c.borrowed_tokens == 2
assert c.statistics().tasks_waiting == 1
c.release_on_behalf_of("value 2")
# Fairness:
assert c.borrowed_tokens == 2
with pytest.raises(_core.WouldBlock):
c.acquire_nowait()
c.release_on_behalf_of("value 3")
c.release_on_behalf_of("value 1")
async def test_CapacityLimiter_inf() -> None:
from math import inf
c = CapacityLimiter(inf)
repr(c) # smoke test
assert c.total_tokens == inf
assert c.borrowed_tokens == 0
assert c.available_tokens == inf
with pytest.raises(RuntimeError):
c.release()
assert c.borrowed_tokens == 0
c.acquire_nowait()
assert c.borrowed_tokens == 1
assert c.available_tokens == inf
async def test_CapacityLimiter_change_total_tokens() -> None:
c = CapacityLimiter(2)
with pytest.raises(TypeError):
c.total_tokens = 1.0
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
c.total_tokens = 0
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
c.total_tokens = -10
assert c.total_tokens == 2
async with _core.open_nursery() as nursery:
for i in range(5):
nursery.start_soon(c.acquire_on_behalf_of, i)
await wait_all_tasks_blocked()
assert set(c.statistics().borrowers) == {0, 1}
assert c.statistics().tasks_waiting == 3
c.total_tokens += 2
assert set(c.statistics().borrowers) == {0, 1, 2, 3}
assert c.statistics().tasks_waiting == 1
c.total_tokens -= 3
assert c.borrowed_tokens == 4
assert c.total_tokens == 1
c.release_on_behalf_of(0)
c.release_on_behalf_of(1)
c.release_on_behalf_of(2)
assert set(c.statistics().borrowers) == {3}
assert c.statistics().tasks_waiting == 1
c.release_on_behalf_of(3)
assert set(c.statistics().borrowers) == {4}
assert c.statistics().tasks_waiting == 0
# regression test for issue #548
async def test_CapacityLimiter_memleak_548() -> None:
limiter = CapacityLimiter(total_tokens=1)
await limiter.acquire()
async with _core.open_nursery() as n:
n.start_soon(limiter.acquire)
await wait_all_tasks_blocked() # give it a chance to run the task
n.cancel_scope.cancel()
# if this is 1, the acquire call (despite being killed) is still there in the task, and will
# leak memory all the while the limiter is active
assert len(limiter._pending_borrowers) == 0
async def test_Semaphore() -> None:
with pytest.raises(TypeError):
Semaphore(1.0) # type: ignore[arg-type]
with pytest.raises(ValueError, match="^initial value must be >= 0$"):
Semaphore(-1)
s = Semaphore(1)
repr(s) # smoke test
assert s.value == 1
assert s.max_value is None
s.release()
assert s.value == 2
assert s.statistics().tasks_waiting == 0
s.acquire_nowait()
assert s.value == 1
with assert_checkpoints():
await s.acquire()
assert s.value == 0
with pytest.raises(_core.WouldBlock):
s.acquire_nowait()
s.release()
assert s.value == 1
with assert_checkpoints():
async with s:
assert s.value == 0
assert s.value == 1
s.acquire_nowait()
record = []
async def do_acquire(s: Semaphore) -> None:
record.append("started")
await s.acquire()
record.append("finished")
async with _core.open_nursery() as nursery:
nursery.start_soon(do_acquire, s)
await wait_all_tasks_blocked()
assert record == ["started"]
assert s.value == 0
s.release()
# Fairness:
assert s.value == 0
with pytest.raises(_core.WouldBlock):
s.acquire_nowait()
assert record == ["started", "finished"]
async def test_Semaphore_bounded() -> None:
with pytest.raises(TypeError):
Semaphore(1, max_value=1.0) # type: ignore[arg-type]
with pytest.raises(ValueError, match="^max_values must be >= initial_value$"):
Semaphore(2, max_value=1)
bs = Semaphore(1, max_value=1)
assert bs.max_value == 1
repr(bs) # smoke test
with pytest.raises(ValueError, match="^semaphore released too many times$"):
bs.release()
assert bs.value == 1
bs.acquire_nowait()
assert bs.value == 0
bs.release()
assert bs.value == 1
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
async def test_Lock_and_StrictFIFOLock(
lockcls: type[Lock | StrictFIFOLock],
) -> None:
l = lockcls() # noqa
assert not l.locked()
# make sure locks can be weakref'ed (gh-331)
r = weakref.ref(l)
assert r() is l
repr(l) # smoke test
# make sure repr uses the right name for subclasses
assert lockcls.__name__ in repr(l)
with assert_checkpoints():
async with l:
assert l.locked()
repr(l) # smoke test (repr branches on locked/unlocked)
assert not l.locked()
l.acquire_nowait()
assert l.locked()
l.release()
assert not l.locked()
with assert_checkpoints():
await l.acquire()
assert l.locked()
l.release()
assert not l.locked()
l.acquire_nowait()
with pytest.raises(RuntimeError):
# Error out if we already own the lock
l.acquire_nowait()
l.release()
with pytest.raises(RuntimeError):
# Error out if we don't own the lock
l.release()
holder_task = None
async def holder() -> None:
nonlocal holder_task
holder_task = _core.current_task()
async with l:
await sleep_forever()
async with _core.open_nursery() as nursery:
assert not l.locked()
nursery.start_soon(holder)
await wait_all_tasks_blocked()
assert l.locked()
# WouldBlock if someone else holds the lock
with pytest.raises(_core.WouldBlock):
l.acquire_nowait()
# Can't release a lock someone else holds
with pytest.raises(RuntimeError):
l.release()
statistics = l.statistics()
print(statistics)
assert statistics.locked
assert statistics.owner is holder_task
assert statistics.tasks_waiting == 0
nursery.start_soon(holder)
await wait_all_tasks_blocked()
statistics = l.statistics()
print(statistics)
assert statistics.tasks_waiting == 1
nursery.cancel_scope.cancel()
statistics = l.statistics()
assert not statistics.locked
assert statistics.owner is None
assert statistics.tasks_waiting == 0
async def test_Condition() -> None:
with pytest.raises(TypeError):
Condition(Semaphore(1)) # type: ignore[arg-type]
with pytest.raises(TypeError):
Condition(StrictFIFOLock) # type: ignore[arg-type]
l = Lock() # noqa
c = Condition(l)
assert not l.locked()
assert not c.locked()
with assert_checkpoints():
await c.acquire()
assert l.locked()
assert c.locked()
c = Condition()
assert not c.locked()
c.acquire_nowait()
assert c.locked()
with pytest.raises(RuntimeError):
c.acquire_nowait()
c.release()
with pytest.raises(RuntimeError):
# Can't wait without holding the lock
await c.wait()
with pytest.raises(RuntimeError):
# Can't notify without holding the lock
c.notify()
with pytest.raises(RuntimeError):
# Can't notify without holding the lock
c.notify_all()
finished_waiters = set()
async def waiter(i: int) -> None:
async with c:
await c.wait()
finished_waiters.add(i)
async with _core.open_nursery() as nursery:
for i in range(3):
nursery.start_soon(waiter, i)
await wait_all_tasks_blocked()
async with c:
c.notify()
assert c.locked()
await wait_all_tasks_blocked()
assert finished_waiters == {0}
async with c:
c.notify_all()
await wait_all_tasks_blocked()
assert finished_waiters == {0, 1, 2}
finished_waiters = set()
async with _core.open_nursery() as nursery:
for i in range(3):
nursery.start_soon(waiter, i)
await wait_all_tasks_blocked()
async with c:
c.notify(2)
statistics = c.statistics()
print(statistics)
assert statistics.tasks_waiting == 1
assert statistics.lock_statistics.tasks_waiting == 2
# exiting the context manager hands off the lock to the first task
assert c.statistics().lock_statistics.tasks_waiting == 1
await wait_all_tasks_blocked()
assert finished_waiters == {0, 1}
async with c:
c.notify_all()
# After being cancelled still hold the lock (!)
# (Note that c.__aexit__ checks that we hold the lock as well)
with _core.CancelScope() as scope:
async with c:
scope.cancel()
try:
await c.wait()
finally:
assert c.locked()
from .._channel import open_memory_channel
from .._sync import AsyncContextManagerMixin
# Three ways of implementing a Lock in terms of a channel. Used to let us put
# the channel through the generic lock tests.
class ChannelLock1(AsyncContextManagerMixin):
def __init__(self, capacity: int) -> None:
self.s, self.r = open_memory_channel[None](capacity)
for _ in range(capacity - 1):
self.s.send_nowait(None)
def acquire_nowait(self) -> None:
self.s.send_nowait(None)
async def acquire(self) -> None:
await self.s.send(None)
def release(self) -> None:
self.r.receive_nowait()
class ChannelLock2(AsyncContextManagerMixin):
def __init__(self) -> None:
self.s, self.r = open_memory_channel[None](10)
self.s.send_nowait(None)
def acquire_nowait(self) -> None:
self.r.receive_nowait()
async def acquire(self) -> None:
await self.r.receive()
def release(self) -> None:
self.s.send_nowait(None)
class ChannelLock3(AsyncContextManagerMixin):
def __init__(self) -> None:
self.s, self.r = open_memory_channel[None](0)
# self.acquired is true when one task acquires the lock and
# only becomes false when it's released and no tasks are
# waiting to acquire.
self.acquired = False
def acquire_nowait(self) -> None:
assert not self.acquired
self.acquired = True
async def acquire(self) -> None:
if self.acquired:
await self.s.send(None)
else:
self.acquired = True
await _core.checkpoint()
def release(self) -> None:
try:
self.r.receive_nowait()
except _core.WouldBlock:
assert self.acquired
self.acquired = False
lock_factories = [
lambda: CapacityLimiter(1),
lambda: Semaphore(1),
Lock,
StrictFIFOLock,
lambda: ChannelLock1(10),
lambda: ChannelLock1(1),
ChannelLock2,
ChannelLock3,
]
lock_factory_names = [
"CapacityLimiter(1)",
"Semaphore(1)",
"Lock",
"StrictFIFOLock",
"ChannelLock1(10)",
"ChannelLock1(1)",
"ChannelLock2",
"ChannelLock3",
]
generic_lock_test = pytest.mark.parametrize(
"lock_factory",
lock_factories,
ids=lock_factory_names,
)
LockLike: TypeAlias = Union[
CapacityLimiter,
Semaphore,
Lock,
StrictFIFOLock,
ChannelLock1,
ChannelLock2,
ChannelLock3,
]
LockFactory: TypeAlias = Callable[[], LockLike]
# Spawn a bunch of workers that take a lock and then yield; make sure that
# only one worker is ever in the critical section at a time.
@generic_lock_test
async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None:
LOOPS = 10
WORKERS = 5
in_critical_section = False
acquires = 0
async def worker(lock_like: LockLike) -> None:
nonlocal in_critical_section, acquires
for _ in range(LOOPS):
async with lock_like:
acquires += 1
assert not in_critical_section
in_critical_section = True
await _core.checkpoint()
await _core.checkpoint()
assert in_critical_section
in_critical_section = False
async with _core.open_nursery() as nursery:
lock_like = lock_factory()
for _ in range(WORKERS):
nursery.start_soon(worker, lock_like)
assert not in_critical_section
assert acquires == LOOPS * WORKERS
# Several workers queue on the same lock; make sure they each get it, in
# order.
@generic_lock_test
async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None:
initial_order = []
record = []
LOOPS = 5
async def loopy(name: int, lock_like: LockLike) -> None:
# Record the order each task was initially scheduled in
initial_order.append(name)
for _ in range(LOOPS):
async with lock_like:
record.append(name)
lock_like = lock_factory()
async with _core.open_nursery() as nursery:
nursery.start_soon(loopy, 1, lock_like)
nursery.start_soon(loopy, 2, lock_like)
nursery.start_soon(loopy, 3, lock_like)
# The first three could be in any order due to scheduling randomness,
# but after that they should repeat in the same order
for i in range(LOOPS):
assert record[3 * i : 3 * (i + 1)] == initial_order
@generic_lock_test
async def test_generic_lock_acquire_nowait_blocks_acquire(
lock_factory: LockFactory,
) -> None:
lock_like = lock_factory()
record = []
async def lock_taker() -> None:
record.append("started")
async with lock_like:
pass
record.append("finished")
async with _core.open_nursery() as nursery:
lock_like.acquire_nowait()
nursery.start_soon(lock_taker)
await wait_all_tasks_blocked()
assert record == ["started"]
lock_like.release()
async def test_lock_acquire_unowned_lock() -> None:
"""Test that trying to acquire a lock whose owner has exited raises an error.
see https://github.com/python-trio/trio/issues/3035
"""
assert not GLOBAL_PARKING_LOT_BREAKER
lock = trio.Lock()
async with trio.open_nursery() as nursery:
nursery.start_soon(lock.acquire)
owner_str = re.escape(str(lock._lot.broken_by[0]))
with pytest.raises(
trio.BrokenResourceError,
match=f"^Owner of this lock exited without releasing: {owner_str}$",
):
await lock.acquire()
assert not GLOBAL_PARKING_LOT_BREAKER
async def test_lock_multiple_acquire() -> None:
"""Test for error if awaiting on a lock whose owner exits without releasing.
see https://github.com/python-trio/trio/issues/3035"""
assert not GLOBAL_PARKING_LOT_BREAKER
lock = trio.Lock()
with RaisesGroup(
Matcher(
trio.BrokenResourceError,
match="^Owner of this lock exited without releasing: ",
),
):
async with trio.open_nursery() as nursery:
nursery.start_soon(lock.acquire)
nursery.start_soon(lock.acquire)
assert not GLOBAL_PARKING_LOT_BREAKER
async def test_lock_handover() -> None:
assert not GLOBAL_PARKING_LOT_BREAKER
child_task: Task | None = None
lock = trio.Lock()
# this task acquires the lock
lock.acquire_nowait()
assert GLOBAL_PARKING_LOT_BREAKER == {
_core.current_task(): [
lock._lot,
],
}
async with trio.open_nursery() as nursery:
nursery.start_soon(lock.acquire)
await wait_all_tasks_blocked()
# hand over the lock to the child task
lock.release()
# check values, and get the identifier out of the dict for later check
assert len(GLOBAL_PARKING_LOT_BREAKER) == 1
child_task = next(iter(GLOBAL_PARKING_LOT_BREAKER))
assert GLOBAL_PARKING_LOT_BREAKER[child_task] == [lock._lot]
assert lock._lot.broken_by == [child_task]
assert not GLOBAL_PARKING_LOT_BREAKER

View File

@ -0,0 +1,684 @@
from __future__ import annotations
# XX this should get broken up, like testing.py did
import tempfile
from typing import TYPE_CHECKING
import pytest
from trio.testing import RaisesGroup
from .. import _core, sleep, socket as tsocket
from .._core._tests.tutil import can_bind_ipv6
from .._highlevel_generic import StapledStream, aclose_forcefully
from .._highlevel_socket import SocketListener
from ..testing import *
from ..testing._check_streams import _assert_raises
from ..testing._memory_streams import _UnboundedByteQueue
if TYPE_CHECKING:
from trio import Nursery
from trio.abc import ReceiveStream, SendStream
async def test_wait_all_tasks_blocked() -> None:
record = []
async def busy_bee() -> None:
for _ in range(10):
await _core.checkpoint()
record.append("busy bee exhausted")
async def waiting_for_bee_to_leave() -> None:
await wait_all_tasks_blocked()
record.append("quiet at last!")
async with _core.open_nursery() as nursery:
nursery.start_soon(busy_bee)
nursery.start_soon(waiting_for_bee_to_leave)
nursery.start_soon(waiting_for_bee_to_leave)
# check cancellation
record = []
async def cancelled_while_waiting() -> None:
try:
await wait_all_tasks_blocked()
except _core.Cancelled:
record.append("ok")
async with _core.open_nursery() as nursery:
nursery.start_soon(cancelled_while_waiting)
nursery.cancel_scope.cancel()
assert record == ["ok"]
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock: MockClock) -> None:
record = []
async def timeout_task() -> None:
record.append("tt start")
await sleep(5)
record.append("tt finished")
async with _core.open_nursery() as nursery:
nursery.start_soon(timeout_task)
await wait_all_tasks_blocked()
assert record == ["tt start"]
mock_clock.jump(10)
await wait_all_tasks_blocked()
assert record == ["tt start", "tt finished"]
async def test_wait_all_tasks_blocked_with_cushion() -> None:
record = []
async def blink() -> None:
record.append("blink start")
await sleep(0.01)
await sleep(0.01)
await sleep(0.01)
record.append("blink end")
async def wait_no_cushion() -> None:
await wait_all_tasks_blocked()
record.append("wait_no_cushion end")
async def wait_small_cushion() -> None:
await wait_all_tasks_blocked(0.02)
record.append("wait_small_cushion end")
async def wait_big_cushion() -> None:
await wait_all_tasks_blocked(0.03)
record.append("wait_big_cushion end")
async with _core.open_nursery() as nursery:
nursery.start_soon(blink)
nursery.start_soon(wait_no_cushion)
nursery.start_soon(wait_small_cushion)
nursery.start_soon(wait_small_cushion)
nursery.start_soon(wait_big_cushion)
assert record == [
"blink start",
"wait_no_cushion end",
"blink end",
"wait_small_cushion end",
"wait_small_cushion end",
"wait_big_cushion end",
]
################################################################
async def test_assert_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
with assert_checkpoints():
await _core.checkpoint()
with pytest.raises(AssertionError):
with assert_checkpoints():
1 + 1 # noqa: B018 # "useless expression"
# partial yield cases
# if you have a schedule point but not a cancel point, or vice-versa, then
# that's not a checkpoint.
for partial_yield in [
_core.checkpoint_if_cancelled,
_core.cancel_shielded_checkpoint,
]:
print(partial_yield)
with pytest.raises(AssertionError):
with assert_checkpoints():
await partial_yield()
# But both together count as a checkpoint
with assert_checkpoints():
await _core.checkpoint_if_cancelled()
await _core.cancel_shielded_checkpoint()
async def test_assert_no_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
with assert_no_checkpoints():
1 + 1 # noqa: B018 # "useless expression"
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await _core.checkpoint()
# partial yield cases
# if you have a schedule point but not a cancel point, or vice-versa, then
# that doesn't make *either* version of assert_{no_,}yields happy.
for partial_yield in [
_core.checkpoint_if_cancelled,
_core.cancel_shielded_checkpoint,
]:
print(partial_yield)
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await partial_yield()
# And both together also count as a checkpoint
with pytest.raises(AssertionError):
with assert_no_checkpoints():
await _core.checkpoint_if_cancelled()
await _core.cancel_shielded_checkpoint()
################################################################
async def test_Sequencer() -> None:
record = []
def t(val: object) -> None:
print(val)
record.append(val)
async def f1(seq: Sequencer) -> None:
async with seq(1):
t(("f1", 1))
async with seq(3):
t(("f1", 3))
async with seq(4):
t(("f1", 4))
async def f2(seq: Sequencer) -> None:
async with seq(0):
t(("f2", 0))
async with seq(2):
t(("f2", 2))
seq = Sequencer()
async with _core.open_nursery() as nursery:
nursery.start_soon(f1, seq)
nursery.start_soon(f2, seq)
async with seq(5):
await wait_all_tasks_blocked()
assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
seq = Sequencer()
# Catches us if we try to reuse a sequence point:
async with seq(0):
pass
with pytest.raises(RuntimeError):
async with seq(0):
pass # pragma: no cover
async def test_Sequencer_cancel() -> None:
# Killing a blocked task makes everything blow up
record = []
seq = Sequencer()
async def child(i: int) -> None:
with _core.CancelScope() as scope:
if i == 1:
scope.cancel()
try:
async with seq(i):
pass # pragma: no cover
except RuntimeError:
record.append(f"seq({i}) RuntimeError")
async with _core.open_nursery() as nursery:
nursery.start_soon(child, 1)
nursery.start_soon(child, 2)
async with seq(0):
pass # pragma: no cover
assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
# Late arrivals also get errors
with pytest.raises(RuntimeError):
async with seq(3):
pass # pragma: no cover
################################################################
async def test__assert_raises() -> None:
with pytest.raises(AssertionError):
with _assert_raises(RuntimeError):
1 + 1 # noqa: B018 # "useless expression"
with pytest.raises(TypeError):
with _assert_raises(RuntimeError):
"foo" + 1 # type: ignore[operator] # noqa: B018 # "useless expression"
with _assert_raises(RuntimeError):
raise RuntimeError
# This is a private implementation detail, but it's complex enough to be worth
# testing directly
async def test__UnboundeByteQueue() -> None:
ubq = _UnboundedByteQueue()
ubq.put(b"123")
ubq.put(b"456")
assert ubq.get_nowait(1) == b"1"
assert ubq.get_nowait(10) == b"23456"
ubq.put(b"789")
assert ubq.get_nowait() == b"789"
with pytest.raises(_core.WouldBlock):
ubq.get_nowait(10)
with pytest.raises(_core.WouldBlock):
ubq.get_nowait()
with pytest.raises(TypeError):
ubq.put("string") # type: ignore[arg-type]
ubq.put(b"abc")
with assert_checkpoints():
assert await ubq.get(10) == b"abc"
ubq.put(b"def")
ubq.put(b"ghi")
with assert_checkpoints():
assert await ubq.get(1) == b"d"
with assert_checkpoints():
assert await ubq.get() == b"efghi"
async def putter(data: bytes) -> None:
await wait_all_tasks_blocked()
ubq.put(data)
async def getter(expect: bytes) -> None:
with assert_checkpoints():
assert await ubq.get() == expect
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"xyz")
nursery.start_soon(putter, b"xyz")
# Two gets at the same time -> BusyResourceError
with RaisesGroup(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"asdf")
nursery.start_soon(getter, b"asdf")
# Closing
ubq.close()
with pytest.raises(_core.ClosedResourceError):
ubq.put(b"---")
assert ubq.get_nowait(10) == b""
assert ubq.get_nowait() == b""
assert await ubq.get(10) == b""
assert await ubq.get() == b""
# close is idempotent
ubq.close()
# close wakes up blocked getters
ubq2 = _UnboundedByteQueue()
async def closer() -> None:
await wait_all_tasks_blocked()
ubq2.close()
async with _core.open_nursery() as nursery:
nursery.start_soon(getter, b"")
nursery.start_soon(closer)
async def test_MemorySendStream() -> None:
mss = MemorySendStream()
async def do_send_all(data: bytes) -> None:
with assert_checkpoints():
await mss.send_all(data)
await do_send_all(b"123")
assert mss.get_data_nowait(1) == b"1"
assert mss.get_data_nowait() == b"23"
with assert_checkpoints():
await mss.wait_send_all_might_not_block()
with pytest.raises(_core.WouldBlock):
mss.get_data_nowait()
with pytest.raises(_core.WouldBlock):
mss.get_data_nowait(10)
await do_send_all(b"456")
with assert_checkpoints():
assert await mss.get_data() == b"456"
# Call send_all twice at once; one should get BusyResourceError and one
# should succeed. But we can't let the error propagate, because it might
# cause the other to be cancelled before it can finish doing its thing,
# and we don't know which one will get the error.
resource_busy_count = 0
async def do_send_all_count_resourcebusy() -> None:
nonlocal resource_busy_count
try:
await do_send_all(b"xxx")
except _core.BusyResourceError:
resource_busy_count += 1
async with _core.open_nursery() as nursery:
nursery.start_soon(do_send_all_count_resourcebusy)
nursery.start_soon(do_send_all_count_resourcebusy)
assert resource_busy_count == 1
with assert_checkpoints():
await mss.aclose()
assert await mss.get_data() == b"xxx"
assert await mss.get_data() == b""
with pytest.raises(_core.ClosedResourceError):
await do_send_all(b"---")
# hooks
assert mss.send_all_hook is None
assert mss.wait_send_all_might_not_block_hook is None
assert mss.close_hook is None
record = []
async def send_all_hook() -> None:
# hook runs after send_all does its work (can pull data out)
assert mss2.get_data_nowait() == b"abc"
record.append("send_all_hook")
async def wait_send_all_might_not_block_hook() -> None:
record.append("wait_send_all_might_not_block_hook")
def close_hook() -> None:
record.append("close_hook")
mss2 = MemorySendStream(
send_all_hook,
wait_send_all_might_not_block_hook,
close_hook,
)
assert mss2.send_all_hook is send_all_hook
assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
assert mss2.close_hook is close_hook
await mss2.send_all(b"abc")
await mss2.wait_send_all_might_not_block()
await aclose_forcefully(mss2)
mss2.close()
assert record == [
"send_all_hook",
"wait_send_all_might_not_block_hook",
"close_hook",
"close_hook",
]
async def test_MemoryReceiveStream() -> None:
mrs = MemoryReceiveStream()
async def do_receive_some(max_bytes: int | None) -> bytes:
with assert_checkpoints():
return await mrs.receive_some(max_bytes)
mrs.put_data(b"abc")
assert await do_receive_some(1) == b"a"
assert await do_receive_some(10) == b"bc"
mrs.put_data(b"abc")
assert await do_receive_some(None) == b"abc"
with RaisesGroup(_core.BusyResourceError):
async with _core.open_nursery() as nursery:
nursery.start_soon(do_receive_some, 10)
nursery.start_soon(do_receive_some, 10)
assert mrs.receive_some_hook is None
mrs.put_data(b"def")
mrs.put_eof()
mrs.put_eof()
assert await do_receive_some(10) == b"def"
assert await do_receive_some(10) == b""
assert await do_receive_some(10) == b""
with pytest.raises(_core.ClosedResourceError):
mrs.put_data(b"---")
async def receive_some_hook() -> None:
mrs2.put_data(b"xxx")
record = []
def close_hook() -> None:
record.append("closed")
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
assert mrs2.receive_some_hook is receive_some_hook
assert mrs2.close_hook is close_hook
mrs2.put_data(b"yyy")
assert await mrs2.receive_some(10) == b"yyyxxx"
assert await mrs2.receive_some(10) == b"xxx"
assert await mrs2.receive_some(10) == b"xxx"
mrs2.put_data(b"zzz")
mrs2.receive_some_hook = None
assert await mrs2.receive_some(10) == b"zzz"
mrs2.put_data(b"lost on close")
with assert_checkpoints():
await mrs2.aclose()
assert record == ["closed"]
with pytest.raises(_core.ClosedResourceError):
await mrs2.receive_some(10)
async def test_MemoryRecvStream_closing() -> None:
mrs = MemoryReceiveStream()
# close with no pending data
mrs.close()
with pytest.raises(_core.ClosedResourceError):
assert await mrs.receive_some(10) == b""
# repeated closes ok
mrs.close()
# put_data now fails
with pytest.raises(_core.ClosedResourceError):
mrs.put_data(b"123")
mrs2 = MemoryReceiveStream()
# close with pending data
mrs2.put_data(b"xyz")
mrs2.close()
with pytest.raises(_core.ClosedResourceError):
await mrs2.receive_some(10)
async def test_memory_stream_pump() -> None:
mss = MemorySendStream()
mrs = MemoryReceiveStream()
# no-op if no data present
memory_stream_pump(mss, mrs)
await mss.send_all(b"123")
memory_stream_pump(mss, mrs)
assert await mrs.receive_some(10) == b"123"
await mss.send_all(b"456")
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert await mrs.receive_some(10) == b"4"
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert memory_stream_pump(mss, mrs, max_bytes=1)
assert not memory_stream_pump(mss, mrs, max_bytes=1)
assert await mrs.receive_some(10) == b"56"
mss.close()
memory_stream_pump(mss, mrs)
assert await mrs.receive_some(10) == b""
async def test_memory_stream_one_way_pair() -> None:
s, r = memory_stream_one_way_pair()
assert s.send_all_hook is not None
assert s.wait_send_all_might_not_block_hook is None
assert s.close_hook is not None
assert r.receive_some_hook is None
await s.send_all(b"123")
assert await r.receive_some(10) == b"123"
async def receiver(expected: bytes) -> None:
assert await r.receive_some(10) == expected
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"abc")
await wait_all_tasks_blocked()
await s.send_all(b"abc")
# And this fails if we don't pump from close_hook
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"")
await wait_all_tasks_blocked()
await s.aclose()
s, r = memory_stream_one_way_pair()
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver, b"")
await wait_all_tasks_blocked()
s.close()
s, r = memory_stream_one_way_pair()
old = s.send_all_hook
s.send_all_hook = None
await s.send_all(b"456")
async def cancel_after_idle(nursery: Nursery) -> None:
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
async def check_for_cancel() -> None:
with pytest.raises(_core.Cancelled):
# This should block forever... or until cancelled. Even though we
# sent some data on the send stream.
await r.receive_some(10)
async with _core.open_nursery() as nursery:
nursery.start_soon(cancel_after_idle, nursery)
nursery.start_soon(check_for_cancel)
s.send_all_hook = old
await s.send_all(b"789")
assert await r.receive_some(10) == b"456789"
async def test_memory_stream_pair() -> None:
a, b = memory_stream_pair()
await a.send_all(b"123")
await b.send_all(b"abc")
assert await b.receive_some(10) == b"123"
assert await a.receive_some(10) == b"abc"
await a.send_eof()
assert await b.receive_some(10) == b""
async def sender() -> None:
await wait_all_tasks_blocked()
await b.send_all(b"xyz")
async def receiver() -> None:
assert await a.receive_some(10) == b"xyz"
async with _core.open_nursery() as nursery:
nursery.start_soon(receiver)
nursery.start_soon(sender)
async def test_memory_streams_with_generic_tests() -> None:
async def one_way_stream_maker() -> tuple[MemorySendStream, MemoryReceiveStream]:
return memory_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, None)
async def half_closeable_stream_maker() -> tuple[
StapledStream[MemorySendStream, MemoryReceiveStream],
StapledStream[MemorySendStream, MemoryReceiveStream],
]:
return memory_stream_pair()
await check_half_closeable_stream(half_closeable_stream_maker, None)
async def test_lockstep_streams_with_generic_tests() -> None:
async def one_way_stream_maker() -> tuple[SendStream, ReceiveStream]:
return lockstep_stream_one_way_pair()
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
async def two_way_stream_maker() -> tuple[
StapledStream[SendStream, ReceiveStream],
StapledStream[SendStream, ReceiveStream],
]:
return lockstep_stream_pair()
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
async def test_open_stream_to_socket_listener() -> None:
async def check(listener: SocketListener) -> None:
async with listener:
client_stream = await open_stream_to_socket_listener(listener)
async with client_stream:
server_stream = await listener.accept()
async with server_stream:
await client_stream.send_all(b"x")
assert await server_stream.receive_some(1) == b"x"
# Listener bound to localhost
sock = tsocket.socket()
await sock.bind(("127.0.0.1", 0))
sock.listen(10)
await check(SocketListener(sock))
# Listener bound to IPv4 wildcard (needs special handling)
sock = tsocket.socket()
await sock.bind(("0.0.0.0", 0))
sock.listen(10)
await check(SocketListener(sock))
# true on all CI systems
if can_bind_ipv6: # pragma: no branch
# Listener bound to IPv6 wildcard (needs special handling)
sock = tsocket.socket(family=tsocket.AF_INET6)
await sock.bind(("::", 0))
sock.listen(10)
await check(SocketListener(sock))
if hasattr(tsocket, "AF_UNIX"):
# Listener bound to Unix-domain socket
sock = tsocket.socket(family=tsocket.AF_UNIX)
# can't use pytest's tmpdir; if we try then macOS says "OSError:
# AF_UNIX path too long"
with tempfile.TemporaryDirectory() as tmpdir:
path = f"{tmpdir}/sock"
await sock.bind(path)
sock.listen(10)
await check(SocketListener(sock))
def test_trio_test() -> None:
async def busy_kitchen(
*,
mock_clock: object,
autojump_clock: object,
) -> None: ... # pragma: no cover
with pytest.raises(ValueError, match="^too many clocks spoil the broth!$"):
trio_test(busy_kitchen)(
mock_clock=MockClock(),
autojump_clock=MockClock(autojump_threshold=0),
)

View File

@ -0,0 +1,374 @@
from __future__ import annotations
import re
import sys
from types import TracebackType
from typing import Any
import pytest
import trio
from trio.testing import Matcher, RaisesGroup
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
def wrap_escape(s: str) -> str:
return "^" + re.escape(s) + "$"
def test_raises_group() -> None:
with pytest.raises(
ValueError,
match=wrap_escape(
f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.',
),
):
RaisesGroup(TypeError())
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (ValueError(),))
with RaisesGroup(SyntaxError):
with RaisesGroup(ValueError):
raise ExceptionGroup("foo", (SyntaxError(),))
# multiple exceptions
with RaisesGroup(ValueError, SyntaxError):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# order doesn't matter
with RaisesGroup(SyntaxError, ValueError):
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
# nested exceptions
with RaisesGroup(RaisesGroup(ValueError)):
raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),))
with RaisesGroup(
SyntaxError,
RaisesGroup(ValueError),
RaisesGroup(RuntimeError),
):
raise ExceptionGroup(
"foo",
(
SyntaxError(),
ExceptionGroup("bar", (ValueError(),)),
ExceptionGroup("", (RuntimeError(),)),
),
)
# will error if there's excess exceptions
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (ValueError(), ValueError()))
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError):
raise ExceptionGroup("", (RuntimeError(), ValueError()))
# will error if there's missing exceptions
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, ValueError):
raise ExceptionGroup("", (ValueError(),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, SyntaxError):
raise ExceptionGroup("", (ValueError(),))
def test_flatten_subgroups() -> None:
# loose semantics, as with expect*
with RaisesGroup(ValueError, flatten_subgroups=True):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(), TypeError())),))
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()]), TypeError()])
# mixed loose is possible if you want it to be at least N deep
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
raise ExceptionGroup(
"",
(ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),),
)
with pytest.raises(ExceptionGroup):
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
raise ExceptionGroup("", (ValueError(),))
# but not the other way around
with pytest.raises(
ValueError,
match="^You cannot specify a nested structure inside a RaisesGroup with",
):
RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore[call-overload]
def test_catch_unwrapped_exceptions() -> None:
# Catches lone exceptions with strict=False
# just as except* would
with RaisesGroup(ValueError, allow_unwrapped=True):
raise ValueError
# expecting multiple unwrapped exceptions is not possible
with pytest.raises(
ValueError,
match="^You cannot specify multiple exceptions with",
):
RaisesGroup(SyntaxError, ValueError, allow_unwrapped=True) # type: ignore[call-overload]
# if users want one of several exception types they need to use a Matcher
# (which the error message suggests)
with RaisesGroup(
Matcher(check=lambda e: isinstance(e, (SyntaxError, ValueError))),
allow_unwrapped=True,
):
raise ValueError
# Unwrapped nested `RaisesGroup` is likely a user error, so we raise an error.
with pytest.raises(ValueError, match="has no effect when expecting"):
RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore[call-overload]
# But it *can* be used to check for nesting level +- 1 if they move it to
# the nested RaisesGroup. Users should probably use `Matcher`s instead though.
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
raise ExceptionGroup("", [ValueError()])
# with allow_unwrapped=False (default) it will not be caught
with pytest.raises(ValueError, match="^value error text$"):
with RaisesGroup(ValueError):
raise ValueError("value error text")
# allow_unwrapped on it's own won't match against nested groups
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, allow_unwrapped=True):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
# for that you need both allow_unwrapped and flatten_subgroups
with RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True):
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
# code coverage
with pytest.raises(TypeError):
with RaisesGroup(ValueError, allow_unwrapped=True):
raise TypeError
def test_match() -> None:
# supports match string
with RaisesGroup(ValueError, match="bar"):
raise ExceptionGroup("bar", (ValueError(),))
# now also works with ^$
with RaisesGroup(ValueError, match="^bar$"):
raise ExceptionGroup("bar", (ValueError(),))
# it also includes notes
with RaisesGroup(ValueError, match="my note"):
e = ExceptionGroup("bar", (ValueError(),))
e.add_note("my note")
raise e
# and technically you can match it all with ^$
# but you're probably better off using a Matcher at that point
with RaisesGroup(ValueError, match="^bar\nmy note$"):
e = ExceptionGroup("bar", (ValueError(),))
e.add_note("my note")
raise e
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, match="foo"):
raise ExceptionGroup("bar", (ValueError(),))
def test_check() -> None:
exc = ExceptionGroup("", (ValueError(),))
with RaisesGroup(ValueError, check=lambda x: x is exc):
raise exc
with pytest.raises(ExceptionGroup):
with RaisesGroup(ValueError, check=lambda x: x is exc):
raise ExceptionGroup("", (ValueError(),))
def test_unwrapped_match_check() -> None:
def my_check(e: object) -> bool: # pragma: no cover
return True
msg = (
"`allow_unwrapped=True` bypasses the `match` and `check` parameters"
" if the exception is unwrapped. If you intended to match/check the"
" exception you should use a `Matcher` object. If you want to match/check"
" the exceptiongroup when the exception *is* wrapped you need to"
" do e.g. `if isinstance(exc.value, ExceptionGroup):"
" assert RaisesGroup(...).matches(exc.value)` afterwards."
)
with pytest.raises(ValueError, match=re.escape(msg)):
RaisesGroup(ValueError, allow_unwrapped=True, match="foo") # type: ignore[call-overload]
with pytest.raises(ValueError, match=re.escape(msg)):
RaisesGroup(ValueError, allow_unwrapped=True, check=my_check) # type: ignore[call-overload]
# Users should instead use a Matcher
rg = RaisesGroup(Matcher(ValueError, match="^foo$"), allow_unwrapped=True)
with rg:
raise ValueError("foo")
with rg:
raise ExceptionGroup("", [ValueError("foo")])
# or if they wanted to match/check the group, do a conditional `.matches()`
with RaisesGroup(ValueError, allow_unwrapped=True) as exc:
raise ExceptionGroup("bar", [ValueError("foo")])
if isinstance(exc.value, ExceptionGroup): # pragma: no branch
assert RaisesGroup(ValueError, match="bar").matches(exc.value)
def test_RaisesGroup_matches() -> None:
rg = RaisesGroup(ValueError)
assert not rg.matches(None)
assert not rg.matches(ValueError())
assert rg.matches(ExceptionGroup("", (ValueError(),)))
def test_message() -> None:
def check_message(message: str, body: RaisesGroup[Any]) -> None:
with pytest.raises(
AssertionError,
match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$",
):
with body:
...
# basic
check_message("ExceptionGroup(ValueError)", RaisesGroup(ValueError))
# multiple exceptions
check_message(
"ExceptionGroup(ValueError, ValueError)",
RaisesGroup(ValueError, ValueError),
)
# nested
check_message(
"ExceptionGroup(ExceptionGroup(ValueError))",
RaisesGroup(RaisesGroup(ValueError)),
)
# Matcher
check_message(
"ExceptionGroup(Matcher(ValueError, match='my_str'))",
RaisesGroup(Matcher(ValueError, "my_str")),
)
check_message(
"ExceptionGroup(Matcher(match='my_str'))",
RaisesGroup(Matcher(match="my_str")),
)
# BaseExceptionGroup
check_message(
"BaseExceptionGroup(KeyboardInterrupt)",
RaisesGroup(KeyboardInterrupt),
)
# BaseExceptionGroup with type inside Matcher
check_message(
"BaseExceptionGroup(Matcher(KeyboardInterrupt))",
RaisesGroup(Matcher(KeyboardInterrupt)),
)
# Base-ness transfers to parent containers
check_message(
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt))",
RaisesGroup(RaisesGroup(KeyboardInterrupt)),
)
# but not to child containers
check_message(
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt), ExceptionGroup(ValueError))",
RaisesGroup(RaisesGroup(KeyboardInterrupt), RaisesGroup(ValueError)),
)
def test_matcher() -> None:
with pytest.raises(
ValueError,
match="^You must specify at least one parameter to match on.$",
):
Matcher() # type: ignore[call-overload]
with pytest.raises(
ValueError,
match=f"^exception_type {re.escape(repr(object))} must be a subclass of BaseException$",
):
Matcher(object) # type: ignore[type-var]
with RaisesGroup(Matcher(ValueError)):
raise ExceptionGroup("", (ValueError(),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(TypeError)):
raise ExceptionGroup("", (ValueError(),))
def test_matcher_match() -> None:
with RaisesGroup(Matcher(ValueError, "foo")):
raise ExceptionGroup("", (ValueError("foo"),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(ValueError, "foo")):
raise ExceptionGroup("", (ValueError("bar"),))
# Can be used without specifying the type
with RaisesGroup(Matcher(match="foo")):
raise ExceptionGroup("", (ValueError("foo"),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(match="foo")):
raise ExceptionGroup("", (ValueError("bar"),))
# check ^$
with RaisesGroup(Matcher(ValueError, match="^bar$")):
raise ExceptionGroup("", [ValueError("bar")])
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(ValueError, match="^bar$")):
raise ExceptionGroup("", [ValueError("barr")])
def test_Matcher_check() -> None:
def check_oserror_and_errno_is_5(e: BaseException) -> bool:
return isinstance(e, OSError) and e.errno == 5
with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)):
raise ExceptionGroup("", (OSError(5, ""),))
# specifying exception_type narrows the parameter type to the callable
def check_errno_is_5(e: OSError) -> bool:
return e.errno == 5
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
raise ExceptionGroup("", (OSError(5, ""),))
with pytest.raises(ExceptionGroup):
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
raise ExceptionGroup("", (OSError(6, ""),))
def test_matcher_tostring() -> None:
assert str(Matcher(ValueError)) == "Matcher(ValueError)"
assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')"
pattern_no_flags = re.compile("noflag", 0)
assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')"
pattern_flags = re.compile("noflag", re.IGNORECASE)
assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})"
assert (
str(Matcher(ValueError, match="re", check=bool))
== f"Matcher(ValueError, match='re', check={bool!r})"
)
def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
trio.testing._raises_group,
"ExceptionInfo",
trio.testing._raises_group._ExceptionInfo,
)
with trio.testing.RaisesGroup(ValueError) as excinfo:
raise ExceptionGroup("", (ValueError("hello"),))
assert excinfo.type is ExceptionGroup
assert excinfo.value.exceptions[0].args == ("hello",)
assert isinstance(excinfo.tb, TracebackType)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,272 @@
import time
from typing import Awaitable, Callable, Protocol, TypeVar
import outcome
import pytest
import trio
from .. import _core
from .._core._tests.tutil import slow
from .._timeouts import (
TooSlowError,
fail_after,
fail_at,
move_on_after,
move_on_at,
sleep,
sleep_forever,
sleep_until,
)
from ..testing import assert_checkpoints
T = TypeVar("T")
async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) -> T:
start = time.perf_counter()
result = await outcome.acapture(f)
dur = time.perf_counter() - start
print(dur / expected_dur)
# 1.5 is an arbitrary fudge factor because there's always some delay
# between when we become eligible to wake up and when we actually do. We
# used to sleep for 0.05, and regularly observed overruns of 1.6x on
# Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
# now we bumped up the sleep to 1 second, marked the tests as slow, and
# hopefully now the proportional error will be less huge.
#
# We also also for durations that are a hair shorter than expected. For
# example, here's a run on Windows where a 1.0 second sleep was measured
# to take 0.9999999999999858 seconds:
# https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
# I believe that what happened here is that Windows's low clock resolution
# meant that our calls to time.monotonic() returned exactly the same
# values as the calls inside the actual run loop, but the two subtractions
# returned slightly different values because the run loop's clock adds a
# random floating point offset to both times, which should cancel out, but
# lol floating point we got slightly different rounding errors. (That
# value above is exactly 128 ULPs below 1.0, which would make sense if it
# started as a 1 ULP error at a different dynamic range.)
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
return result.unwrap()
# How long to (attempt to) sleep for when testing. Smaller numbers make the
# test suite go faster.
TARGET = 1.0
@slow
async def test_sleep() -> None:
async def sleep_1() -> None:
await sleep_until(_core.current_time() + TARGET)
await check_takes_about(sleep_1, TARGET)
async def sleep_2() -> None:
await sleep(TARGET)
await check_takes_about(sleep_2, TARGET)
with assert_checkpoints():
await sleep(0)
# This also serves as a test of the trivial move_on_at
with move_on_at(_core.current_time()):
with pytest.raises(_core.Cancelled):
await sleep(0)
@slow
async def test_move_on_after() -> None:
async def sleep_3() -> None:
with move_on_after(TARGET):
await sleep(100)
await check_takes_about(sleep_3, TARGET)
async def test_cannot_wake_sleep_forever() -> None:
# Test an error occurs if you manually wake sleep_forever().
task = trio.lowlevel.current_task()
async def wake_task() -> None:
await trio.lowlevel.checkpoint()
trio.lowlevel.reschedule(task, outcome.Value(None))
async with trio.open_nursery() as nursery:
nursery.start_soon(wake_task)
with pytest.raises(RuntimeError):
await trio.sleep_forever()
class TimeoutScope(Protocol):
def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ...
@pytest.mark.parametrize("scope", [move_on_after, fail_after])
async def test_context_shields_from_outer(scope: TimeoutScope) -> None:
with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner:
outer.cancel()
try:
await trio.lowlevel.checkpoint()
except trio.Cancelled:
pytest.fail("shield didn't work")
inner.shield = False
with pytest.raises(trio.Cancelled):
await trio.lowlevel.checkpoint()
@slow
async def test_move_on_after_moves_on_even_if_shielded() -> None:
async def task() -> None:
with _core.CancelScope() as outer, move_on_after(TARGET, shield=True):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()
await check_takes_about(task, TARGET)
@slow
async def test_fail_after_fails_even_if_shielded() -> None:
async def task() -> None:
with pytest.raises(TooSlowError), _core.CancelScope() as outer, fail_after(
TARGET,
shield=True,
):
outer.cancel()
# The outer scope is cancelled, but this task is protected by the
# shield, so it manages to get to sleep until deadline is met
await sleep_forever()
await check_takes_about(task, TARGET)
@slow
async def test_fail() -> None:
async def sleep_4() -> None:
with fail_at(_core.current_time() + TARGET):
await sleep(100)
with pytest.raises(TooSlowError):
await check_takes_about(sleep_4, TARGET)
with fail_at(_core.current_time() + 100):
await sleep(0)
async def sleep_5() -> None:
with fail_after(TARGET):
await sleep(100)
with pytest.raises(TooSlowError):
await check_takes_about(sleep_5, TARGET)
with fail_after(100):
await sleep(0)
async def test_timeouts_raise_value_error() -> None:
# deadlines are allowed to be negative, but not delays.
# neither delays nor deadlines are allowed to be NaN
nan = float("nan")
for fun, val in (
(sleep, -1),
(sleep, nan),
(sleep_until, nan),
):
with pytest.raises(
ValueError,
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
):
await fun(val)
for cm, val in (
(fail_after, -1),
(fail_after, nan),
(fail_at, nan),
(move_on_after, -1),
(move_on_after, nan),
(move_on_at, nan),
):
with pytest.raises(
ValueError,
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
):
with cm(val):
pass # pragma: no cover
async def test_timeout_deadline_on_entry(mock_clock: _core.MockClock) -> None:
rcs = move_on_after(5)
assert rcs.relative_deadline == 5
mock_clock.jump(3)
start = _core.current_time()
with rcs as cs:
assert cs.is_relative is None
# This would previously be start+2
assert cs.deadline == start + 5
assert cs.relative_deadline == 5
cs.deadline = start + 3
assert cs.deadline == start + 3
assert cs.relative_deadline == 3
cs.relative_deadline = 4
assert cs.deadline == start + 4
assert cs.relative_deadline == 4
rcs = move_on_after(5)
assert rcs.shield is False
rcs.shield = True
assert rcs.shield is True
mock_clock.jump(3)
start = _core.current_time()
with rcs as cs:
assert cs.deadline == start + 5
assert rcs is cs
async def test_invalid_access_unentered(mock_clock: _core.MockClock) -> None:
cs = move_on_after(5)
mock_clock.jump(3)
start = _core.current_time()
match_str = "^unentered relative cancel scope does not have an absolute deadline"
with pytest.warns(DeprecationWarning, match=match_str):
assert cs.deadline == start + 5
mock_clock.jump(1)
# this is hella sketchy, but they *have* been warned
with pytest.warns(DeprecationWarning, match=match_str):
assert cs.deadline == start + 6
with pytest.warns(DeprecationWarning, match=match_str):
cs.deadline = 7
# now transformed into absolute
assert cs.deadline == 7
assert not cs.is_relative
cs = move_on_at(5)
match_str = (
"^unentered non-relative cancel scope does not have a relative deadline$"
)
with pytest.raises(RuntimeError, match=match_str):
assert cs.relative_deadline
with pytest.raises(RuntimeError, match=match_str):
cs.relative_deadline = 7
@pytest.mark.xfail(reason="not implemented")
async def test_fail_access_before_entering() -> None: # pragma: no cover
my_fail_at = fail_at(5)
assert my_fail_at.deadline # type: ignore[attr-defined]
my_fail_after = fail_after(5)
assert my_fail_after.relative_deadline # type: ignore[attr-defined]

View File

@ -0,0 +1,66 @@
from typing import AsyncGenerator
import trio
async def coro1(event: trio.Event) -> None:
event.set()
await trio.sleep_forever()
async def coro2(event: trio.Event) -> None:
await coro1(event)
async def coro3(event: trio.Event) -> None:
await coro2(event)
async def coro2_async_gen(event: trio.Event) -> AsyncGenerator[None, None]:
# mypy does not like `yield await trio.lowlevel.checkpoint()` - but that
# should be equivalent to splitting the statement
await trio.lowlevel.checkpoint()
yield
await coro1(event)
yield # pragma: no cover
await trio.lowlevel.checkpoint() # pragma: no cover
yield # pragma: no cover
async def coro3_async_gen(event: trio.Event) -> None:
async for _ in coro2_async_gen(event):
pass
async def test_task_iter_await_frames() -> None:
async with trio.open_nursery() as nursery:
event = trio.Event()
nursery.start_soon(coro3, event)
await event.wait()
(task,) = nursery.child_tasks
assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [
"coro3",
"coro2",
"coro1",
]
nursery.cancel_scope.cancel()
async def test_task_iter_await_frames_async_gen() -> None:
async with trio.open_nursery() as nursery:
event = trio.Event()
nursery.start_soon(coro3_async_gen, event)
await event.wait()
(task,) = nursery.child_tasks
assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [
"coro3_async_gen",
"coro2_async_gen",
"coro1",
]
nursery.cancel_scope.cancel()

View File

@ -0,0 +1,8 @@
def test_trio_import() -> None:
import sys
for module in list(sys.modules.keys()):
if module.startswith("trio"):
del sys.modules[module]
import trio # noqa: F401

View File

@ -0,0 +1,282 @@
from __future__ import annotations
import errno
import os
import select
import sys
from typing import TYPE_CHECKING
import pytest
from .. import _core
from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken
from ..testing import check_one_way_stream, wait_all_tasks_blocked
posix = os.name == "posix"
pytestmark = pytest.mark.skipif(not posix, reason="posix only")
assert not TYPE_CHECKING or sys.platform == "unix"
if posix:
from .._unix_pipes import FdStream
else:
with pytest.raises(ImportError):
from .._unix_pipes import FdStream
async def make_pipe() -> tuple[FdStream, FdStream]:
"""Makes a new pair of pipes."""
(r, w) = os.pipe()
return FdStream(w), FdStream(r)
async def make_clogged_pipe():
s, r = await make_pipe()
try:
while True:
# We want to totally fill up the pipe buffer.
# This requires working around a weird feature that POSIX pipes
# have.
# If you do a write of <= PIPE_BUF bytes, then it's guaranteed
# to either complete entirely, or not at all. So if we tried to
# write PIPE_BUF bytes, and the buffer's free space is only
# PIPE_BUF/2, then the write will raise BlockingIOError... even
# though a smaller write could still succeed! To avoid this,
# make sure to write >PIPE_BUF bytes each time, which disables
# the special behavior.
# For details, search for PIPE_BUF here:
# http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html
# for the getattr:
# https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3
buf_size = getattr(select, "PIPE_BUF", 8192)
os.write(s.fileno(), b"x" * buf_size * 2)
except BlockingIOError:
pass
return s, r
async def test_send_pipe() -> None:
r, w = os.pipe()
async with FdStream(w) as send:
assert send.fileno() == w
await send.send_all(b"123")
assert (os.read(r, 8)) == b"123"
os.close(r)
async def test_receive_pipe() -> None:
r, w = os.pipe()
async with FdStream(r) as recv:
assert (recv.fileno()) == r
os.write(w, b"123")
assert (await recv.receive_some(8)) == b"123"
os.close(w)
async def test_pipes_combined() -> None:
write, read = await make_pipe()
count = 2**20
async def sender() -> None:
big = bytearray(count)
await write.send_all(big)
async def reader() -> None:
await wait_all_tasks_blocked()
received = 0
while received < count:
received += len(await read.receive_some(4096))
assert received == count
async with _core.open_nursery() as n:
n.start_soon(sender)
n.start_soon(reader)
await read.aclose()
await write.aclose()
async def test_pipe_errors() -> None:
with pytest.raises(TypeError):
FdStream(None)
r, w = os.pipe()
os.close(w)
async with FdStream(r) as s:
with pytest.raises(ValueError, match="^max_bytes must be integer >= 1$"):
await s.receive_some(0)
async def test_del() -> None:
w, r = await make_pipe()
f1, f2 = w.fileno(), r.fileno()
del w, r
gc_collect_harder()
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
os.close(f1)
assert excinfo.value.errno == errno.EBADF
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
os.close(f2)
assert excinfo.value.errno == errno.EBADF
async def test_async_with() -> None:
w, r = await make_pipe()
async with w, r:
pass
assert w.fileno() == -1
assert r.fileno() == -1
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
os.close(w.fileno())
assert excinfo.value.errno == errno.EBADF
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
os.close(r.fileno())
assert excinfo.value.errno == errno.EBADF
async def test_misdirected_aclose_regression() -> None:
# https://github.com/python-trio/trio/issues/661#issuecomment-456582356
w, r = await make_pipe()
old_r_fd = r.fileno()
# Close the original objects
await w.aclose()
await r.aclose()
# Do a little dance to get a new pipe whose receive handle matches the old
# receive handle.
r2_fd, w2_fd = os.pipe()
if r2_fd != old_r_fd: # pragma: no cover
os.dup2(r2_fd, old_r_fd)
os.close(r2_fd)
async with FdStream(old_r_fd) as r2:
assert r2.fileno() == old_r_fd
# And now set up a background task that's working on the new receive
# handle
async def expect_eof() -> None:
assert await r2.receive_some(10) == b""
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_eof)
await wait_all_tasks_blocked()
# Here's the key test: does calling aclose() again on the *old*
# handle, cause the task blocked on the *new* handle to raise
# ClosedResourceError?
await r.aclose()
await wait_all_tasks_blocked()
# Guess we survived! Close the new write handle so that the task
# gets an EOF and can exit cleanly.
os.close(w2_fd)
async def test_close_at_bad_time_for_receive_some(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
# https://github.com/python-trio/trio/issues/661
#
# This tests what happens if the pipe gets closed in the moment *between*
# when receive_some wakes up, and when it tries to call os.read
async def expect_closedresourceerror() -> None:
with pytest.raises(_core.ClosedResourceError):
await r.receive_some(10)
orig_wait_readable = _core._run.TheIOManager.wait_readable
async def patched_wait_readable(*args, **kwargs) -> None:
await orig_wait_readable(*args, **kwargs)
await r.aclose()
monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable)
s, r = await make_pipe()
async with s, r:
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_closedresourceerror)
await wait_all_tasks_blocked()
# Trigger everything by waking up the receiver
await s.send_all(b"x")
async def test_close_at_bad_time_for_send_all(monkeypatch: pytest.MonkeyPatch) -> None:
# We used to have race conditions where if one task was using the pipe,
# and another closed it at *just* the wrong moment, it would give an
# unexpected error instead of ClosedResourceError:
# https://github.com/python-trio/trio/issues/661
#
# This tests what happens if the pipe gets closed in the moment *between*
# when send_all wakes up, and when it tries to call os.write
async def expect_closedresourceerror() -> None:
with pytest.raises(_core.ClosedResourceError):
await s.send_all(b"x" * 100)
orig_wait_writable = _core._run.TheIOManager.wait_writable
async def patched_wait_writable(*args, **kwargs) -> None:
await orig_wait_writable(*args, **kwargs)
await s.aclose()
monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable)
s, r = await make_clogged_pipe()
async with s, r:
async with _core.open_nursery() as nursery:
nursery.start_soon(expect_closedresourceerror)
await wait_all_tasks_blocked()
# Trigger everything by waking up the sender. On ppc64el, PIPE_BUF
# is 8192 but make_clogged_pipe() ends up writing a total of
# 1048576 bytes before the pipe is full, and then a subsequent
# receive_some(10000) isn't sufficient for orig_wait_writable() to
# return for our subsequent aclose() call. It's necessary to empty
# the pipe further before this happens. So we loop here until the
# pipe is empty to make sure that the sender wakes up even in this
# case. Otherwise patched_wait_writable() never gets to the
# aclose(), so expect_closedresourceerror() never returns, the
# nursery never finishes all tasks and this test hangs.
received_data = await r.receive_some(10000)
while received_data:
received_data = await r.receive_some(10000)
# On FreeBSD, directories are readable, and we haven't found any other trick
# for making an unreadable fd, so there's no way to run this test. Fortunately
# the logic this is testing doesn't depend on the platform, so testing on
# other platforms is probably good enough.
@pytest.mark.skipif(
sys.platform.startswith("freebsd"),
reason="no way to make read() return a bizarro error on FreeBSD",
)
async def test_bizarro_OSError_from_receive() -> None:
# Make sure that if the read syscall returns some bizarro error, then we
# get a BrokenResourceError. This is incredibly unlikely; there's almost
# no way to trigger a failure here intentionally (except for EBADF, but we
# exploit that to detect file closure, so it takes a different path). So
# we set up a strange scenario where the pipe fd somehow transmutes into a
# directory fd, causing os.read to raise IsADirectoryError (yes, that's a
# real built-in exception type).
s, r = await make_pipe()
async with s, r:
dir_fd = os.open("/", os.O_DIRECTORY, 0)
try:
os.dup2(dir_fd, r.fileno())
with pytest.raises(_core.BrokenResourceError):
await r.receive_some(10)
finally:
os.close(dir_fd)
@skip_if_fbsd_pipes_broken
async def test_pipe_fully() -> None:
await check_one_way_stream(make_pipe, make_clogged_pipe)

View File

@ -0,0 +1,275 @@
import signal
import sys
import types
from typing import Any, TypeVar
import pytest
import trio
from trio.testing import Matcher, RaisesGroup
from .. import _core
from .._core._tests.tutil import (
create_asyncio_future_in_new_loop,
ignore_coroutine_never_awaited_warnings,
)
from .._util import (
ConflictDetector,
NoPublicConstructor,
coroutine_or_error,
final,
fixup_module_metadata,
generic_function,
is_main_thread,
signal_raise,
)
from ..testing import wait_all_tasks_blocked
T = TypeVar("T")
def test_signal_raise() -> None:
record = []
def handler(signum: int, _: object) -> None:
record.append(signum)
old = signal.signal(signal.SIGFPE, handler)
try:
signal_raise(signal.SIGFPE)
finally:
signal.signal(signal.SIGFPE, old)
assert record == [signal.SIGFPE]
async def test_ConflictDetector() -> None:
ul1 = ConflictDetector("ul1")
ul2 = ConflictDetector("ul2")
with ul1:
with ul2:
print("ok")
with pytest.raises(_core.BusyResourceError, match="ul1"):
with ul1:
with ul1:
pass # pragma: no cover
async def wait_with_ul1() -> None:
with ul1:
await wait_all_tasks_blocked()
with RaisesGroup(Matcher(_core.BusyResourceError, "ul1")):
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_with_ul1)
nursery.start_soon(wait_with_ul1)
def test_module_metadata_is_fixed_up() -> None:
import trio
import trio.testing
assert trio.Cancelled.__module__ == "trio"
assert trio.open_nursery.__module__ == "trio"
assert trio.abc.Stream.__module__ == "trio.abc"
assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel"
assert trio.testing.trio_test.__module__ == "trio.testing"
# Also check methods
assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel"
assert trio.abc.Stream.send_all.__module__ == "trio.abc"
# And names
assert trio.Cancelled.__name__ == "Cancelled"
assert trio.Cancelled.__qualname__ == "Cancelled"
assert trio.abc.SendStream.send_all.__name__ == "send_all"
assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all"
assert trio.to_thread.__name__ == "trio.to_thread"
assert trio.to_thread.run_sync.__name__ == "run_sync"
assert trio.to_thread.run_sync.__qualname__ == "run_sync"
async def test_is_main_thread() -> None:
assert is_main_thread()
def not_main_thread() -> None:
assert not is_main_thread()
await trio.to_thread.run_sync(not_main_thread)
# @coroutine is deprecated since python 3.8, which is fine with us.
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
def test_coroutine_or_error() -> None:
class Deferred:
"Just kidding"
with ignore_coroutine_never_awaited_warnings():
async def f() -> None: # pragma: no cover
pass
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(f()) # type: ignore[arg-type, unused-coroutine]
assert "expecting an async function" in str(excinfo.value)
import asyncio
if sys.version_info < (3, 11):
# not bothering to type this one
@asyncio.coroutine # type: ignore[misc]
def generator_based_coro() -> Any: # pragma: no cover
yield from asyncio.sleep(1)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(generator_based_coro()) # type: ignore[arg-type, unused-coroutine]
assert "asyncio" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(create_asyncio_future_in_new_loop()) # type: ignore[arg-type, unused-coroutine]
assert "asyncio" in str(excinfo.value)
# does not raise arg-type error
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(create_asyncio_future_in_new_loop) # type: ignore[unused-coroutine]
assert "asyncio" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(Deferred()) # type: ignore[arg-type, unused-coroutine]
assert "twisted" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(lambda: Deferred()) # type: ignore[arg-type, unused-coroutine, return-value]
assert "twisted" in str(excinfo.value)
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(len, [[1, 2, 3]]) # type: ignore[arg-type, unused-coroutine]
assert "appears to be synchronous" in str(excinfo.value)
async def async_gen(_: object) -> Any: # pragma: no cover
yield
# does not give arg-type typing error
with pytest.raises(TypeError) as excinfo:
coroutine_or_error(async_gen, [0]) # type: ignore[unused-coroutine]
msg = "expected an async function but got an async generator"
assert msg in str(excinfo.value)
# Make sure no references are kept around to keep anything alive
del excinfo
def test_generic_function() -> None:
@generic_function # Decorated function contains "Any".
def test_func(arg: T) -> T: # type: ignore[misc]
"""Look, a docstring!"""
return arg
assert test_func is test_func[int] is test_func[int, str]
assert test_func(42) == test_func[int](42) == 42
assert test_func.__doc__ == "Look, a docstring!"
assert test_func.__qualname__ == "test_generic_function.<locals>.test_func" # type: ignore[attr-defined]
assert test_func.__name__ == "test_func" # type: ignore[attr-defined]
assert test_func.__module__ == __name__
def test_final_decorator() -> None:
"""Test that subclassing a @final-annotated class is not allowed.
This checks both runtime results, and verifies that type checkers detect
the error statically through the type-ignore comment.
"""
@final
class FinalClass:
pass
with pytest.raises(TypeError):
class SubClass(FinalClass): # type: ignore[misc]
pass
def test_no_public_constructor_metaclass() -> None:
"""The NoPublicConstructor metaclass prevents calling the constructor directly."""
class SpecialClass(metaclass=NoPublicConstructor):
def __init__(self, a: int, b: float) -> None:
"""Check arguments can be passed to __init__."""
assert a == 8
assert b == 3.14
with pytest.raises(TypeError):
SpecialClass(8, 3.14)
# Private constructor should not raise, and passes args to __init__.
assert isinstance(SpecialClass._create(8, b=3.14), SpecialClass)
def test_fixup_module_metadata() -> None:
# Ignores modules not in the trio.X tree.
non_trio_module = types.ModuleType("not_trio")
non_trio_module.some_func = lambda: None # type: ignore[attr-defined]
non_trio_module.some_func.__name__ = "some_func"
non_trio_module.some_func.__qualname__ = "some_func"
fixup_module_metadata(non_trio_module.__name__, vars(non_trio_module))
assert non_trio_module.some_func.__name__ == "some_func"
assert non_trio_module.some_func.__qualname__ == "some_func"
# Bulild up a fake module to test. Just use lambdas since all we care about is the names.
mod = types.ModuleType("trio._somemodule_impl")
mod.some_func = lambda: None # type: ignore[attr-defined]
mod.some_func.__name__ = "_something_else"
mod.some_func.__qualname__ = "_something_else"
# No __module__ means it's unchanged.
mod.not_funclike = types.SimpleNamespace() # type: ignore[attr-defined]
mod.not_funclike.__name__ = "not_funclike"
# Check __qualname__ being absent works.
mod.only_has_name = types.SimpleNamespace() # type: ignore[attr-defined]
mod.only_has_name.__module__ = "trio._somemodule_impl"
mod.only_has_name.__name__ = "only_name"
# Underscored names are unchanged.
mod._private = lambda: None # type: ignore[attr-defined]
mod._private.__module__ = "trio._somemodule_impl"
mod._private.__name__ = mod._private.__qualname__ = "_private"
# We recurse into classes.
mod.SomeClass = type( # type: ignore[attr-defined]
"SomeClass",
(),
{
"__init__": lambda self: None,
"method": lambda self: None,
},
)
# Reference loop is fine.
mod.SomeClass.recursion = mod.SomeClass # type: ignore[attr-defined]
fixup_module_metadata("trio.somemodule", vars(mod))
assert mod.some_func.__name__ == "some_func"
assert mod.some_func.__module__ == "trio.somemodule"
assert mod.some_func.__qualname__ == "some_func"
assert mod.not_funclike.__name__ == "not_funclike"
assert mod._private.__name__ == "_private"
assert mod._private.__module__ == "trio._somemodule_impl"
assert mod._private.__qualname__ == "_private"
assert mod.only_has_name.__name__ == "only_has_name"
assert mod.only_has_name.__module__ == "trio.somemodule"
assert not hasattr(mod.only_has_name, "__qualname__")
assert mod.SomeClass.method.__name__ == "method" # type: ignore[attr-defined]
assert mod.SomeClass.method.__module__ == "trio.somemodule" # type: ignore[attr-defined]
assert mod.SomeClass.method.__qualname__ == "SomeClass.method" # type: ignore[attr-defined]
# Make coverage happy.
non_trio_module.some_func()
mod.some_func()
mod._private()
mod.SomeClass().method()

View File

@ -0,0 +1,225 @@
import os
import pytest
on_windows = os.name == "nt"
# Mark all the tests in this file as being windows-only
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
import trio
from .. import _core, _timeouts
from .._core._tests.tutil import slow
if on_windows:
from .._core._windows_cffi import Handle, ffi, kernel32
from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject
async def test_WaitForMultipleObjects_sync() -> None:
# This does a series of tests where we set/close the handle before
# initiating the waiting for it.
#
# Note that closing the handle (not signaling) will cause the
# *initiation* of a wait to return immediately. But closing a handle
# that is already being waited on will not stop whatever is waiting
# for it.
# One handle
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle1)
WaitForMultipleObjects_sync(handle1)
kernel32.CloseHandle(handle1)
print("test_WaitForMultipleObjects_sync one OK")
# Two handles, signal first
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle1)
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync set first OK")
# Two handles, signal second
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle2)
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync set second OK")
# Two handles, close first
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.CloseHandle(handle1)
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync close first OK")
# Two handles, close second
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.CloseHandle(handle2)
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
WaitForMultipleObjects_sync(handle1, handle2)
kernel32.CloseHandle(handle1)
print("test_WaitForMultipleObjects_sync close second OK")
@slow
async def test_WaitForMultipleObjects_sync_slow() -> None:
# This does a series of test in which the main thread sync-waits for
# handles, while we spawn a thread to set the handles after a short while.
TIMEOUT = 0.3
# One handle
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(
trio.to_thread.run_sync,
WaitForMultipleObjects_sync,
handle1,
)
await _timeouts.sleep(TIMEOUT)
# If we would comment the line below, the above thread will be stuck,
# and Trio won't exit this scope
kernel32.SetEvent(handle1)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
kernel32.CloseHandle(handle1)
print("test_WaitForMultipleObjects_sync_slow one OK")
# Two handles, signal first
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(
trio.to_thread.run_sync,
WaitForMultipleObjects_sync,
handle1,
handle2,
)
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle1)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync_slow thread-set first OK")
# Two handles, signal second
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(
trio.to_thread.run_sync,
WaitForMultipleObjects_sync,
handle1,
handle2,
)
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle2)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
kernel32.CloseHandle(handle1)
kernel32.CloseHandle(handle2)
print("test_WaitForMultipleObjects_sync_slow thread-set second OK")
async def test_WaitForSingleObject() -> None:
# This does a series of test for setting/closing the handle before
# initiating the wait.
# Test already set
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.SetEvent(handle)
await WaitForSingleObject(handle) # should return at once
kernel32.CloseHandle(handle)
print("test_WaitForSingleObject already set OK")
# Test already set, as int
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle_int = int(ffi.cast("intptr_t", handle))
kernel32.SetEvent(handle)
await WaitForSingleObject(handle_int) # should return at once
kernel32.CloseHandle(handle)
print("test_WaitForSingleObject already set OK")
# Test already closed
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
kernel32.CloseHandle(handle)
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
await WaitForSingleObject(handle) # should return at once
print("test_WaitForSingleObject already closed OK")
# Not a handle
with pytest.raises(TypeError):
await WaitForSingleObject("not a handle") # type: ignore[arg-type] # Wrong type
# with pytest.raises(OSError):
# await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :(
print("test_WaitForSingleObject not a handle OK")
@slow
async def test_WaitForSingleObject_slow() -> None:
# This does a series of test for setting the handle in another task,
# and cancelling the wait task.
# Set the timeout used in the tests. We test the waiting time against
# the timeout with a certain margin.
TIMEOUT = 0.3
async def signal_soon_async(handle: Handle) -> None:
await _timeouts.sleep(TIMEOUT)
kernel32.SetEvent(handle)
# Test handle is SET after TIMEOUT in separate coroutine
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(WaitForSingleObject, handle)
nursery.start_soon(signal_soon_async, handle)
kernel32.CloseHandle(handle)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
print("test_WaitForSingleObject_slow set from task OK")
# Test handle is SET after TIMEOUT in separate coroutine, as int
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
handle_int = int(ffi.cast("intptr_t", handle))
t0 = _core.current_time()
async with _core.open_nursery() as nursery:
nursery.start_soon(WaitForSingleObject, handle_int)
nursery.start_soon(signal_soon_async, handle)
kernel32.CloseHandle(handle)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
print("test_WaitForSingleObject_slow set from task as int OK")
# Test handle is CLOSED after 1 sec - NOPE see comment above
# Test cancellation
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
t0 = _core.current_time()
with _timeouts.move_on_after(TIMEOUT):
await WaitForSingleObject(handle)
kernel32.CloseHandle(handle)
t1 = _core.current_time()
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
print("test_WaitForSingleObject_slow cancellation OK")

View File

@ -0,0 +1,112 @@
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
import pytest
from .. import _core
from ..testing import check_one_way_stream, wait_all_tasks_blocked
# Mark all the tests in this file as being windows-only
pytestmark = pytest.mark.skipif(sys.platform != "win32", reason="windows only")
assert ( # Skip type checking when not on Windows
sys.platform == "win32" or not TYPE_CHECKING
)
if sys.platform == "win32":
from asyncio.windows_utils import pipe
from .._core._windows_cffi import _handle, kernel32
from .._windows_pipes import PipeReceiveStream, PipeSendStream
async def make_pipe() -> tuple[PipeSendStream, PipeReceiveStream]:
"""Makes a new pair of pipes."""
(r, w) = pipe()
return PipeSendStream(w), PipeReceiveStream(r)
async def test_pipe_typecheck() -> None:
with pytest.raises(TypeError):
PipeSendStream(1.0) # type: ignore[arg-type]
with pytest.raises(TypeError):
PipeReceiveStream(None) # type: ignore[arg-type]
async def test_pipe_error_on_close() -> None:
# Make sure we correctly handle a failure from kernel32.CloseHandle
r, w = pipe()
send_stream = PipeSendStream(w)
receive_stream = PipeReceiveStream(r)
assert kernel32.CloseHandle(_handle(r))
assert kernel32.CloseHandle(_handle(w))
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
await send_stream.aclose()
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
await receive_stream.aclose()
async def test_pipes_combined() -> None:
write, read = await make_pipe()
count = 2**20
replicas = 3
async def sender() -> None:
async with write:
big = bytearray(count)
for _ in range(replicas):
await write.send_all(big)
async def reader() -> None:
async with read:
await wait_all_tasks_blocked()
total_received = 0
while True:
# 5000 is chosen because it doesn't evenly divide 2**20
received = len(await read.receive_some(5000))
if not received:
break
total_received += received
assert total_received == count * replicas
async with _core.open_nursery() as n:
n.start_soon(sender)
n.start_soon(reader)
async def test_async_with() -> None:
w, r = await make_pipe()
async with w, r:
pass
with pytest.raises(_core.ClosedResourceError):
await w.send_all(b"")
with pytest.raises(_core.ClosedResourceError):
await r.receive_some(10)
async def test_close_during_write() -> None:
w, r = await make_pipe()
async with _core.open_nursery() as nursery:
async def write_forever() -> None:
with pytest.raises(_core.ClosedResourceError) as excinfo:
while True:
await w.send_all(b"x" * 4096)
assert "another task" in str(excinfo.value)
nursery.start_soon(write_forever)
await wait_all_tasks_blocked(0.1)
await w.aclose()
async def test_pipe_fully() -> None:
# passing make_clogged_pipe tests wait_send_all_might_not_block, and we
# can't implement that on Windows
await check_one_way_stream(make_pipe, None)

View File

@ -0,0 +1,176 @@
import ast
import sys
from pathlib import Path
import pytest
from trio._tests.pytest_plugin import skip_if_optional_else_raise
# imports in gen_exports that are not in `install_requires` in setup.py
try:
import astor # noqa: F401
import isort # noqa: F401
except ImportError as error:
skip_if_optional_else_raise(error)
from trio._tools.gen_exports import (
File,
create_passthrough_args,
get_public_methods,
process,
run_linters,
run_ruff,
)
SOURCE = '''from _run import _public
from collections import Counter
class Test:
@_public
def public_func(self):
"""With doc string"""
@ignore_this
@_public
@another_decorator
async def public_async_func(self) -> Counter:
pass # no doc string
def not_public(self):
pass
async def not_public_async(self):
pass
'''
IMPORT_1 = """\
from collections import Counter
"""
IMPORT_2 = """\
from collections import Counter
import os
"""
IMPORT_3 = """\
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections import Counter
"""
def test_get_public_methods() -> None:
methods = list(get_public_methods(ast.parse(SOURCE)))
assert {m.name for m in methods} == {"public_func", "public_async_func"}
def test_create_pass_through_args() -> None:
testcases = [
("def f()", "()"),
("def f(one)", "(one)"),
("def f(one, two)", "(one, two)"),
("def f(one, *args)", "(one, *args)"),
(
"def f(one, *args, kw1, kw2=None, **kwargs)",
"(one, *args, kw1=kw1, kw2=kw2, **kwargs)",
),
]
for funcdef, expected in testcases:
func_node = ast.parse(funcdef + ":\n pass").body[0]
assert isinstance(func_node, ast.FunctionDef)
assert create_passthrough_args(func_node) == expected
skip_lints = pytest.mark.skipif(
sys.implementation.name != "cpython",
reason="gen_exports is internal, black/isort only runs on CPython",
)
@skip_lints
@pytest.mark.parametrize("imports", [IMPORT_1, IMPORT_2, IMPORT_3])
def test_process(
tmp_path: Path,
imports: str,
capsys: pytest.CaptureFixture[str],
) -> None:
try:
import black # noqa: F401
# there's no dedicated CI run that has astor+isort, but lacks black.
except ImportError as error: # pragma: no cover
skip_if_optional_else_raise(error)
modpath = tmp_path / "_module.py"
genpath = tmp_path / "_generated_module.py"
modpath.write_text(SOURCE, encoding="utf-8")
file = File(modpath, "runner", platform="linux", imports=imports)
assert not genpath.exists()
with pytest.raises(SystemExit) as excinfo:
process([file], do_test=True)
assert excinfo.value.code == 1
captured = capsys.readouterr()
assert "Generated sources are outdated. Please regenerate." in captured.out
with pytest.raises(SystemExit) as excinfo:
process([file], do_test=False)
assert excinfo.value.code == 1
captured = capsys.readouterr()
assert "Regenerated sources successfully." in captured.out
assert genpath.exists()
process([file], do_test=True)
# But if we change the lookup path it notices
with pytest.raises(SystemExit) as excinfo:
process(
[File(modpath, "runner.io_manager", platform="linux", imports=imports)],
do_test=True,
)
assert excinfo.value.code == 1
# Also if the platform is changed.
with pytest.raises(SystemExit) as excinfo:
process([File(modpath, "runner", imports=imports)], do_test=True)
assert excinfo.value.code == 1
@skip_lints
def test_run_ruff(tmp_path: Path) -> None:
"""Test that processing properly fails if ruff does."""
try:
import ruff # noqa: F401
except ImportError as error: # pragma: no cover
skip_if_optional_else_raise(error)
file = File(tmp_path / "module.py", "module")
success, _ = run_ruff(file, "class not valid code ><")
assert not success
test_function = '''def combine_and(data: list[str]) -> str:
"""Join values of text, and have 'and' with the last one properly."""
if len(data) >= 2:
data[-1] = 'and ' + data[-1]
if len(data) > 2:
return ', '.join(data)
return ' '.join(data)'''
success, response = run_ruff(file, test_function)
assert success
assert response == test_function
@skip_lints
def test_lint_failure(tmp_path: Path) -> None:
"""Test that processing properly fails if black or ruff does."""
try:
import black # noqa: F401
import ruff # noqa: F401
except ImportError as error: # pragma: no cover
skip_if_optional_else_raise(error)
file = File(tmp_path / "module.py", "module")
with pytest.raises(SystemExit):
run_linters(file, "class not valid code ><")
with pytest.raises(SystemExit):
run_linters(file, "import waffle\n;import trio")

View File

@ -0,0 +1,140 @@
from __future__ import annotations
import io
import sys
from typing import TYPE_CHECKING
import pytest
from trio._tools.mypy_annotate import Result, export, main, process_line
if TYPE_CHECKING:
from pathlib import Path
@pytest.mark.parametrize(
("src", "expected"),
[
("", None),
("a regular line\n", None),
(
"package\\filename.py:42:8: note: Some info\n",
Result(
kind="notice",
filename="package\\filename.py",
start_line=42,
start_col=8,
end_line=None,
end_col=None,
message=" Some info",
),
),
(
"package/filename.py:42:1:46:3: error: Type error here [code]\n",
Result(
kind="error",
filename="package/filename.py",
start_line=42,
start_col=1,
end_line=46,
end_col=3,
message=" Type error here [code]",
),
),
(
"package/module.py:87: warn: Bad code\n",
Result(
kind="warning",
filename="package/module.py",
start_line=87,
message=" Bad code",
),
),
],
ids=["blank", "normal", "note-wcol", "error-wend", "warn-lineonly"],
)
def test_processing(src: str, expected: Result | None) -> None:
result = process_line(src)
assert result == expected
def test_export(capsys: pytest.CaptureFixture[str]) -> None:
results = {
Result(
kind="notice",
filename="package\\filename.py",
start_line=42,
start_col=8,
end_line=None,
end_col=None,
message=" Some info",
): ["Windows", "Mac"],
Result(
kind="error",
filename="package/filename.py",
start_line=42,
start_col=1,
end_line=46,
end_col=3,
message=" Type error here [code]",
): ["Linux", "Mac"],
Result(
kind="warning",
filename="package/module.py",
start_line=87,
message=" Bad code",
): ["Linux"],
}
export(results)
std = capsys.readouterr()
assert std.err == ""
assert std.out == (
"::notice file=package\\filename.py,line=42,col=8,"
"title=Mypy-Windows+Mac::package\\filename.py:(42:8): Some info"
"\n"
"::error file=package/filename.py,line=42,col=1,endLine=46,endColumn=3,"
"title=Mypy-Linux+Mac::package/filename.py:(42:1 - 46:3): Type error here [code]"
"\n"
"::warning file=package/module.py,line=87,"
"title=Mypy-Linux::package/module.py:87: Bad code\n"
)
def test_endtoend(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
import trio._tools.mypy_annotate as mypy_annotate
inp_text = """\
Mypy begun
trio/core.py:15: error: Bad types here [misc]
trio/package/module.py:48:4:56:18: warn: Missing annotations [no-untyped-def]
Found 3 errors in 29 files
"""
result_file = tmp_path / "dump.dat"
assert not result_file.exists()
with monkeypatch.context():
monkeypatch.setattr(sys, "stdin", io.StringIO(inp_text))
mypy_annotate.main(
["--dumpfile", str(result_file), "--platform", "SomePlatform"],
)
std = capsys.readouterr()
assert std.err == ""
assert std.out == inp_text # Echos the original.
assert result_file.exists()
main(["--dumpfile", str(result_file)])
std = capsys.readouterr()
assert std.err == ""
assert std.out == (
"::error file=trio/core.py,line=15,title=Mypy-SomePlatform::trio/core.py:15: Bad types here [misc]\n"
"::warning file=trio/package/module.py,line=48,col=4,endLine=56,endColumn=18,"
"title=Mypy-SomePlatform::trio/package/module.py:(48:4 - 56:18): Missing "
"annotations [no-untyped-def]\n"
)

View File

@ -0,0 +1,9 @@
# https://github.com/python-trio/trio/issues/2775#issuecomment-1702267589
# (except platform independent...)
import trio
from typing_extensions import assert_type
async def fn(s: trio.SocketStream) -> None:
result = await s.socket.sendto(b"a", "h")
assert_type(result, int)

View File

@ -0,0 +1,4 @@
# https://github.com/python-trio/trio/issues/2873
import trio
s, r = trio.open_memory_channel[int](0)

View File

@ -0,0 +1,144 @@
"""Path wrapping is quite complex, ensure all methods are understood as wrapped correctly."""
import io
import os
import pathlib
import sys
from typing import IO, Any, BinaryIO, List, Tuple
import trio
from trio._file_io import AsyncIOWrapper
from typing_extensions import assert_type
def operator_checks(text: str, tpath: trio.Path, ppath: pathlib.Path) -> None:
"""Verify operators produce the right results."""
assert_type(tpath / ppath, trio.Path)
assert_type(tpath / tpath, trio.Path)
assert_type(tpath / text, trio.Path)
assert_type(text / tpath, trio.Path)
assert_type(tpath > tpath, bool)
assert_type(tpath >= tpath, bool)
assert_type(tpath < tpath, bool)
assert_type(tpath <= tpath, bool)
assert_type(tpath > ppath, bool)
assert_type(tpath >= ppath, bool)
assert_type(tpath < ppath, bool)
assert_type(tpath <= ppath, bool)
assert_type(ppath > tpath, bool)
assert_type(ppath >= tpath, bool)
assert_type(ppath < tpath, bool)
assert_type(ppath <= tpath, bool)
def sync_attrs(path: trio.Path) -> None:
assert_type(path.parts, Tuple[str, ...])
assert_type(path.drive, str)
assert_type(path.root, str)
assert_type(path.anchor, str)
assert_type(path.parents[3], trio.Path)
assert_type(path.parent, trio.Path)
assert_type(path.name, str)
assert_type(path.suffix, str)
assert_type(path.suffixes, List[str])
assert_type(path.stem, str)
assert_type(path.as_posix(), str)
assert_type(path.as_uri(), str)
assert_type(path.is_absolute(), bool)
if sys.version_info > (3, 9):
assert_type(path.is_relative_to(path), bool)
assert_type(path.is_reserved(), bool)
assert_type(path.joinpath(path, "folder"), trio.Path)
assert_type(path.match("*.py"), bool)
assert_type(path.relative_to("/usr"), trio.Path)
if sys.version_info > (3, 12):
assert_type(path.relative_to("/", walk_up=True), bool)
assert_type(path.with_name("filename.txt"), trio.Path)
if sys.version_info > (3, 9):
assert_type(path.with_stem("readme"), trio.Path)
assert_type(path.with_suffix(".log"), trio.Path)
async def async_attrs(path: trio.Path) -> None:
assert_type(await trio.Path.cwd(), trio.Path)
assert_type(await trio.Path.home(), trio.Path)
assert_type(await path.stat(), os.stat_result)
assert_type(await path.chmod(0o777), None)
assert_type(await path.exists(), bool)
assert_type(await path.expanduser(), trio.Path)
for result in await path.glob("*.py"):
assert_type(result, trio.Path)
if sys.platform != "win32":
assert_type(await path.group(), str)
assert_type(await path.is_dir(), bool)
assert_type(await path.is_file(), bool)
if sys.version_info > (3, 12):
assert_type(await path.is_junction(), bool)
if sys.platform != "win32":
assert_type(await path.is_mount(), bool)
assert_type(await path.is_symlink(), bool)
assert_type(await path.is_socket(), bool)
assert_type(await path.is_fifo(), bool)
assert_type(await path.is_block_device(), bool)
assert_type(await path.is_char_device(), bool)
for child_iter in await path.iterdir():
assert_type(child_iter, trio.Path)
# TODO: Path.walk() in 3.12
assert_type(await path.lchmod(0o111), None)
assert_type(await path.lstat(), os.stat_result)
assert_type(await path.mkdir(mode=0o777, parents=True, exist_ok=False), None)
# Open done separately.
if sys.platform != "win32":
assert_type(await path.owner(), str)
assert_type(await path.read_bytes(), bytes)
assert_type(await path.read_text(encoding="utf16", errors="replace"), str)
if sys.version_info > (3, 9):
assert_type(await path.readlink(), trio.Path)
assert_type(await path.rename("another"), trio.Path)
assert_type(await path.replace(path), trio.Path)
assert_type(await path.resolve(), trio.Path)
for child_glob in await path.glob("*.py"):
assert_type(child_glob, trio.Path)
for child_rglob in await path.rglob("*.py"):
assert_type(child_rglob, trio.Path)
assert_type(await path.rmdir(), None)
assert_type(await path.samefile("something_else"), bool)
assert_type(await path.symlink_to("somewhere"), None)
if sys.version_info > (3, 10):
assert_type(await path.hardlink_to("elsewhere"), None)
assert_type(await path.touch(), None)
assert_type(await path.unlink(missing_ok=True), None)
assert_type(await path.write_bytes(b"123"), int)
assert_type(
await path.write_text("hello", encoding="utf32le", errors="ignore"),
int,
)
async def open_results(path: trio.Path, some_int: int, some_str: str) -> None:
# Check the overloads.
assert_type(await path.open(), AsyncIOWrapper[io.TextIOWrapper])
assert_type(await path.open("r"), AsyncIOWrapper[io.TextIOWrapper])
assert_type(await path.open("r+"), AsyncIOWrapper[io.TextIOWrapper])
assert_type(await path.open("w"), AsyncIOWrapper[io.TextIOWrapper])
assert_type(await path.open("rb", buffering=0), AsyncIOWrapper[io.FileIO])
assert_type(await path.open("rb+"), AsyncIOWrapper[io.BufferedRandom])
assert_type(await path.open("wb"), AsyncIOWrapper[io.BufferedWriter])
assert_type(await path.open("rb"), AsyncIOWrapper[io.BufferedReader])
assert_type(await path.open("rb", buffering=some_int), AsyncIOWrapper[BinaryIO])
assert_type(await path.open(some_str), AsyncIOWrapper[IO[Any]])
# Check they produce the right types.
file_bin = await path.open("rb+")
assert_type(await file_bin.read(), bytes)
assert_type(await file_bin.write(b"test"), int)
assert_type(await file_bin.seek(32), int)
file_text = await path.open("r+t")
assert_type(await file_text.read(), str)
assert_type(await file_text.write("test"), int)
# TODO: report mypy bug: equiv to https://github.com/microsoft/pyright/issues/6833
assert_type(await file_text.readlines(), List[str])

View File

@ -0,0 +1,254 @@
"""The typing of RaisesGroup involves a lot of deception and lies, since AFAIK what we
actually want to achieve is ~impossible. This is because we specify what we expect with
instances of RaisesGroup and exception classes, but excinfo.value will be instances of
[Base]ExceptionGroup and instances of exceptions. So we need to "translate" from
RaisesGroup to ExceptionGroup.
The way it currently works is that RaisesGroup[E] corresponds to
ExceptionInfo[BaseExceptionGroup[E]], so the top-level group will be correct. But
RaisesGroup[RaisesGroup[ValueError]] will become
ExceptionInfo[BaseExceptionGroup[RaisesGroup[ValueError]]]. To get around that we specify
RaisesGroup as a subclass of BaseExceptionGroup during type checking - which should mean
that most static type checking for end users should be mostly correct.
"""
from __future__ import annotations
import sys
from typing import Union
from trio.testing import Matcher, RaisesGroup
from typing_extensions import assert_type
if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
# split into functions to isolate the different scopes
def check_inheritance_and_assignments() -> None:
# Check inheritance
_: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError)
_ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore
a: BaseExceptionGroup[BaseExceptionGroup[ValueError]]
a = RaisesGroup(RaisesGroup(ValueError))
a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),))
assert a
def check_matcher_typevar_default(e: Matcher) -> object:
assert e.exception_type is not None
exc: type[BaseException] = e.exception_type
# this would previously pass, as the type would be `Any`
e.exception_type().blah() # type: ignore
return exc # Silence Pyright unused var warning
def check_basic_contextmanager() -> None:
# One level of Group is correctly translated - except it's a BaseExceptionGroup
# instead of an ExceptionGroup.
with RaisesGroup(ValueError) as e:
raise ExceptionGroup("foo", (ValueError(),))
assert_type(e.value, BaseExceptionGroup[ValueError])
def check_basic_matches() -> None:
# check that matches gets rid of the naked ValueError in the union
exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup("", (ValueError(),))
if RaisesGroup(ValueError).matches(exc):
assert_type(exc, BaseExceptionGroup[ValueError])
def check_matches_with_different_exception_type() -> None:
# This should probably raise some type error somewhere, since
# ValueError != KeyboardInterrupt
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
"",
(KeyboardInterrupt(),),
)
if RaisesGroup(ValueError).matches(e):
assert_type(e, BaseExceptionGroup[ValueError])
def check_matcher_init() -> None:
def check_exc(exc: BaseException) -> bool:
return isinstance(exc, ValueError)
# Check various combinations of constructor signatures.
# At least 1 arg must be provided.
Matcher() # type: ignore
Matcher(ValueError)
Matcher(ValueError, "regex")
Matcher(ValueError, "regex", check_exc)
Matcher(exception_type=ValueError)
Matcher(match="regex")
Matcher(check=check_exc)
Matcher(ValueError, match="regex")
Matcher(match="regex", check=check_exc)
def check_filenotfound(exc: FileNotFoundError) -> bool:
return not exc.filename.endswith(".tmp")
# If exception_type is provided, that narrows the `check` method's argument.
Matcher(FileNotFoundError, check=check_filenotfound)
Matcher(ValueError, check=check_filenotfound) # type: ignore
Matcher(check=check_filenotfound) # type: ignore
Matcher(FileNotFoundError, match="regex", check=check_filenotfound)
def raisesgroup_check_type_narrowing() -> None:
"""Check type narrowing on the `check` argument to `RaisesGroup`.
All `type: ignore`s are correctly pointing out type errors, except
where otherwise noted.
"""
def handle_exc(e: BaseExceptionGroup[BaseException]) -> bool:
return True
def handle_kbi(e: BaseExceptionGroup[KeyboardInterrupt]) -> bool:
return True
def handle_value(e: BaseExceptionGroup[ValueError]) -> bool:
return True
RaisesGroup(BaseException, check=handle_exc)
RaisesGroup(BaseException, check=handle_kbi) # type: ignore
RaisesGroup(Exception, check=handle_exc)
RaisesGroup(Exception, check=handle_value) # type: ignore
RaisesGroup(KeyboardInterrupt, check=handle_exc)
RaisesGroup(KeyboardInterrupt, check=handle_kbi)
RaisesGroup(KeyboardInterrupt, check=handle_value) # type: ignore
RaisesGroup(ValueError, check=handle_exc)
RaisesGroup(ValueError, check=handle_kbi) # type: ignore
RaisesGroup(ValueError, check=handle_value)
RaisesGroup(ValueError, KeyboardInterrupt, check=handle_exc)
RaisesGroup(ValueError, KeyboardInterrupt, check=handle_kbi) # type: ignore
RaisesGroup(ValueError, KeyboardInterrupt, check=handle_value) # type: ignore
def raisesgroup_narrow_baseexceptiongroup() -> None:
"""Check type narrowing specifically for the container exceptiongroup.
This is not currently working, and after playing around with it for a bit
I think the only way is to introduce a subclass `NonBaseRaisesGroup`, and overload
`__new__` in Raisesgroup to return the subclass when exceptions are non-base.
(or make current class BaseRaisesGroup and introduce RaisesGroup for non-base)
I encountered problems trying to type this though, see
https://github.com/python/mypy/issues/17251
That is probably possible to work around by entirely using `__new__` instead of
`__init__`, but........ ugh.
"""
def handle_group(e: ExceptionGroup[Exception]) -> bool:
return True
def handle_group_value(e: ExceptionGroup[ValueError]) -> bool:
return True
# should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup
RaisesGroup(ValueError, check=handle_group_value) # type: ignore
# should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup
RaisesGroup(Exception, check=handle_group) # type: ignore
def check_matcher_transparent() -> None:
with RaisesGroup(Matcher(ValueError)) as e:
...
_: BaseExceptionGroup[ValueError] = e.value
assert_type(e.value, BaseExceptionGroup[ValueError])
def check_nested_raisesgroups_contextmanager() -> None:
with RaisesGroup(RaisesGroup(ValueError)) as excinfo:
raise ExceptionGroup("foo", (ValueError(),))
# thanks to inheritance this assignment works
_: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value
# and it can mostly be treated like an exceptiongroup
print(excinfo.value.exceptions[0].exceptions[0])
# but assert_type reveals the lies
print(type(excinfo.value)) # would print "ExceptionGroup"
# typing says it's a BaseExceptionGroup
assert_type(
excinfo.value,
BaseExceptionGroup[RaisesGroup[ValueError]],
)
print(type(excinfo.value.exceptions[0])) # would print "ExceptionGroup"
# but type checkers are utterly confused
assert_type(
excinfo.value.exceptions[0],
Union[RaisesGroup[ValueError], BaseExceptionGroup[RaisesGroup[ValueError]]],
)
def check_nested_raisesgroups_matches() -> None:
"""Check nested RaisesGroups with .matches"""
exc: ExceptionGroup[ExceptionGroup[ValueError]] = ExceptionGroup(
"",
(ExceptionGroup("", (ValueError(),)),),
)
# has the same problems as check_nested_raisesgroups_contextmanager
if RaisesGroup(RaisesGroup(ValueError)).matches(exc):
assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]])
def check_multiple_exceptions_1() -> None:
a = RaisesGroup(ValueError, ValueError)
b = RaisesGroup(Matcher(ValueError), Matcher(ValueError))
c = RaisesGroup(ValueError, Matcher(ValueError))
d: BaseExceptionGroup[ValueError]
d = a
d = b
d = c
assert d
def check_multiple_exceptions_2() -> None:
# This previously failed due to lack of covariance in the TypeVar
a = RaisesGroup(Matcher(ValueError), Matcher(TypeError))
b = RaisesGroup(Matcher(ValueError), TypeError)
c = RaisesGroup(ValueError, TypeError)
d: BaseExceptionGroup[Exception]
d = a
d = b
d = c
assert d
def check_raisesgroup_overloads() -> None:
# allow_unwrapped=True does not allow:
# multiple exceptions
RaisesGroup(ValueError, TypeError, allow_unwrapped=True) # type: ignore
# nested RaisesGroup
RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore
# specifying match
RaisesGroup(ValueError, match="foo", allow_unwrapped=True) # type: ignore
# specifying check
RaisesGroup(ValueError, check=bool, allow_unwrapped=True) # type: ignore
# allowed variants
RaisesGroup(ValueError, allow_unwrapped=True)
RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True)
RaisesGroup(Matcher(ValueError), allow_unwrapped=True)
# flatten_subgroups=True does not allow nested RaisesGroup
RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore
# but rest is plenty fine
RaisesGroup(ValueError, TypeError, flatten_subgroups=True)
RaisesGroup(ValueError, match="foo", flatten_subgroups=True)
RaisesGroup(ValueError, check=bool, flatten_subgroups=True)
RaisesGroup(ValueError, flatten_subgroups=True)
RaisesGroup(Matcher(ValueError), flatten_subgroups=True)
# if they're both false we can of course specify nested raisesgroup
RaisesGroup(RaisesGroup(ValueError))

View File

@ -0,0 +1,29 @@
"""Check that started() can only be called for TaskStatus[None]."""
from trio import TaskStatus
from typing_extensions import assert_type
async def check_status(
none_status_explicit: TaskStatus[None],
none_status_implicit: TaskStatus,
int_status: TaskStatus[int],
) -> None:
assert_type(none_status_explicit, TaskStatus[None])
assert_type(none_status_implicit, TaskStatus[None]) # Default typevar
assert_type(int_status, TaskStatus[int])
# Omitting the parameter is only allowed for None.
none_status_explicit.started()
none_status_implicit.started()
int_status.started() # type: ignore
# Explicit None is allowed.
none_status_explicit.started(None)
none_status_implicit.started(None)
int_status.started(None) # type: ignore
none_status_explicit.started(42) # type: ignore
none_status_implicit.started(42) # type: ignore
int_status.started(42)
int_status.started(True)