Updated script that can be controled by Nodejs web app
This commit is contained in:
@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env python3
|
||||
"""This is a file that wraps calls to `pyright --verifytypes`, achieving two things:
|
||||
1. give an error if docstrings are missing.
|
||||
pyright will give a number of missing docstrings, and error messages, but not exit with a non-zero value.
|
||||
2. filter out specific errors we don't care about.
|
||||
this is largely due to 1, but also because Trio does some very complex stuff and --verifytypes has few to no ways of ignoring specific errors.
|
||||
|
||||
If this check is giving you false alarms, you can ignore them by adding logic to `has_docstring_at_runtime`, in the main loop in `check_type`, or by updating the json file.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
# this file is not run as part of the tests, instead it's run standalone from check.sh
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
|
||||
# not needed if everything is working, but if somebody does something to generate
|
||||
# tons of errors, we can be nice and stop them from getting 3*tons of output
|
||||
printed_diagnostics: set[str] = set()
|
||||
|
||||
|
||||
# TODO: consider checking manually without `--ignoreexternal`, and/or
|
||||
# removing it from the below call later on.
|
||||
def run_pyright(platform: str) -> subprocess.CompletedProcess[bytes]:
|
||||
return subprocess.run(
|
||||
[
|
||||
"pyright",
|
||||
# Specify a platform and version to keep imported modules consistent.
|
||||
f"--pythonplatform={platform}",
|
||||
"--pythonversion=3.8",
|
||||
"--verifytypes=trio",
|
||||
"--outputjson",
|
||||
"--ignoreexternal",
|
||||
],
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
|
||||
def has_docstring_at_runtime(name: str) -> bool:
|
||||
"""Pyright gives us an object identifier of xx.yy.zz
|
||||
This function tries to decompose that into its constituent parts, such that we
|
||||
can resolve it, in order to check whether it has a `__doc__` at runtime and
|
||||
verifytypes misses it because we're doing overly fancy stuff.
|
||||
"""
|
||||
# This assert is solely for stopping isort from removing our imports of trio & trio.testing
|
||||
# It could also be done with isort:skip, but that'd also disable import sorting and the like.
|
||||
assert trio.testing
|
||||
|
||||
# figure out what part of the name is the module, so we can "import" it
|
||||
name_parts = name.split(".")
|
||||
assert name_parts[0] == "trio"
|
||||
if name_parts[1] == "tests":
|
||||
return True
|
||||
|
||||
# traverse down the remaining identifiers with getattr
|
||||
obj = trio
|
||||
try:
|
||||
for obj_name in name_parts[1:]:
|
||||
obj = getattr(obj, obj_name)
|
||||
except AttributeError as exc:
|
||||
# asynciowrapper does funky getattr stuff
|
||||
if "AsyncIOWrapper" in str(exc) or name in (
|
||||
# Symbols not existing on all platforms, so we can't dynamically inspect them.
|
||||
# Manually confirmed to have docstrings but pyright doesn't see them due to
|
||||
# export shenanigans. TODO: actually manually confirm that.
|
||||
# In theory we could verify these at runtime, probably by running the script separately
|
||||
# on separate platforms. It might also be a decent idea to work the other way around,
|
||||
# a la test_static_tool_sees_class_members
|
||||
# darwin
|
||||
"trio.lowlevel.current_kqueue",
|
||||
"trio.lowlevel.monitor_kevent",
|
||||
"trio.lowlevel.wait_kevent",
|
||||
"trio._core._io_kqueue._KqueueStatistics",
|
||||
# windows
|
||||
"trio._socket.SocketType.share",
|
||||
"trio._core._io_windows._WindowsStatistics",
|
||||
"trio._core._windows_cffi.Handle",
|
||||
"trio.lowlevel.current_iocp",
|
||||
"trio.lowlevel.monitor_completion_key",
|
||||
"trio.lowlevel.readinto_overlapped",
|
||||
"trio.lowlevel.register_with_iocp",
|
||||
"trio.lowlevel.wait_overlapped",
|
||||
"trio.lowlevel.write_overlapped",
|
||||
"trio.lowlevel.WaitForSingleObject",
|
||||
"trio.socket.fromshare",
|
||||
# linux
|
||||
# this test will fail on linux, but I don't develop on linux. So the next
|
||||
# person to do so is very welcome to open a pull request and populate with
|
||||
# objects
|
||||
# TODO: these are erroring on all platforms, why?
|
||||
"trio._highlevel_generic.StapledStream.send_stream",
|
||||
"trio._highlevel_generic.StapledStream.receive_stream",
|
||||
"trio._ssl.SSLStream.transport_stream",
|
||||
"trio._file_io._HasFileNo",
|
||||
"trio._file_io._HasFileNo.fileno",
|
||||
):
|
||||
return True
|
||||
|
||||
else:
|
||||
print(
|
||||
f"Pyright sees {name} at runtime, but unable to getattr({obj.__name__}, {obj_name}).",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return False
|
||||
return bool(obj.__doc__)
|
||||
|
||||
|
||||
def check_type(
|
||||
platform: str,
|
||||
full_diagnostics_file: Path | None,
|
||||
expected_errors: list[object],
|
||||
) -> list[object]:
|
||||
# convince isort we use the trio import
|
||||
assert trio
|
||||
|
||||
# run pyright, load output into json
|
||||
res = run_pyright(platform)
|
||||
current_result = json.loads(res.stdout)
|
||||
|
||||
if res.stderr:
|
||||
print(res.stderr, file=sys.stderr)
|
||||
|
||||
if full_diagnostics_file:
|
||||
with open(full_diagnostics_file, "a") as f:
|
||||
json.dump(current_result, f, sort_keys=True, indent=4)
|
||||
|
||||
errors = []
|
||||
|
||||
for symbol in current_result["typeCompleteness"]["symbols"]:
|
||||
diagnostics = symbol["diagnostics"]
|
||||
name = symbol["name"]
|
||||
for diagnostic in diagnostics:
|
||||
message = diagnostic["message"]
|
||||
if name in (
|
||||
"trio._path.PosixPath",
|
||||
"trio._path.WindowsPath",
|
||||
) and message.startswith("Type of base class "):
|
||||
continue
|
||||
|
||||
if name.startswith("trio._path.Path"):
|
||||
if message.startswith("No docstring found for"):
|
||||
continue
|
||||
if message.startswith(
|
||||
"Type is missing type annotation and could be inferred differently by type checkers",
|
||||
):
|
||||
continue
|
||||
|
||||
# ignore errors about missing docstrings if they're available at runtime
|
||||
if message.startswith("No docstring found for"):
|
||||
if has_docstring_at_runtime(symbol["name"]):
|
||||
continue
|
||||
else:
|
||||
# Missing docstring messages include the name of the object.
|
||||
# Other errors don't, so we add it.
|
||||
message = f"{name}: {message}"
|
||||
if message not in expected_errors and message not in printed_diagnostics:
|
||||
print(f"new error: {message}", file=sys.stderr)
|
||||
errors.append(message)
|
||||
printed_diagnostics.add(message)
|
||||
|
||||
continue
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> int:
|
||||
if args.full_diagnostics_file:
|
||||
full_diagnostics_file = Path(args.full_diagnostics_file)
|
||||
full_diagnostics_file.write_text("")
|
||||
else:
|
||||
full_diagnostics_file = None
|
||||
|
||||
errors_by_platform_file = Path(__file__).parent / "_check_type_completeness.json"
|
||||
if errors_by_platform_file.exists():
|
||||
with open(errors_by_platform_file) as f:
|
||||
errors_by_platform = json.load(f)
|
||||
else:
|
||||
errors_by_platform = {"Linux": [], "Windows": [], "Darwin": [], "all": []}
|
||||
|
||||
changed = False
|
||||
for platform in "Linux", "Windows", "Darwin":
|
||||
platform_errors = errors_by_platform[platform] + errors_by_platform["all"]
|
||||
print("*" * 20, f"\nChecking {platform}...")
|
||||
errors = check_type(platform, full_diagnostics_file, platform_errors)
|
||||
|
||||
new_errors = [e for e in errors if e not in platform_errors]
|
||||
missing_errors = [e for e in platform_errors if e not in errors]
|
||||
|
||||
if new_errors:
|
||||
print(
|
||||
f"New errors introduced in `pyright --verifytypes`. Fix them, or ignore them by modifying {errors_by_platform_file}, either manually or with '--overwrite-file'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
changed = True
|
||||
if missing_errors:
|
||||
print(
|
||||
f"Congratulations, you have resolved existing errors! Please remove them from {errors_by_platform_file}, either manually or with '--overwrite-file'.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
changed = True
|
||||
print(missing_errors, file=sys.stderr)
|
||||
|
||||
errors_by_platform[platform] = errors
|
||||
print("*" * 20)
|
||||
|
||||
# cut down the size of the json file by a lot, and make it easier to parse for
|
||||
# humans, by moving errors that appear on all platforms to a separate category
|
||||
errors_by_platform["all"] = []
|
||||
for e in errors_by_platform["Linux"].copy():
|
||||
if e in errors_by_platform["Darwin"] and e in errors_by_platform["Windows"]:
|
||||
for platform in "Linux", "Windows", "Darwin":
|
||||
errors_by_platform[platform].remove(e)
|
||||
errors_by_platform["all"].append(e)
|
||||
|
||||
if changed and args.overwrite_file:
|
||||
with open(errors_by_platform_file, "w") as f:
|
||||
json.dump(errors_by_platform, f, indent=4, sort_keys=True)
|
||||
# newline at end of file
|
||||
f.write("\n")
|
||||
|
||||
# True -> 1 -> non-zero exit value -> error
|
||||
return changed
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--overwrite-file",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Use this flag to overwrite the current stored results. Either in CI together with a diff check, or to avoid having to manually correct it.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--full-diagnostics-file",
|
||||
type=Path,
|
||||
default=None,
|
||||
help="Use this for debugging, it will dump the output of all three pyright runs by platform into this file.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
assert __name__ == "__main__", "This script should be run standalone"
|
||||
sys.exit(main(args))
|
@ -0,0 +1,24 @@
|
||||
regular = "hi"
|
||||
|
||||
from .. import _deprecate
|
||||
|
||||
_deprecate.enable_attribute_deprecations(__name__)
|
||||
|
||||
# Make sure that we don't trigger infinite recursion when accessing module
|
||||
# attributes in between calling enable_attribute_deprecations and defining
|
||||
# __deprecated_attributes__:
|
||||
import sys
|
||||
|
||||
this_mod = sys.modules[__name__]
|
||||
assert this_mod.regular == "hi"
|
||||
assert not hasattr(this_mod, "dep1")
|
||||
|
||||
__deprecated_attributes__ = {
|
||||
"dep1": _deprecate.DeprecatedAttribute("value1", "1.1", issue=1),
|
||||
"dep2": _deprecate.DeprecatedAttribute(
|
||||
"value2",
|
||||
"1.2",
|
||||
issue=1,
|
||||
instead="instead-string",
|
||||
),
|
||||
}
|
54
lib/python3.13/site-packages/trio/_tests/pytest_plugin.py
Normal file
54
lib/python3.13/site-packages/trio/_tests/pytest_plugin.py
Normal file
@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
from ..testing import MockClock, trio_test
|
||||
|
||||
RUN_SLOW = True
|
||||
SKIP_OPTIONAL_IMPORTS = False
|
||||
|
||||
|
||||
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||
parser.addoption("--run-slow", action="store_true", help="run slow tests")
|
||||
parser.addoption(
|
||||
"--skip-optional-imports",
|
||||
action="store_true",
|
||||
help="skip tests that rely on libraries not required by trio itself",
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config: pytest.Config) -> None:
|
||||
global RUN_SLOW
|
||||
RUN_SLOW = config.getoption("--run-slow", default=True)
|
||||
global SKIP_OPTIONAL_IMPORTS
|
||||
SKIP_OPTIONAL_IMPORTS = config.getoption("--skip-optional-imports", default=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_clock() -> MockClock:
|
||||
return MockClock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def autojump_clock() -> MockClock:
|
||||
return MockClock(autojump_threshold=0)
|
||||
|
||||
|
||||
# FIXME: split off into a package (or just make part of Trio's public
|
||||
# interface?), with config file to enable? and I guess a mark option too; I
|
||||
# guess it's useful with the class- and file-level marking machinery (where
|
||||
# the raw @trio_test decorator isn't enough).
|
||||
@pytest.hookimpl(tryfirst=True)
|
||||
def pytest_pyfunc_call(pyfuncitem: pytest.Function) -> None:
|
||||
if inspect.iscoroutinefunction(pyfuncitem.obj):
|
||||
pyfuncitem.obj = trio_test(pyfuncitem.obj)
|
||||
|
||||
|
||||
def skip_if_optional_else_raise(error: ImportError) -> NoReturn:
|
||||
if SKIP_OPTIONAL_IMPORTS:
|
||||
pytest.skip(error.msg, allow_module_level=True)
|
||||
else: # pragma: no cover
|
||||
raise error
|
72
lib/python3.13/site-packages/trio/_tests/test_abc.py
Normal file
72
lib/python3.13/site-packages/trio/_tests/test_abc.py
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from .. import abc as tabc
|
||||
from ..lowlevel import Task
|
||||
|
||||
|
||||
def test_instrument_implements_hook_methods() -> None:
|
||||
attrs = {
|
||||
"before_run": (),
|
||||
"after_run": (),
|
||||
"task_spawned": (Task,),
|
||||
"task_scheduled": (Task,),
|
||||
"before_task_step": (Task,),
|
||||
"after_task_step": (Task,),
|
||||
"task_exited": (Task,),
|
||||
"before_io_wait": (3.3,),
|
||||
"after_io_wait": (3.3,),
|
||||
}
|
||||
|
||||
mayonnaise = tabc.Instrument()
|
||||
|
||||
for method_name, args in attrs.items():
|
||||
assert hasattr(mayonnaise, method_name)
|
||||
method = getattr(mayonnaise, method_name)
|
||||
assert callable(method)
|
||||
method(*args)
|
||||
|
||||
|
||||
async def test_AsyncResource_defaults() -> None:
|
||||
@attrs.define(slots=False)
|
||||
class MyAR(tabc.AsyncResource):
|
||||
record: list[str] = attrs.Factory(list)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("ac")
|
||||
|
||||
async with MyAR() as myar:
|
||||
assert isinstance(myar, MyAR)
|
||||
assert myar.record == []
|
||||
|
||||
assert myar.record == ["ac"]
|
||||
|
||||
|
||||
def test_abc_generics() -> None:
|
||||
# Pythons below 3.5.2 had a typing.Generic that would throw
|
||||
# errors when instantiating or subclassing a parameterized
|
||||
# version of a class with any __slots__. This is why RunVar
|
||||
# (which has slots) is not generic. This tests that
|
||||
# the generic ABCs are fine, because while they are slotted
|
||||
# they don't actually define any slots.
|
||||
|
||||
class SlottedChannel(tabc.SendChannel[tabc.Stream]):
|
||||
__slots__ = ("x",)
|
||||
|
||||
def send_nowait(self, value: object) -> None:
|
||||
raise RuntimeError
|
||||
|
||||
async def send(self, value: object) -> None:
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
def clone(self) -> None:
|
||||
raise RuntimeError # pragma: no cover
|
||||
|
||||
async def aclose(self) -> None:
|
||||
pass # pragma: no cover
|
||||
|
||||
channel = SlottedChannel()
|
||||
with pytest.raises(RuntimeError):
|
||||
channel.send_nowait(None)
|
413
lib/python3.13/site-packages/trio/_tests/test_channel.py
Normal file
413
lib/python3.13/site-packages/trio/_tests/test_channel.py
Normal file
@ -0,0 +1,413 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import EndOfChannel, open_memory_channel
|
||||
|
||||
from ..testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
|
||||
async def test_channel() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
open_memory_channel(1.0)
|
||||
with pytest.raises(ValueError, match="^max_buffer_size must be >= 0$"):
|
||||
open_memory_channel(-1)
|
||||
|
||||
s, r = open_memory_channel[Union[int, str, None]](2)
|
||||
repr(s) # smoke test
|
||||
repr(r) # smoke test
|
||||
|
||||
s.send_nowait(1)
|
||||
with assert_checkpoints():
|
||||
await s.send(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
assert r.receive_nowait() == 2
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
|
||||
s.send_nowait("last")
|
||||
await s.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait("too late")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.clone()
|
||||
await s.aclose()
|
||||
|
||||
assert r.receive_nowait() == "last"
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
await r.aclose()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
await r.aclose()
|
||||
|
||||
|
||||
async def test_553(autojump_clock: trio.abc.Clock) -> None:
|
||||
s, r = open_memory_channel[str](1)
|
||||
with trio.move_on_after(10) as timeout_scope:
|
||||
await r.receive()
|
||||
assert timeout_scope.cancelled_caught
|
||||
await s.send("Test for PR #553")
|
||||
|
||||
|
||||
async def test_channel_multiple_producers() -> None:
|
||||
async def producer(send_channel: trio.MemorySendChannel[int], i: int) -> None:
|
||||
# We close our handle when we're done with it
|
||||
async with send_channel:
|
||||
for j in range(3 * i, 3 * (i + 1)):
|
||||
await send_channel.send(j)
|
||||
|
||||
send_channel, receive_channel = open_memory_channel[int](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
# We hand out clones to all the new producers, and then close the
|
||||
# original.
|
||||
async with send_channel:
|
||||
for i in range(10):
|
||||
nursery.start_soon(producer, send_channel.clone(), i)
|
||||
|
||||
got = [value async for value in receive_channel]
|
||||
|
||||
got.sort()
|
||||
assert got == list(range(30))
|
||||
|
||||
|
||||
async def test_channel_multiple_consumers() -> None:
|
||||
successful_receivers = set()
|
||||
received = []
|
||||
|
||||
async def consumer(receive_channel: trio.MemoryReceiveChannel[int], i: int) -> None:
|
||||
async for value in receive_channel:
|
||||
successful_receivers.add(i)
|
||||
received.append(value)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
send_channel, receive_channel = trio.open_memory_channel[int](1)
|
||||
async with send_channel:
|
||||
for i in range(5):
|
||||
nursery.start_soon(consumer, receive_channel, i)
|
||||
await wait_all_tasks_blocked()
|
||||
for i in range(10):
|
||||
await send_channel.send(i)
|
||||
|
||||
assert successful_receivers == set(range(5))
|
||||
assert len(received) == 10
|
||||
assert set(received) == set(range(10))
|
||||
|
||||
|
||||
async def test_close_basics() -> None:
|
||||
async def send_block(
|
||||
s: trio.MemorySendChannel[None],
|
||||
expect: type[BaseException],
|
||||
) -> None:
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
await r.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r: trio.MemoryReceiveChannel[int]) -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
s2, r2 = open_memory_channel[int](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r2)
|
||||
await wait_all_tasks_blocked()
|
||||
await r2.aclose()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r2.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r2.receive()
|
||||
|
||||
|
||||
async def test_close_sync() -> None:
|
||||
async def send_block(
|
||||
s: trio.MemorySendChannel[None],
|
||||
expect: type[BaseException],
|
||||
) -> None:
|
||||
with pytest.raises(expect):
|
||||
await s.send(None)
|
||||
|
||||
# closing send -> other send gets ClosedResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.ClosedResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# and receive gets EndOfChannel
|
||||
with pytest.raises(EndOfChannel):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(EndOfChannel):
|
||||
await r.receive()
|
||||
|
||||
# closing receive -> send gets BrokenResourceError
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_block, s, trio.BrokenResourceError)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await s.send(None)
|
||||
|
||||
# closing receive -> other receive gets ClosedResourceError
|
||||
async def receive_block(r: trio.MemoryReceiveChannel[None]) -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
s, r = open_memory_channel[None](0)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_block, r)
|
||||
await wait_all_tasks_blocked()
|
||||
r.close()
|
||||
|
||||
# and it's persistent
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r.receive()
|
||||
|
||||
|
||||
async def test_receive_channel_clone_and_close() -> None:
|
||||
s, r = open_memory_channel[None](10)
|
||||
|
||||
r2 = r.clone()
|
||||
r3 = r.clone()
|
||||
|
||||
s.send_nowait(None)
|
||||
await r.aclose()
|
||||
with r2:
|
||||
pass
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r.clone()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
r2.clone()
|
||||
|
||||
# Can still send, r3 is still open
|
||||
s.send_nowait(None)
|
||||
|
||||
await r3.aclose()
|
||||
|
||||
# But now the receiver is really closed
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
s.send_nowait(None)
|
||||
|
||||
|
||||
async def test_close_multiple_send_handles() -> None:
|
||||
# With multiple send handles, closing one handle only wakes senders on
|
||||
# that handle, but others can continue just fine
|
||||
s1, r = open_memory_channel[str](0)
|
||||
s2 = s1.clone()
|
||||
|
||||
async def send_will_close() -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await s1.send("nope")
|
||||
|
||||
async def send_will_succeed() -> None:
|
||||
await s2.send("ok")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(send_will_close)
|
||||
nursery.start_soon(send_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await s1.aclose()
|
||||
assert await r.receive() == "ok"
|
||||
|
||||
|
||||
async def test_close_multiple_receive_handles() -> None:
|
||||
# With multiple receive handles, closing one handle only wakes receivers on
|
||||
# that handle, but others can continue just fine
|
||||
s, r1 = open_memory_channel[str](0)
|
||||
r2 = r1.clone()
|
||||
|
||||
async def receive_will_close() -> None:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await r1.receive()
|
||||
|
||||
async def receive_will_succeed() -> None:
|
||||
assert await r2.receive() == "ok"
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receive_will_close)
|
||||
nursery.start_soon(receive_will_succeed)
|
||||
await wait_all_tasks_blocked()
|
||||
await r1.aclose()
|
||||
await s.send("ok")
|
||||
|
||||
|
||||
async def test_inf_capacity() -> None:
|
||||
send, receive = open_memory_channel[int](float("inf"))
|
||||
|
||||
# It's accepted, and we can send all day without blocking
|
||||
with send:
|
||||
for i in range(10):
|
||||
send.send_nowait(i)
|
||||
|
||||
got = [i async for i in receive]
|
||||
assert got == list(range(10))
|
||||
|
||||
|
||||
async def test_statistics() -> None:
|
||||
s, r = open_memory_channel[None](2)
|
||||
|
||||
assert s.statistics() == r.statistics()
|
||||
stats = s.statistics()
|
||||
assert stats.current_buffer_used == 0
|
||||
assert stats.max_buffer_size == 2
|
||||
assert stats.open_send_channels == 1
|
||||
assert stats.open_receive_channels == 1
|
||||
assert stats.tasks_waiting_send == 0
|
||||
assert stats.tasks_waiting_receive == 0
|
||||
|
||||
s.send_nowait(None)
|
||||
assert s.statistics().current_buffer_used == 1
|
||||
|
||||
s2 = s.clone()
|
||||
assert s.statistics().open_send_channels == 2
|
||||
await s.aclose()
|
||||
assert s2.statistics().open_send_channels == 1
|
||||
|
||||
r2 = r.clone()
|
||||
assert s2.statistics().open_receive_channels == 2
|
||||
await r2.aclose()
|
||||
assert s2.statistics().open_receive_channels == 1
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
s2.send_nowait(None) # fill up the buffer
|
||||
assert s.statistics().current_buffer_used == 2
|
||||
nursery.start_soon(s2.send, None)
|
||||
nursery.start_soon(s2.send, None)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_send == 2
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_send == 0
|
||||
|
||||
# empty out the buffer again
|
||||
try:
|
||||
while True:
|
||||
r.receive_nowait()
|
||||
except trio.WouldBlock:
|
||||
pass
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(r.receive)
|
||||
await wait_all_tasks_blocked()
|
||||
assert s.statistics().tasks_waiting_receive == 1
|
||||
nursery.cancel_scope.cancel()
|
||||
assert s.statistics().tasks_waiting_receive == 0
|
||||
|
||||
|
||||
async def test_channel_fairness() -> None:
|
||||
# We can remove an item we just sent, and send an item back in after, if
|
||||
# no-one else is waiting.
|
||||
s, r = open_memory_channel[Union[int, None]](1)
|
||||
s.send_nowait(1)
|
||||
assert r.receive_nowait() == 1
|
||||
s.send_nowait(2)
|
||||
assert r.receive_nowait() == 2
|
||||
|
||||
# But if someone else is waiting to receive, then they "own" the item we
|
||||
# send, so we can't receive it (even though we run first):
|
||||
|
||||
result: int | None = None
|
||||
|
||||
async def do_receive(r: trio.MemoryReceiveChannel[int | None]) -> None:
|
||||
nonlocal result
|
||||
result = await r.receive()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive, r)
|
||||
await wait_all_tasks_blocked()
|
||||
s.send_nowait(2)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
assert result == 2
|
||||
|
||||
# And the analogous situation for send: if we free up a space, we can't
|
||||
# immediately send something in it if someone is already waiting to do
|
||||
# that
|
||||
s, r = open_memory_channel[Union[int, None]](1)
|
||||
s.send_nowait(1)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(None)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(s.send, 2)
|
||||
await wait_all_tasks_blocked()
|
||||
assert r.receive_nowait() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(3)
|
||||
assert (await r.receive()) == 2
|
||||
|
||||
|
||||
async def test_unbuffered() -> None:
|
||||
s, r = open_memory_channel[int](0)
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
s.send_nowait(1)
|
||||
|
||||
async def do_send(s: trio.MemorySendChannel[int], v: int) -> None:
|
||||
with assert_checkpoints():
|
||||
await s.send(v)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send, s, 1)
|
||||
with assert_checkpoints():
|
||||
assert await r.receive() == 1
|
||||
with pytest.raises(trio.WouldBlock):
|
||||
r.receive_nowait()
|
56
lib/python3.13/site-packages/trio/_tests/test_contextvars.py
Normal file
56
lib/python3.13/site-packages/trio/_tests/test_contextvars.py
Normal file
@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
|
||||
from .. import _core
|
||||
|
||||
trio_testing_contextvar: contextvars.ContextVar[str] = contextvars.ContextVar(
|
||||
"trio_testing_contextvar",
|
||||
)
|
||||
|
||||
|
||||
async def test_contextvars_default() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
assert record == ["main"]
|
||||
|
||||
|
||||
async def test_contextvars_set() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
trio_testing_contextvar.set("child")
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert record == ["child"]
|
||||
assert value == "main"
|
||||
|
||||
|
||||
async def test_contextvars_copy() -> None:
|
||||
trio_testing_contextvar.set("main")
|
||||
context = contextvars.copy_context()
|
||||
trio_testing_contextvar.set("second_main")
|
||||
record: list[str] = []
|
||||
|
||||
async def child() -> None:
|
||||
value = trio_testing_contextvar.get()
|
||||
record.append(value)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
context.run(nursery.start_soon, child)
|
||||
nursery.start_soon(child)
|
||||
value = trio_testing_contextvar.get()
|
||||
assert set(record) == {"main", "second_main"}
|
||||
assert value == "second_main"
|
283
lib/python3.13/site-packages/trio/_tests/test_deprecate.py
Normal file
283
lib/python3.13/site-packages/trio/_tests/test_deprecate.py
Normal file
@ -0,0 +1,283 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
|
||||
from .._deprecate import (
|
||||
TrioDeprecationWarning,
|
||||
deprecated,
|
||||
deprecated_alias,
|
||||
warn_deprecated,
|
||||
)
|
||||
from . import module_with_deprecations
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recwarn_always(recwarn: pytest.WarningsRecorder) -> pytest.WarningsRecorder:
|
||||
warnings.simplefilter("always")
|
||||
# ResourceWarnings about unclosed sockets can occur nondeterministically
|
||||
# (during GC) which throws off the tests in this file
|
||||
warnings.simplefilter("ignore", ResourceWarning)
|
||||
return recwarn
|
||||
|
||||
|
||||
def _here() -> tuple[str, int]:
|
||||
frame = inspect.currentframe()
|
||||
assert frame is not None
|
||||
assert frame.f_back is not None
|
||||
info = inspect.getframeinfo(frame.f_back)
|
||||
return (info.filename, info.lineno)
|
||||
|
||||
|
||||
def test_warn_deprecated(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
def deprecated_thing() -> None:
|
||||
warn_deprecated("ice", "1.2", issue=1, instead="water")
|
||||
|
||||
deprecated_thing()
|
||||
filename, lineno = _here()
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "ice is deprecated" in got.message.args[0]
|
||||
assert "Trio 1.2" in got.message.args[0]
|
||||
assert "water instead" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno - 1
|
||||
|
||||
|
||||
def test_warn_deprecated_no_instead_or_issue(
|
||||
recwarn_always: pytest.WarningsRecorder,
|
||||
) -> None:
|
||||
# Explicitly no instead or issue
|
||||
warn_deprecated("water", "1.3", issue=None, instead=None)
|
||||
assert len(recwarn_always) == 1
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "water is deprecated" in got.message.args[0]
|
||||
assert "no replacement" in got.message.args[0]
|
||||
assert "Trio 1.3" in got.message.args[0]
|
||||
|
||||
|
||||
def test_warn_deprecated_stacklevel(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
def nested1() -> None:
|
||||
nested2()
|
||||
|
||||
def nested2() -> None:
|
||||
warn_deprecated("x", "1.3", issue=7, instead="y", stacklevel=3)
|
||||
|
||||
filename, lineno = _here()
|
||||
nested1()
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
|
||||
def old() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def new() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
|
||||
def test_warn_deprecated_formatting(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
warn_deprecated(old, "1.0", issue=1, instead=new)
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.old is deprecated" in got.message.args[0]
|
||||
assert "test_deprecate.new instead" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.5", issue=123, instead=new)
|
||||
def deprecated_old() -> int:
|
||||
return 3
|
||||
|
||||
|
||||
def test_deprecated_decorator(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert deprecated_old() == 3
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.deprecated_old is deprecated" in got.message.args[0]
|
||||
assert "1.5" in got.message.args[0]
|
||||
assert "test_deprecate.new" in got.message.args[0]
|
||||
assert "issues/123" in got.message.args[0]
|
||||
|
||||
|
||||
class Foo:
|
||||
@deprecated("1.0", issue=123, instead="crying")
|
||||
def method(self) -> int:
|
||||
return 7
|
||||
|
||||
|
||||
def test_deprecated_decorator_method(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
f = Foo()
|
||||
assert f.method() == 7
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.Foo.method is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
@deprecated("1.2", thing="the thing", issue=None, instead=None)
|
||||
def deprecated_with_thing() -> int:
|
||||
return 72
|
||||
|
||||
|
||||
def test_deprecated_decorator_with_explicit_thing(
|
||||
recwarn_always: pytest.WarningsRecorder,
|
||||
) -> None:
|
||||
assert deprecated_with_thing() == 72
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "the thing is deprecated" in got.message.args[0]
|
||||
|
||||
|
||||
def new_hotness() -> str:
|
||||
return "new hotness"
|
||||
|
||||
|
||||
old_hotness = deprecated_alias("old_hotness", new_hotness, "1.23", issue=1)
|
||||
|
||||
|
||||
def test_deprecated_alias(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert old_hotness() == "new hotness"
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "test_deprecate.old_hotness is deprecated" in got.message.args[0]
|
||||
assert "1.23" in got.message.args[0]
|
||||
assert "test_deprecate.new_hotness instead" in got.message.args[0]
|
||||
assert "issues/1" in got.message.args[0]
|
||||
|
||||
assert isinstance(old_hotness.__doc__, str)
|
||||
assert ".. deprecated:: 1.23" in old_hotness.__doc__
|
||||
assert "test_deprecate.new_hotness instead" in old_hotness.__doc__
|
||||
assert "issues/1>`__" in old_hotness.__doc__
|
||||
|
||||
|
||||
class Alias:
|
||||
def new_hotness_method(self) -> str:
|
||||
return "new hotness method"
|
||||
|
||||
old_hotness_method = deprecated_alias(
|
||||
"Alias.old_hotness_method",
|
||||
new_hotness_method,
|
||||
"3.21",
|
||||
issue=1,
|
||||
)
|
||||
|
||||
|
||||
def test_deprecated_alias_method(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
obj = Alias()
|
||||
assert obj.old_hotness_method() == "new hotness method"
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
msg = got.message.args[0]
|
||||
assert "test_deprecate.Alias.old_hotness_method is deprecated" in msg
|
||||
assert "test_deprecate.Alias.new_hotness_method instead" in msg
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead="hi")
|
||||
def docstring_test1() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead="hi")
|
||||
def docstring_test2() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=1, instead=None)
|
||||
def docstring_test3() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
@deprecated("2.1", issue=None, instead=None)
|
||||
def docstring_test4() -> None: # pragma: no cover
|
||||
"""Hello!"""
|
||||
|
||||
|
||||
def test_deprecated_docstring_munging() -> None:
|
||||
assert (
|
||||
docstring_test1.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test2.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
Use hi instead.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test3.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
For details, see `issue #1 <https://github.com/python-trio/trio/issues/1>`__.
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
assert (
|
||||
docstring_test4.__doc__
|
||||
== """Hello!
|
||||
|
||||
.. deprecated:: 2.1
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def test_module_with_deprecations(recwarn_always: pytest.WarningsRecorder) -> None:
|
||||
assert module_with_deprecations.regular == "hi"
|
||||
assert len(recwarn_always) == 0
|
||||
|
||||
filename, lineno = _here()
|
||||
assert module_with_deprecations.dep1 == "value1" # type: ignore[attr-defined]
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert got.filename == filename
|
||||
assert got.lineno == lineno + 1
|
||||
|
||||
assert "module_with_deprecations.dep1" in got.message.args[0]
|
||||
assert "Trio 1.1" in got.message.args[0]
|
||||
assert "/issues/1" in got.message.args[0]
|
||||
assert "value1 instead" in got.message.args[0]
|
||||
|
||||
assert module_with_deprecations.dep2 == "value2" # type: ignore[attr-defined]
|
||||
got = recwarn_always.pop(DeprecationWarning)
|
||||
assert isinstance(got.message, Warning)
|
||||
assert "instead-string instead" in got.message.args[0]
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
module_with_deprecations.asdf # type: ignore[attr-defined] # noqa: B018 # "useless expression"
|
||||
|
||||
|
||||
def test_warning_class() -> None:
|
||||
with pytest.deprecated_call():
|
||||
warn_deprecated("foo", "bar", issue=None, instead=None)
|
||||
|
||||
# essentially the same as the above check
|
||||
with pytest.warns(DeprecationWarning):
|
||||
warn_deprecated("foo", "bar", issue=None, instead=None)
|
||||
|
||||
with pytest.warns(TrioDeprecationWarning):
|
||||
warn_deprecated(
|
||||
"foo",
|
||||
"bar",
|
||||
issue=None,
|
||||
instead=None,
|
||||
use_triodeprecationwarning=True,
|
||||
)
|
@ -0,0 +1,64 @@
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
async def test_deprecation_warning_open_nursery() -> None:
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
async with trio.open_nursery(strict_exception_groups=False):
|
||||
...
|
||||
assert len(record) == 1
|
||||
async with trio.open_nursery(strict_exception_groups=True):
|
||||
...
|
||||
async with trio.open_nursery():
|
||||
...
|
||||
|
||||
|
||||
def test_deprecation_warning_run() -> None:
|
||||
async def foo() -> None: ...
|
||||
|
||||
async def foo_nursery() -> None:
|
||||
# this should not raise a warning, even if it's implied loose
|
||||
async with trio.open_nursery():
|
||||
...
|
||||
|
||||
async def foo_loose_nursery() -> None:
|
||||
# this should raise a warning, even if specifying the parameter is redundant
|
||||
async with trio.open_nursery(strict_exception_groups=False):
|
||||
...
|
||||
|
||||
def helper(fun: Callable[..., Awaitable[None]], num: int) -> None:
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
trio.run(fun, strict_exception_groups=False)
|
||||
assert len(record) == num
|
||||
|
||||
helper(foo, 1)
|
||||
helper(foo_nursery, 1)
|
||||
helper(foo_loose_nursery, 2)
|
||||
|
||||
|
||||
def test_deprecation_warning_start_guest_run() -> None:
|
||||
# "The simplest possible "host" loop."
|
||||
from .._core._tests.test_guest_mode import trivial_guest_run
|
||||
|
||||
async def trio_return(in_host: object) -> str:
|
||||
await trio.lowlevel.checkpoint()
|
||||
return "ok"
|
||||
|
||||
with pytest.warns(
|
||||
trio.TrioDeprecationWarning,
|
||||
match="strict_exception_groups=False",
|
||||
) as record:
|
||||
trivial_guest_run(
|
||||
trio_return,
|
||||
strict_exception_groups=False,
|
||||
)
|
||||
assert len(record) == 1
|
900
lib/python3.13/site-packages/trio/_tests/test_dtls.py
Normal file
900
lib/python3.13/site-packages/trio/_tests/test_dtls.py
Normal file
@ -0,0 +1,900 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from contextlib import asynccontextmanager
|
||||
from itertools import count
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
try:
|
||||
import trustme
|
||||
from OpenSSL import SSL
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio import DTLSChannel, DTLSEndpoint
|
||||
from trio.testing._fake_net import FakeNet, UDPPacket
|
||||
|
||||
from .._core._tests.tutil import binds_ipv6, gc_collect_harder, slow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
ca = trustme.CA()
|
||||
server_cert = ca.issue_cert("example.com")
|
||||
|
||||
server_ctx = SSL.Context(SSL.DTLS_METHOD)
|
||||
server_cert.configure_cert(server_ctx)
|
||||
|
||||
client_ctx = SSL.Context(SSL.DTLS_METHOD)
|
||||
ca.configure_trust(client_ctx)
|
||||
|
||||
|
||||
parametrize_ipv6 = pytest.mark.parametrize(
|
||||
"ipv6",
|
||||
[False, pytest.param(True, marks=binds_ipv6)],
|
||||
ids=["ipv4", "ipv6"],
|
||||
)
|
||||
|
||||
|
||||
def endpoint(**kwargs: int | bool) -> DTLSEndpoint:
|
||||
ipv6 = kwargs.pop("ipv6", False)
|
||||
family = trio.socket.AF_INET6 if ipv6 else trio.socket.AF_INET
|
||||
sock = trio.socket.socket(type=trio.socket.SOCK_DGRAM, family=family)
|
||||
return DTLSEndpoint(sock, **kwargs)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def dtls_echo_server(
|
||||
*,
|
||||
autocancel: bool = True,
|
||||
mtu: int | None = None,
|
||||
ipv6: bool = False,
|
||||
) -> AsyncGenerator[tuple[DTLSEndpoint, tuple[str, int]], None]:
|
||||
with endpoint(ipv6=ipv6) as server:
|
||||
localhost = "::1" if ipv6 else "127.0.0.1"
|
||||
await server.socket.bind((localhost, 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def echo_handler(dtls_channel: DTLSChannel) -> None:
|
||||
print(
|
||||
"echo handler started: "
|
||||
f"server {dtls_channel.endpoint.socket.getsockname()!r} "
|
||||
f"client {dtls_channel.peer_address!r}",
|
||||
)
|
||||
if mtu is not None:
|
||||
dtls_channel.set_ciphertext_mtu(mtu)
|
||||
try:
|
||||
print("server starting do_handshake")
|
||||
await dtls_channel.do_handshake()
|
||||
print("server finished do_handshake")
|
||||
async for packet in dtls_channel:
|
||||
print(f"echoing {packet!r} -> {dtls_channel.peer_address!r}")
|
||||
await dtls_channel.send(packet)
|
||||
except trio.BrokenResourceError: # pragma: no cover
|
||||
print("echo handler channel broken")
|
||||
|
||||
await nursery.start(server.serve, server_ctx, echo_handler)
|
||||
|
||||
yield server, server.socket.getsockname()
|
||||
|
||||
if autocancel:
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@parametrize_ipv6
|
||||
async def test_smoke(ipv6: bool) -> None:
|
||||
async with dtls_echo_server(ipv6=ipv6) as (server_endpoint, address):
|
||||
with endpoint(ipv6=ipv6) as client_endpoint:
|
||||
client_channel = client_endpoint.connect(address, client_ctx)
|
||||
with pytest.raises(trio.NeedHandshakeError):
|
||||
client_channel.get_cleartext_mtu()
|
||||
|
||||
await client_channel.do_handshake()
|
||||
await client_channel.send(b"hello")
|
||||
assert await client_channel.receive() == b"hello"
|
||||
await client_channel.send(b"goodbye")
|
||||
assert await client_channel.receive() == b"goodbye"
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^openssl doesn't support sending empty DTLS packets$",
|
||||
):
|
||||
await client_channel.send(b"")
|
||||
|
||||
client_channel.set_ciphertext_mtu(1234)
|
||||
cleartext_mtu_1234 = client_channel.get_cleartext_mtu()
|
||||
client_channel.set_ciphertext_mtu(4321)
|
||||
assert client_channel.get_cleartext_mtu() > cleartext_mtu_1234
|
||||
client_channel.set_ciphertext_mtu(1234)
|
||||
assert client_channel.get_cleartext_mtu() == cleartext_mtu_1234
|
||||
|
||||
|
||||
@slow
|
||||
async def test_handshake_over_terrible_network(
|
||||
autojump_clock: trio.testing.MockClock,
|
||||
) -> None:
|
||||
HANDSHAKES = 100
|
||||
r = random.Random(0)
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
# avoid spurious timeouts on slow machines
|
||||
autojump_clock.autojump_threshold = 0.001
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def route_packet(packet: UDPPacket) -> None:
|
||||
while True:
|
||||
op = r.choices(
|
||||
["deliver", "drop", "dupe", "delay"],
|
||||
weights=[0.7, 0.1, 0.1, 0.1],
|
||||
)[0]
|
||||
print(f"{packet.source} -> {packet.destination}: {op}")
|
||||
if op == "drop":
|
||||
return
|
||||
elif op == "dupe":
|
||||
fn.send_packet(packet)
|
||||
elif op == "delay":
|
||||
await trio.sleep(r.random() * 3)
|
||||
# I wanted to test random packet corruption too, but it turns out
|
||||
# openssl has a bug in the following scenario:
|
||||
#
|
||||
# - client sends ClientHello
|
||||
# - server sends HelloVerifyRequest with cookie -- but cookie is
|
||||
# invalid b/c either the ClientHello or HelloVerifyRequest was
|
||||
# corrupted
|
||||
# - client re-sends ClientHello with invalid cookie
|
||||
# - server replies with new HelloVerifyRequest and correct cookie
|
||||
#
|
||||
# At this point, the client *should* switch to the new, valid
|
||||
# cookie. But OpenSSL doesn't; it stubbornly insists on re-sending
|
||||
# the original, invalid cookie over and over. In theory we could
|
||||
# work around this by detecting cookie changes and starting over
|
||||
# with a whole new SSL object, but (a) it doesn't seem worth it, (b)
|
||||
# when I tried then I ran into another issue where OpenSSL got stuck
|
||||
# in an infinite loop sending alerts over and over, which I didn't
|
||||
# dig into because see (a).
|
||||
#
|
||||
# elif op == "distort":
|
||||
# payload = bytearray(packet.payload)
|
||||
# payload[r.randrange(len(payload))] ^= 1 << r.randrange(8)
|
||||
# packet = attrs.evolve(packet, payload=payload)
|
||||
else:
|
||||
assert op == "deliver"
|
||||
print(
|
||||
f"{packet.source} -> {packet.destination}: delivered"
|
||||
f" {packet.payload.hex()}",
|
||||
)
|
||||
fn.deliver_packet(packet)
|
||||
break
|
||||
|
||||
def route_packet_wrapper(packet: UDPPacket) -> None:
|
||||
try: # noqa: SIM105 # suppressible-exception
|
||||
nursery.start_soon(route_packet, packet)
|
||||
except RuntimeError: # pragma: no cover
|
||||
# We're exiting the nursery, so any remaining packets can just get
|
||||
# dropped
|
||||
pass
|
||||
|
||||
fn.route_packet = route_packet_wrapper # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||||
|
||||
for i in range(HANDSHAKES):
|
||||
print("#" * 80)
|
||||
print("#" * 80)
|
||||
print("#" * 80)
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
print("client starting do_handshake")
|
||||
await client.do_handshake()
|
||||
print("client finished do_handshake")
|
||||
msg = str(i).encode()
|
||||
# Make multiple attempts to send data, because the network might
|
||||
# drop it
|
||||
while True:
|
||||
with trio.move_on_after(10) as cscope:
|
||||
await client.send(msg)
|
||||
assert await client.receive() == msg
|
||||
if not cscope.cancelled_caught:
|
||||
break
|
||||
|
||||
|
||||
async def test_implicit_handshake() -> None:
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
|
||||
# Implicit handshake
|
||||
await client.send(b"xyz")
|
||||
assert await client.receive() == b"xyz"
|
||||
|
||||
|
||||
async def test_full_duplex() -> None:
|
||||
# Tests simultaneous send/receive, and also multiple methods implicitly invoking
|
||||
# do_handshake simultaneously.
|
||||
with endpoint() as server_endpoint, endpoint() as client_endpoint:
|
||||
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as server_nursery:
|
||||
|
||||
async def handler(channel: DTLSChannel) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(channel.send, b"from server")
|
||||
nursery.start_soon(channel.receive)
|
||||
|
||||
await server_nursery.start(server_endpoint.serve, server_ctx, handler)
|
||||
|
||||
client = client_endpoint.connect(
|
||||
server_endpoint.socket.getsockname(),
|
||||
client_ctx,
|
||||
)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(client.send, b"from client")
|
||||
nursery.start_soon(client.receive)
|
||||
|
||||
server_nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_channel_closing() -> None:
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake()
|
||||
client.close()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client.send(b"abc")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client.receive()
|
||||
|
||||
# close is idempotent
|
||||
client.close()
|
||||
# can also aclose
|
||||
await client.aclose()
|
||||
|
||||
|
||||
async def test_serve_exits_cleanly_on_close() -> None:
|
||||
async with dtls_echo_server(autocancel=False) as (server_endpoint, address):
|
||||
server_endpoint.close()
|
||||
# Testing that the nursery exits even without being cancelled
|
||||
# close is idempotent
|
||||
server_endpoint.close()
|
||||
|
||||
|
||||
async def test_client_multiplex() -> None:
|
||||
async with dtls_echo_server() as (_, address1), dtls_echo_server() as (_, address2):
|
||||
with endpoint() as client_endpoint:
|
||||
client1 = client_endpoint.connect(address1, client_ctx)
|
||||
client2 = client_endpoint.connect(address2, client_ctx)
|
||||
|
||||
await client1.send(b"abc")
|
||||
await client2.send(b"xyz")
|
||||
assert await client2.receive() == b"xyz"
|
||||
assert await client1.receive() == b"abc"
|
||||
|
||||
client_endpoint.close()
|
||||
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client1.send(b"xxx")
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await client2.receive()
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
client_endpoint.connect(address1, client_ctx)
|
||||
|
||||
async def null_handler(_: object) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
with pytest.raises(trio.ClosedResourceError):
|
||||
await nursery.start(client_endpoint.serve, server_ctx, null_handler)
|
||||
|
||||
|
||||
async def test_dtls_over_dgram_only() -> None:
|
||||
with trio.socket.socket() as s:
|
||||
with pytest.raises(ValueError, match="^DTLS requires a SOCK_DGRAM socket$"):
|
||||
DTLSEndpoint(s)
|
||||
|
||||
|
||||
async def test_double_serve() -> None:
|
||||
async def null_handler(_: object) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
with endpoint() as server_endpoint:
|
||||
await server_endpoint.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
with pytest.raises(trio.BusyResourceError):
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(server_endpoint.serve, server_ctx, null_handler)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_connect_to_non_server(autojump_clock: trio.abc.Clock) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
with endpoint() as client1, endpoint() as client2:
|
||||
await client1.socket.bind(("127.0.0.1", 0))
|
||||
# This should just time out
|
||||
with trio.move_on_after(100) as cscope:
|
||||
channel = client2.connect(client1.socket.getsockname(), client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
|
||||
async def test_incoming_buffer_overflow(autojump_clock: trio.abc.Clock) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
for buffer_size in [10, 20]:
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint(incoming_packets_buffer=buffer_size) as client_endpoint:
|
||||
assert client_endpoint.incoming_packets_buffer == buffer_size
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
for i in range(buffer_size + 15):
|
||||
await client.send(str(i).encode())
|
||||
await trio.sleep(1)
|
||||
stats = client.statistics()
|
||||
assert stats.incoming_packets_dropped_in_trio == 15
|
||||
for i in range(buffer_size):
|
||||
assert await client.receive() == str(i).encode()
|
||||
await client.send(b"buffer clear now")
|
||||
assert await client.receive() == b"buffer clear now"
|
||||
|
||||
|
||||
async def test_server_socket_doesnt_crash_on_garbage(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
from trio._dtls import (
|
||||
ContentType,
|
||||
HandshakeFragment,
|
||||
HandshakeType,
|
||||
ProtocolVersion,
|
||||
Record,
|
||||
encode_handshake_fragment,
|
||||
encode_record,
|
||||
)
|
||||
|
||||
client_hello = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=10,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_extended = client_hello + b"\x00"
|
||||
client_hello_short = client_hello[:-1]
|
||||
# cuts off in middle of handshake message header
|
||||
client_hello_really_short = client_hello[:14]
|
||||
client_hello_corrupt_record_len = bytearray(client_hello)
|
||||
client_hello_corrupt_record_len[11] = 0xFF
|
||||
|
||||
client_hello_fragmented = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=20,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_trailing_data_in_record = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=encode_handshake_fragment(
|
||||
HandshakeFragment(
|
||||
msg_type=HandshakeType.client_hello,
|
||||
msg_len=20,
|
||||
msg_seq=0,
|
||||
frag_offset=0,
|
||||
frag_len=10,
|
||||
frag=bytes(10),
|
||||
),
|
||||
)
|
||||
+ b"\x00",
|
||||
),
|
||||
)
|
||||
|
||||
handshake_empty = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=b"",
|
||||
),
|
||||
)
|
||||
|
||||
client_hello_truncated_in_cookie = encode_record(
|
||||
Record(
|
||||
content_type=ContentType.handshake,
|
||||
version=ProtocolVersion.DTLS10,
|
||||
epoch_seqno=0,
|
||||
payload=bytes(2 + 32 + 1) + b"\xff",
|
||||
),
|
||||
)
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as sock:
|
||||
for bad_packet in [
|
||||
b"",
|
||||
b"xyz",
|
||||
client_hello_extended,
|
||||
client_hello_short,
|
||||
client_hello_really_short,
|
||||
client_hello_corrupt_record_len,
|
||||
client_hello_fragmented,
|
||||
client_hello_trailing_data_in_record,
|
||||
handshake_empty,
|
||||
client_hello_truncated_in_cookie,
|
||||
]:
|
||||
await sock.sendto(bad_packet, address)
|
||||
await trio.sleep(1)
|
||||
|
||||
|
||||
async def test_invalid_cookie_rejected(autojump_clock: trio.abc.Clock) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
from trio._dtls import BadPacket, decode_client_hello_untrusted
|
||||
|
||||
with trio.CancelScope() as cscope:
|
||||
# the first 11 bytes of ClientHello aren't protected by the cookie, so only test
|
||||
# corrupting bytes after that.
|
||||
offset_to_corrupt = count(11)
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
try:
|
||||
_, cookie, _ = decode_client_hello_untrusted(packet.payload)
|
||||
except BadPacket:
|
||||
pass
|
||||
else:
|
||||
if len(cookie) != 0:
|
||||
# this is a challenge response packet
|
||||
# let's corrupt the next offset so the handshake should fail
|
||||
payload = bytearray(packet.payload)
|
||||
offset = next(offset_to_corrupt)
|
||||
if offset >= len(payload):
|
||||
# We've tried all offsets. Clamp offset to the end of the
|
||||
# payload, and terminate the test.
|
||||
offset = len(payload) - 1
|
||||
cscope.cancel()
|
||||
payload[offset] ^= 0x01
|
||||
packet = attrs.evolve(packet, payload=payload)
|
||||
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO: Fix FakeNet typing
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
while True:
|
||||
with endpoint() as client:
|
||||
channel = client.connect(address, client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert cscope.cancelled_caught
|
||||
|
||||
|
||||
async def test_client_cancels_handshake_and_starts_new_one(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
# if a client disappears during the handshake, and then starts a new handshake from
|
||||
# scratch, then the first handler's channel should fail, and a new handler get
|
||||
# started
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
with endpoint() as server, endpoint() as client:
|
||||
await server.socket.bind(("127.0.0.1", 0))
|
||||
async with trio.open_nursery() as nursery:
|
||||
first_time = True
|
||||
|
||||
async def handler(channel: DTLSChannel) -> None:
|
||||
nonlocal first_time
|
||||
if first_time:
|
||||
first_time = False
|
||||
print("handler: first time, cancelling connect")
|
||||
connect_cscope.cancel()
|
||||
await trio.sleep(0.5)
|
||||
print("handler: handshake should fail now")
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await channel.do_handshake()
|
||||
else:
|
||||
print("handler: not first time, sending hello")
|
||||
await channel.send(b"hello")
|
||||
|
||||
await nursery.start(server.serve, server_ctx, handler)
|
||||
|
||||
print("client: starting first connect")
|
||||
with trio.CancelScope() as connect_cscope:
|
||||
channel = client.connect(server.socket.getsockname(), client_ctx)
|
||||
await channel.do_handshake()
|
||||
assert connect_cscope.cancelled_caught
|
||||
|
||||
print("client: starting second connect")
|
||||
channel = client.connect(server.socket.getsockname(), client_ctx)
|
||||
assert await channel.receive() == b"hello"
|
||||
|
||||
# Give handlers a chance to finish
|
||||
await trio.sleep(10)
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_swap_client_server() -> None:
|
||||
with endpoint() as a, endpoint() as b:
|
||||
await a.socket.bind(("127.0.0.1", 0))
|
||||
await b.socket.bind(("127.0.0.1", 0))
|
||||
|
||||
async def echo_handler(channel: DTLSChannel) -> None:
|
||||
async for packet in channel:
|
||||
await channel.send(packet)
|
||||
|
||||
async def crashing_echo_handler(channel: DTLSChannel) -> None:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await echo_handler(channel)
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
await nursery.start(a.serve, server_ctx, crashing_echo_handler)
|
||||
await nursery.start(b.serve, server_ctx, echo_handler)
|
||||
|
||||
b_to_a = b.connect(a.socket.getsockname(), client_ctx)
|
||||
await b_to_a.send(b"b as client")
|
||||
assert await b_to_a.receive() == b"b as client"
|
||||
|
||||
a_to_b = a.connect(b.socket.getsockname(), client_ctx)
|
||||
await a_to_b.do_handshake()
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await b_to_a.send(b"association broken")
|
||||
await a_to_b.send(b"a as client")
|
||||
assert await a_to_b.receive() == b"a as client"
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@slow
|
||||
async def test_openssl_retransmit_doesnt_break_stuff() -> None:
|
||||
# can't use autojump_clock here, because the point of the test is to wait for
|
||||
# openssl's built-in retransmit timer to expire, which is hard-coded to use
|
||||
# wall-clock time.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
blackholed = True
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
if blackholed:
|
||||
print("dropped packet", packet)
|
||||
return
|
||||
print("delivered packet", packet)
|
||||
# packets.append(
|
||||
# scapy.all.IP(
|
||||
# src=packet.source.ip.compressed, dst=packet.destination.ip.compressed
|
||||
# )
|
||||
# / scapy.all.UDP(sport=packet.source.port, dport=packet.destination.port)
|
||||
# / packet.payload
|
||||
# )
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (server_endpoint, address):
|
||||
with endpoint() as client_endpoint:
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def connecter() -> None:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake(initial_retransmit_timeout=1.5)
|
||||
await client.send(b"hi")
|
||||
assert await client.receive() == b"hi"
|
||||
|
||||
nursery.start_soon(connecter)
|
||||
|
||||
# openssl's default timeout is 1 second, so this ensures that it thinks
|
||||
# the timeout has expired
|
||||
await trio.sleep(1.1)
|
||||
# disable blackholing and send a garbage packet to wake up openssl so it
|
||||
# notices the timeout has expired
|
||||
blackholed = False
|
||||
await server_endpoint.socket.sendto(
|
||||
b"xxx",
|
||||
client_endpoint.socket.getsockname(),
|
||||
)
|
||||
# now the client task should finish connecting and exit cleanly
|
||||
|
||||
# scapy.all.wrpcap("/tmp/trace.pcap", packets)
|
||||
|
||||
|
||||
async def test_initial_retransmit_timeout_configuration(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
blackholed = True
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
nonlocal blackholed
|
||||
if blackholed:
|
||||
blackholed = False
|
||||
else:
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
for t in [1, 2, 4]:
|
||||
with endpoint() as client:
|
||||
before = trio.current_time()
|
||||
blackholed = True
|
||||
channel = client.connect(address, client_ctx)
|
||||
await channel.do_handshake(initial_retransmit_timeout=t)
|
||||
after = trio.current_time()
|
||||
assert after - before == t
|
||||
|
||||
|
||||
async def test_explicit_tiny_mtu_is_respected() -> None:
|
||||
# ClientHello is ~240 bytes, and it can't be fragmented, so our mtu has to
|
||||
# be larger than that. (300 is still smaller than any real network though.)
|
||||
MTU = 300
|
||||
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
print(f"delivering {packet}")
|
||||
print(f"payload size: {len(packet.payload)}")
|
||||
assert len(packet.payload) <= MTU
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server(mtu=MTU) as (server, address):
|
||||
with endpoint() as client:
|
||||
channel = client.connect(address, client_ctx)
|
||||
channel.set_ciphertext_mtu(MTU)
|
||||
await channel.do_handshake()
|
||||
await channel.send(b"hi")
|
||||
assert await channel.receive() == b"hi"
|
||||
|
||||
|
||||
@parametrize_ipv6
|
||||
async def test_handshake_handles_minimum_network_mtu(
|
||||
ipv6: bool,
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
# Fake network that has the minimum allowable MTU for whatever protocol we're using.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
mtu = 1280 - 48 if ipv6 else 576 - 28
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
if len(packet.payload) > mtu:
|
||||
print(f"dropping {packet}")
|
||||
else:
|
||||
print(f"delivering {packet}")
|
||||
fn.deliver_packet(packet)
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
# See if we can successfully do a handshake -- some of the volleys will get dropped,
|
||||
# and the retransmit logic should detect this and back off the MTU to something
|
||||
# smaller until it succeeds.
|
||||
async with dtls_echo_server(ipv6=ipv6) as (_, address):
|
||||
with endpoint(ipv6=ipv6) as client_endpoint:
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
# the handshake mtu backoff shouldn't affect the return value from
|
||||
# get_cleartext_mtu, b/c that's under the user's control via
|
||||
# set_ciphertext_mtu
|
||||
client.set_ciphertext_mtu(9999)
|
||||
await client.send(b"xyz")
|
||||
assert await client.receive() == b"xyz"
|
||||
assert client.get_cleartext_mtu() > 9000 # as vegeta said
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_system_task_cleaned_up_on_gc() -> None:
|
||||
before_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
|
||||
# We put this into a sub-function so that everything automatically becomes garbage
|
||||
# when the frame exits. For some reason just doing 'del e' wasn't enough on pypy
|
||||
# with coverage enabled -- I think we were hitting this bug:
|
||||
# https://foss.heptapod.net/pypy/pypy/-/issues/3656
|
||||
async def start_and_forget_endpoint() -> int:
|
||||
e = endpoint()
|
||||
|
||||
# This connection/handshake attempt can't succeed. The only purpose is to force
|
||||
# the endpoint to set up a receive loop.
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
c = e.connect(s.getsockname(), client_ctx)
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(c.do_handshake)
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
during_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
return during_tasks
|
||||
|
||||
with pytest.warns(ResourceWarning):
|
||||
during_tasks = await start_and_forget_endpoint()
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
gc_collect_harder()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
after_tasks = trio.lowlevel.current_statistics().tasks_living
|
||||
assert before_tasks < during_tasks
|
||||
assert before_tasks == after_tasks
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_gc_before_system_task_starts() -> None:
|
||||
e = endpoint()
|
||||
|
||||
with pytest.warns(ResourceWarning):
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
async def test_gc_as_packet_received() -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
e = endpoint()
|
||||
await e.socket.bind(("127.0.0.1", 0))
|
||||
e._ensure_receive_loop()
|
||||
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
with trio.socket.socket(type=trio.socket.SOCK_DGRAM) as s:
|
||||
await s.sendto(b"xxx", e.socket.getsockname())
|
||||
# At this point, the endpoint's receive loop has been marked runnable because it
|
||||
# just received a packet; closing the endpoint socket won't interrupt that. But by
|
||||
# the time it wakes up to process the packet, the endpoint will be gone.
|
||||
with pytest.warns(ResourceWarning):
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("always:unclosed DTLS:ResourceWarning")
|
||||
def test_gc_after_trio_exits() -> None:
|
||||
async def main() -> DTLSEndpoint:
|
||||
# We use fakenet just to make sure no real sockets can leak out of the test
|
||||
# case - on pypy somehow the socket was outliving the gc_collect_harder call
|
||||
# below. Since the test is just making sure DTLSEndpoint.__del__ doesn't explode
|
||||
# when called after trio exits, it doesn't need a real socket.
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
return endpoint()
|
||||
|
||||
e = trio.run(main)
|
||||
with pytest.warns(ResourceWarning):
|
||||
del e
|
||||
gc_collect_harder()
|
||||
|
||||
|
||||
async def test_already_closed_socket_doesnt_crash() -> None:
|
||||
with endpoint() as e:
|
||||
# We close the socket before checkpointing, so the socket will already be closed
|
||||
# when the system task starts up
|
||||
e.socket.close()
|
||||
# Now give it a chance to start up, and hopefully not crash
|
||||
await trio.testing.wait_all_tasks_blocked()
|
||||
|
||||
|
||||
async def test_socket_closed_while_processing_clienthello(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
# Check what happens if the socket is discovered to be closed when sending a
|
||||
# HelloVerifyRequest, since that has its own sending logic
|
||||
async with dtls_echo_server() as (server, address):
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
fn.deliver_packet(packet)
|
||||
server.socket.close()
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
with endpoint() as client_endpoint:
|
||||
with trio.move_on_after(10):
|
||||
client = client_endpoint.connect(address, client_ctx)
|
||||
await client.do_handshake()
|
||||
|
||||
|
||||
async def test_association_replaced_while_handshake_running(
|
||||
autojump_clock: trio.abc.Clock,
|
||||
) -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
def route_packet(packet: UDPPacket) -> None:
|
||||
pass
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
c1 = client_endpoint.connect(address, client_ctx)
|
||||
async with trio.open_nursery() as nursery:
|
||||
|
||||
async def doomed_handshake() -> None:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await c1.do_handshake()
|
||||
|
||||
nursery.start_soon(doomed_handshake)
|
||||
|
||||
await trio.sleep(10)
|
||||
|
||||
client_endpoint.connect(address, client_ctx)
|
||||
|
||||
|
||||
async def test_association_replaced_before_handshake_starts() -> None:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
|
||||
# This test shouldn't send any packets
|
||||
def route_packet(packet: UDPPacket) -> NoReturn: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
fn.route_packet = route_packet # type: ignore[assignment] # TODO add type annotations for FakeNet
|
||||
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
c1 = client_endpoint.connect(address, client_ctx)
|
||||
client_endpoint.connect(address, client_ctx)
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await c1.do_handshake()
|
||||
|
||||
|
||||
async def test_send_to_closed_local_port() -> None:
|
||||
# On Windows, sending a UDP packet to a closed local port can cause a weird
|
||||
# ECONNRESET error later, inside the receive task. Make sure we're handling it
|
||||
# properly.
|
||||
async with dtls_echo_server() as (_, address):
|
||||
with endpoint() as client_endpoint:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(1, 10):
|
||||
channel = client_endpoint.connect(("127.0.0.1", i), client_ctx)
|
||||
nursery.start_soon(channel.do_handshake)
|
||||
channel = client_endpoint.connect(address, client_ctx)
|
||||
await channel.send(b"xxx")
|
||||
assert await channel.receive() == b"xxx"
|
||||
nursery.cancel_scope.cancel()
|
574
lib/python3.13/site-packages/trio/_tests/test_exports.py
Normal file
574
lib/python3.13/site-packages/trio/_tests/test_exports.py
Normal file
@ -0,0 +1,574 @@
|
||||
from __future__ import annotations # isort: split
|
||||
|
||||
import __future__ # Regular import, not special!
|
||||
|
||||
import enum
|
||||
import functools
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path, PurePath
|
||||
from types import ModuleType
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
from .. import _core, _util
|
||||
from .._core._tests.tutil import slow
|
||||
from .pytest_plugin import RUN_SLOW
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator
|
||||
|
||||
mypy_cache_updated = False
|
||||
|
||||
|
||||
try: # If installed, check both versions of this class.
|
||||
from typing_extensions import Protocol as Protocol_ext
|
||||
except ImportError: # pragma: no cover
|
||||
Protocol_ext = Protocol # type: ignore[assignment]
|
||||
|
||||
|
||||
def _ensure_mypy_cache_updated() -> None:
|
||||
# This pollutes the `empty` dir. Should this be changed?
|
||||
try:
|
||||
from mypy.api import run
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
global mypy_cache_updated
|
||||
if not mypy_cache_updated:
|
||||
# mypy cache was *probably* already updated by the other tests,
|
||||
# but `pytest -k ...` might run just this test on its own
|
||||
result = run(
|
||||
[
|
||||
"--config-file=",
|
||||
"--cache-dir=./.mypy_cache",
|
||||
"--no-error-summary",
|
||||
"-c",
|
||||
"import trio",
|
||||
],
|
||||
)
|
||||
assert not result[1] # stderr
|
||||
assert not result[0] # stdout
|
||||
mypy_cache_updated = True
|
||||
|
||||
|
||||
def test_core_is_properly_reexported() -> None:
|
||||
# Each export from _core should be re-exported by exactly one of these
|
||||
# three modules:
|
||||
sources = [trio, trio.lowlevel, trio.testing]
|
||||
for symbol in dir(_core):
|
||||
if symbol.startswith("_"):
|
||||
continue
|
||||
found = 0
|
||||
for source in sources:
|
||||
if symbol in dir(source) and getattr(source, symbol) is getattr(
|
||||
_core,
|
||||
symbol,
|
||||
):
|
||||
found += 1
|
||||
print(symbol, found)
|
||||
assert found == 1
|
||||
|
||||
|
||||
def class_is_final(cls: type) -> bool:
|
||||
"""Check if a class cannot be subclassed."""
|
||||
try:
|
||||
# new_class() handles metaclasses properly, type(...) does not.
|
||||
types.new_class("SubclassTester", (cls,))
|
||||
except TypeError:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def iter_modules(
|
||||
module: types.ModuleType,
|
||||
only_public: bool,
|
||||
) -> Iterator[types.ModuleType]:
|
||||
yield module
|
||||
for name, class_ in module.__dict__.items():
|
||||
if name.startswith("_") and only_public:
|
||||
continue
|
||||
if not isinstance(class_, ModuleType):
|
||||
continue
|
||||
if not class_.__name__.startswith(module.__name__): # pragma: no cover
|
||||
continue
|
||||
if class_ is module: # pragma: no cover
|
||||
continue
|
||||
yield from iter_modules(class_, only_public)
|
||||
|
||||
|
||||
PUBLIC_MODULES = list(iter_modules(trio, only_public=True))
|
||||
ALL_MODULES = list(iter_modules(trio, only_public=False))
|
||||
PUBLIC_MODULE_NAMES = [m.__name__ for m in PUBLIC_MODULES]
|
||||
|
||||
|
||||
# It doesn't make sense for downstream redistributors to run this test, since
|
||||
# they might be using a newer version of Python with additional symbols which
|
||||
# won't be reflected in trio.socket, and this shouldn't cause downstream test
|
||||
# runs to start failing.
|
||||
@pytest.mark.redistributors_should_skip()
|
||||
# Static analysis tools often have trouble with alpha releases, where Python's
|
||||
# internals are in flux, grammar may not have settled down, etc.
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info.releaselevel == "alpha",
|
||||
reason="skip static introspection tools on Python dev/alpha releases",
|
||||
)
|
||||
@pytest.mark.parametrize("modname", PUBLIC_MODULE_NAMES)
|
||||
@pytest.mark.parametrize("tool", ["pylint", "jedi", "mypy", "pyright_verifytypes"])
|
||||
@pytest.mark.filterwarnings(
|
||||
# https://github.com/pypa/setuptools/issues/3274
|
||||
"ignore:module 'sre_constants' is deprecated:DeprecationWarning",
|
||||
)
|
||||
def test_static_tool_sees_all_symbols(tool: str, modname: str, tmp_path: Path) -> None:
|
||||
module = importlib.import_module(modname)
|
||||
|
||||
def no_underscores(symbols: Iterable[str]) -> set[str]:
|
||||
return {symbol for symbol in symbols if not symbol.startswith("_")}
|
||||
|
||||
runtime_names = no_underscores(dir(module))
|
||||
|
||||
# ignore deprecated module `tests` being invisible
|
||||
if modname == "trio":
|
||||
runtime_names.discard("tests")
|
||||
|
||||
# Ignore any __future__ feature objects, if imported under that name.
|
||||
for name in __future__.all_feature_names:
|
||||
if getattr(module, name, None) is getattr(__future__, name):
|
||||
runtime_names.remove(name)
|
||||
|
||||
if tool == "pylint":
|
||||
try:
|
||||
from pylint.lint import PyLinter
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
linter = PyLinter()
|
||||
assert module.__file__ is not None
|
||||
ast = linter.get_ast(module.__file__, modname)
|
||||
static_names = no_underscores(ast) # type: ignore[arg-type]
|
||||
elif tool == "jedi":
|
||||
if sys.implementation.name != "cpython":
|
||||
pytest.skip("jedi does not support pypy")
|
||||
|
||||
try:
|
||||
import jedi
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
# Simulate typing "import trio; trio.<TAB>"
|
||||
script = jedi.Script(f"import {modname}; {modname}.")
|
||||
completions = script.complete()
|
||||
static_names = no_underscores(c.name for c in completions)
|
||||
elif tool == "mypy":
|
||||
if not RUN_SLOW: # pragma: no cover
|
||||
pytest.skip("use --run-slow to check against mypy")
|
||||
|
||||
cache = Path.cwd() / ".mypy_cache"
|
||||
|
||||
_ensure_mypy_cache_updated()
|
||||
|
||||
trio_cache = next(cache.glob("*/trio"))
|
||||
_, modname = (modname + ".").split(".", 1)
|
||||
modname = modname[:-1]
|
||||
mod_cache = trio_cache / modname if modname else trio_cache
|
||||
if mod_cache.is_dir(): # pragma: no coverage
|
||||
mod_cache = mod_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = trio_cache / (modname + ".data.json")
|
||||
|
||||
assert mod_cache.exists()
|
||||
assert mod_cache.is_file()
|
||||
with mod_cache.open() as cache_file:
|
||||
cache_json = json.loads(cache_file.read())
|
||||
static_names = no_underscores(
|
||||
key
|
||||
for key, value in cache_json["names"].items()
|
||||
if not key.startswith(".") and value["kind"] == "Gdef"
|
||||
)
|
||||
elif tool == "pyright_verifytypes":
|
||||
if not RUN_SLOW: # pragma: no cover
|
||||
pytest.skip("use --run-slow to check against pyright")
|
||||
|
||||
try:
|
||||
import pyright # noqa: F401
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
import subprocess
|
||||
|
||||
res = subprocess.run(
|
||||
["pyright", f"--verifytypes={modname}", "--outputjson"],
|
||||
capture_output=True,
|
||||
)
|
||||
current_result = json.loads(res.stdout)
|
||||
|
||||
static_names = {
|
||||
x["name"][len(modname) + 1 :]
|
||||
for x in current_result["typeCompleteness"]["symbols"]
|
||||
if x["name"].startswith(modname)
|
||||
}
|
||||
else: # pragma: no cover
|
||||
raise AssertionError()
|
||||
|
||||
# It's expected that the static set will contain more names than the
|
||||
# runtime set:
|
||||
# - static tools are sometimes sloppy and include deleted names
|
||||
# - some symbols are platform-specific at runtime, but always show up in
|
||||
# static analysis (e.g. in trio.socket or trio.lowlevel)
|
||||
# So we check that the runtime names are a subset of the static names.
|
||||
missing_names = runtime_names - static_names
|
||||
|
||||
# ignore warnings about deprecated module tests
|
||||
missing_names -= {"tests"}
|
||||
|
||||
if missing_names: # pragma: no cover
|
||||
print(f"{tool} can't see the following names in {modname}:")
|
||||
print()
|
||||
for name in sorted(missing_names):
|
||||
print(f" {name}")
|
||||
raise AssertionError()
|
||||
|
||||
|
||||
# this could be sped up by only invoking mypy once per module, or even once for all
|
||||
# modules, instead of once per class.
|
||||
@slow
|
||||
# see comment on test_static_tool_sees_all_symbols
|
||||
@pytest.mark.redistributors_should_skip()
|
||||
# Static analysis tools often have trouble with alpha releases, where Python's
|
||||
# internals are in flux, grammar may not have settled down, etc.
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info.releaselevel == "alpha",
|
||||
reason="skip static introspection tools on Python dev/alpha releases",
|
||||
)
|
||||
@pytest.mark.parametrize("module_name", PUBLIC_MODULE_NAMES)
|
||||
@pytest.mark.parametrize("tool", ["jedi", "mypy"])
|
||||
def test_static_tool_sees_class_members(
|
||||
tool: str,
|
||||
module_name: str,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
module = PUBLIC_MODULES[PUBLIC_MODULE_NAMES.index(module_name)]
|
||||
|
||||
# ignore hidden, but not dunder, symbols
|
||||
def no_hidden(symbols: Iterable[str]) -> set[str]:
|
||||
return {
|
||||
symbol
|
||||
for symbol in symbols
|
||||
if (not symbol.startswith("_")) or symbol.startswith("__")
|
||||
}
|
||||
|
||||
if tool == "jedi" and sys.implementation.name != "cpython":
|
||||
pytest.skip("jedi does not support pypy")
|
||||
|
||||
if tool == "mypy":
|
||||
cache = Path.cwd() / ".mypy_cache"
|
||||
|
||||
_ensure_mypy_cache_updated()
|
||||
|
||||
trio_cache = next(cache.glob("*/trio"))
|
||||
modname = module_name
|
||||
_, modname = (modname + ".").split(".", 1)
|
||||
modname = modname[:-1]
|
||||
mod_cache = trio_cache / modname if modname else trio_cache
|
||||
if mod_cache.is_dir():
|
||||
mod_cache = mod_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = trio_cache / (modname + ".data.json")
|
||||
|
||||
assert mod_cache.exists()
|
||||
assert mod_cache.is_file()
|
||||
with mod_cache.open() as cache_file:
|
||||
cache_json = json.loads(cache_file.read())
|
||||
|
||||
# skip a bunch of file-system activity (probably can un-memoize?)
|
||||
@functools.lru_cache
|
||||
def lookup_symbol(symbol: str) -> dict[str, str]:
|
||||
topname, *modname, name = symbol.split(".")
|
||||
version = next(cache.glob("3.*/"))
|
||||
mod_cache = version / topname
|
||||
if not mod_cache.is_dir():
|
||||
mod_cache = version / (topname + ".data.json")
|
||||
|
||||
if modname:
|
||||
for piece in modname[:-1]:
|
||||
mod_cache /= piece
|
||||
next_cache = mod_cache / modname[-1]
|
||||
if next_cache.is_dir(): # pragma: no coverage
|
||||
mod_cache = next_cache / "__init__.data.json"
|
||||
else:
|
||||
mod_cache = mod_cache / (modname[-1] + ".data.json")
|
||||
elif mod_cache.is_dir():
|
||||
mod_cache /= "__init__.data.json"
|
||||
with mod_cache.open() as f:
|
||||
return json.loads(f.read())["names"][name] # type: ignore[no-any-return]
|
||||
|
||||
errors: dict[str, object] = {}
|
||||
for class_name, class_ in module.__dict__.items():
|
||||
if not isinstance(class_, type):
|
||||
continue
|
||||
if module_name == "trio.socket" and class_name in dir(stdlib_socket):
|
||||
continue
|
||||
|
||||
# ignore class that does dirty tricks
|
||||
if class_ is trio.testing.RaisesGroup:
|
||||
continue
|
||||
|
||||
# dir() and inspect.getmembers doesn't display properties from the metaclass
|
||||
# also ignore some dunder methods that tend to differ but are of no consequence
|
||||
ignore_names = set(dir(type(class_))) | {
|
||||
"__annotations__",
|
||||
"__attrs_attrs__",
|
||||
"__attrs_own_setattr__",
|
||||
"__callable_proto_members_only__",
|
||||
"__class_getitem__",
|
||||
"__final__",
|
||||
"__getstate__",
|
||||
"__match_args__",
|
||||
"__order__",
|
||||
"__orig_bases__",
|
||||
"__parameters__",
|
||||
"__protocol_attrs__",
|
||||
"__setstate__",
|
||||
"__slots__",
|
||||
"__weakref__",
|
||||
# ignore errors about dunders inherited from stdlib that tools might
|
||||
# not see
|
||||
"__copy__",
|
||||
"__deepcopy__",
|
||||
}
|
||||
|
||||
if type(class_) is type:
|
||||
# C extension classes don't have these dunders, but Python classes do
|
||||
ignore_names.add("__firstlineno__")
|
||||
ignore_names.add("__static_attributes__")
|
||||
|
||||
# pypy seems to have some additional dunders that differ
|
||||
if sys.implementation.name == "pypy":
|
||||
ignore_names |= {
|
||||
"__basicsize__",
|
||||
"__dictoffset__",
|
||||
"__itemsize__",
|
||||
"__sizeof__",
|
||||
"__weakrefoffset__",
|
||||
"__unicode__",
|
||||
}
|
||||
|
||||
# inspect.getmembers sees `name` and `value` in Enums, otherwise
|
||||
# it behaves the same way as `dir`
|
||||
# runtime_names = no_underscores(dir(class_))
|
||||
runtime_names = (
|
||||
no_hidden(x[0] for x in inspect.getmembers(class_)) - ignore_names
|
||||
)
|
||||
|
||||
if tool == "jedi":
|
||||
try:
|
||||
import jedi
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
script = jedi.Script(
|
||||
f"from {module_name} import {class_name}; {class_name}.",
|
||||
)
|
||||
completions = script.complete()
|
||||
static_names = no_hidden(c.name for c in completions) - ignore_names
|
||||
|
||||
elif tool == "mypy":
|
||||
# load the cached type information
|
||||
cached_type_info = cache_json["names"][class_name]
|
||||
if "node" not in cached_type_info:
|
||||
cached_type_info = lookup_symbol(cached_type_info["cross_ref"])
|
||||
|
||||
assert "node" in cached_type_info
|
||||
node = cached_type_info["node"]
|
||||
static_names = no_hidden(k for k in node["names"] if not k.startswith("."))
|
||||
for symbol in node["mro"][1:]:
|
||||
node = lookup_symbol(symbol)["node"]
|
||||
static_names |= no_hidden(
|
||||
k for k in node["names"] if not k.startswith(".")
|
||||
)
|
||||
static_names -= ignore_names
|
||||
|
||||
else: # pragma: no cover
|
||||
raise AssertionError("unknown tool")
|
||||
|
||||
missing = runtime_names - static_names
|
||||
extra = static_names - runtime_names
|
||||
|
||||
# using .remove() instead of .delete() to get an error in case they start not
|
||||
# being missing
|
||||
|
||||
if (
|
||||
tool == "jedi"
|
||||
and BaseException in class_.__mro__
|
||||
and sys.version_info >= (3, 11)
|
||||
):
|
||||
missing.remove("add_note")
|
||||
|
||||
if (
|
||||
tool == "mypy"
|
||||
and BaseException in class_.__mro__
|
||||
and sys.version_info >= (3, 11)
|
||||
):
|
||||
extra.remove("__notes__")
|
||||
|
||||
if tool == "mypy" and attrs.has(class_):
|
||||
# e.g. __trio__core__run_CancelScope_AttrsAttributes__
|
||||
before = len(extra)
|
||||
extra = {e for e in extra if not e.endswith("AttrsAttributes__")}
|
||||
assert len(extra) == before - 1
|
||||
|
||||
# mypy does not see these attributes in Enum subclasses
|
||||
if (
|
||||
tool == "mypy"
|
||||
and enum.Enum in class_.__mro__
|
||||
and sys.version_info >= (3, 12)
|
||||
):
|
||||
# Another attribute, in 3.12+ only.
|
||||
extra.remove("__signature__")
|
||||
|
||||
# TODO: this *should* be visible via `dir`!!
|
||||
if tool == "mypy" and class_ == trio.Nursery:
|
||||
extra.remove("cancel_scope")
|
||||
|
||||
# These are (mostly? solely?) *runtime* attributes, often set in
|
||||
# __init__, which doesn't show up with dir() or inspect.getmembers,
|
||||
# but we get them in the way we query mypy & jedi
|
||||
EXTRAS = {
|
||||
trio.DTLSChannel: {"peer_address", "endpoint"},
|
||||
trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"},
|
||||
trio.Process: {"args", "pid", "stderr", "stdin", "stdio", "stdout"},
|
||||
trio.SSLListener: {"transport_listener"},
|
||||
trio.SSLStream: {"transport_stream"},
|
||||
trio.SocketListener: {"socket"},
|
||||
trio.SocketStream: {"socket"},
|
||||
trio.testing.MemoryReceiveStream: {"close_hook", "receive_some_hook"},
|
||||
trio.testing.MemorySendStream: {
|
||||
"close_hook",
|
||||
"send_all_hook",
|
||||
"wait_send_all_might_not_block_hook",
|
||||
},
|
||||
trio.testing.Matcher: {
|
||||
"exception_type",
|
||||
"match",
|
||||
"check",
|
||||
},
|
||||
}
|
||||
if tool == "mypy" and class_ in EXTRAS:
|
||||
before = len(extra)
|
||||
extra -= EXTRAS[class_]
|
||||
assert len(extra) == before - len(EXTRAS[class_])
|
||||
|
||||
# TODO: why is this? Is it a problem?
|
||||
# see https://github.com/python-trio/trio/pull/2631#discussion_r1185615916
|
||||
if class_ == trio.StapledStream:
|
||||
extra.remove("receive_stream")
|
||||
extra.remove("send_stream")
|
||||
|
||||
# I have not researched why these are missing, should maybe create an issue
|
||||
# upstream with jedi
|
||||
if tool == "jedi" and sys.version_info >= (3, 11):
|
||||
if class_ in (
|
||||
trio.DTLSChannel,
|
||||
trio.MemoryReceiveChannel,
|
||||
trio.MemorySendChannel,
|
||||
trio.SSLListener,
|
||||
trio.SocketListener,
|
||||
):
|
||||
missing.remove("__aenter__")
|
||||
missing.remove("__aexit__")
|
||||
if class_ in (trio.DTLSChannel, trio.MemoryReceiveChannel):
|
||||
missing.remove("__aiter__")
|
||||
missing.remove("__anext__")
|
||||
|
||||
if class_ in (trio.Path, trio.WindowsPath, trio.PosixPath):
|
||||
# These are from inherited subclasses.
|
||||
missing -= PurePath.__dict__.keys()
|
||||
# These are unix-only.
|
||||
if tool == "mypy" and sys.platform == "win32":
|
||||
missing -= {"owner", "is_mount", "group"}
|
||||
if tool == "jedi" and sys.platform == "win32":
|
||||
extra -= {"owner", "is_mount", "group"}
|
||||
|
||||
# not sure why jedi in particular ignores this (static?) method in 3.13
|
||||
# (especially given the method is from 3.12....)
|
||||
if (
|
||||
tool == "jedi"
|
||||
and sys.version_info >= (3, 13)
|
||||
and class_ in (trio.Path, trio.WindowsPath, trio.PosixPath)
|
||||
):
|
||||
missing.remove("with_segments")
|
||||
|
||||
if missing or extra: # pragma: no cover
|
||||
errors[f"{module_name}.{class_name}"] = {
|
||||
"missing": missing,
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
# `assert not errors` will not print the full content of errors, even with
|
||||
# `--verbose`, so we manually print it
|
||||
if errors: # pragma: no cover
|
||||
from pprint import pprint
|
||||
|
||||
print(f"\n{tool} can't see the following symbols in {module_name}:")
|
||||
pprint(errors)
|
||||
assert not errors
|
||||
|
||||
|
||||
def test_nopublic_is_final() -> None:
|
||||
"""Check all NoPublicConstructor classes are also @final."""
|
||||
assert class_is_final(_util.NoPublicConstructor) # This is itself final.
|
||||
|
||||
for module in ALL_MODULES:
|
||||
for class_ in module.__dict__.values():
|
||||
if isinstance(class_, _util.NoPublicConstructor):
|
||||
assert class_is_final(class_)
|
||||
|
||||
|
||||
def test_classes_are_final() -> None:
|
||||
# Sanity checks.
|
||||
assert not class_is_final(object)
|
||||
assert class_is_final(bool)
|
||||
|
||||
for module in PUBLIC_MODULES:
|
||||
for name, class_ in module.__dict__.items():
|
||||
if not isinstance(class_, type):
|
||||
continue
|
||||
# Deprecated classes are exported with a leading underscore
|
||||
if name.startswith("_"): # pragma: no cover
|
||||
continue
|
||||
|
||||
# Abstract classes can be subclassed, because that's the whole
|
||||
# point of ABCs
|
||||
if inspect.isabstract(class_):
|
||||
continue
|
||||
# Same with protocols, but only direct children.
|
||||
if Protocol in class_.__bases__ or Protocol_ext in class_.__bases__:
|
||||
continue
|
||||
# Exceptions are allowed to be subclassed, because exception
|
||||
# subclassing isn't used to inherit behavior.
|
||||
if issubclass(class_, BaseException):
|
||||
continue
|
||||
# These are classes that are conceptually abstract, but
|
||||
# inspect.isabstract returns False for boring reasons.
|
||||
if class_ is trio.abc.Instrument or class_ is trio.socket.SocketType:
|
||||
continue
|
||||
# ... insert other special cases here ...
|
||||
|
||||
# The `Path` class needs to support inheritance to allow `WindowsPath` and `PosixPath`.
|
||||
if class_ is trio.Path:
|
||||
continue
|
||||
# don't care about the *Statistics classes
|
||||
if name.endswith("Statistics"):
|
||||
continue
|
||||
|
||||
assert class_is_final(class_)
|
313
lib/python3.13/site-packages/trio/_tests/test_fakenet.py
Normal file
313
lib/python3.13/site-packages/trio/_tests/test_fakenet.py
Normal file
@ -0,0 +1,313 @@
|
||||
import errno
|
||||
import re
|
||||
import socket
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing._fake_net import FakeNet
|
||||
|
||||
# ENOTCONN gives different messages on different platforms
|
||||
if sys.platform == "linux":
|
||||
ENOTCONN_MSG = r"^\[Errno 107\] (Transport endpoint is|Socket) not connected$"
|
||||
elif sys.platform == "darwin":
|
||||
ENOTCONN_MSG = r"^\[Errno 57\] Socket is not connected$"
|
||||
else:
|
||||
ENOTCONN_MSG = r"^\[Errno 10057\] Unknown error$"
|
||||
|
||||
|
||||
def fn() -> FakeNet:
|
||||
fn = FakeNet()
|
||||
fn.enable()
|
||||
return fn
|
||||
|
||||
|
||||
async def test_basic_udp() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[\w+ \d+\] Invalid argument$",
|
||||
) as exc: # Cannot rebind.
|
||||
await s1.bind(("192.0.2.1", 0))
|
||||
assert exc.value.errno == errno.EINVAL
|
||||
|
||||
# Cannot bind multiple sockets to the same address
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^\[\w+ \d+\] (Address (already )?in use|Unknown error)$",
|
||||
) as exc:
|
||||
await s2.bind(("127.0.0.1", port))
|
||||
assert exc.value.errno == errno.EADDRINUSE
|
||||
|
||||
await s2.sendto(b"xyz", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
assert data == b"xyz"
|
||||
assert addr == s2.getsockname()
|
||||
await s1.sendto(b"abc", s2.getsockname())
|
||||
data, addr = await s2.recvfrom(10)
|
||||
assert data == b"abc"
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
|
||||
async def test_msg_trunc() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
await s2.sendto(b"xyz", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
|
||||
|
||||
async def test_recv_methods() -> None:
|
||||
"""Test all recv methods for codecov"""
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
# receiving on an unbound socket is a bad idea (I think?)
|
||||
with pytest.raises(NotImplementedError, match="code will most likely hang"):
|
||||
await s2.recv(10)
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
# recvfrom
|
||||
await s2.sendto(b"abc", s1.getsockname())
|
||||
data, addr = await s1.recvfrom(10)
|
||||
assert data == b"abc"
|
||||
assert addr == s2.getsockname()
|
||||
|
||||
# recv
|
||||
await s1.sendto(b"def", s2.getsockname())
|
||||
data = await s2.recv(10)
|
||||
assert data == b"def"
|
||||
|
||||
# recvfrom_into
|
||||
assert await s1.sendto(b"ghi", s2.getsockname()) == 3
|
||||
buf = bytearray(10)
|
||||
|
||||
with pytest.raises(NotImplementedError, match="^partial recvfrom_into$"):
|
||||
(nbytes, addr) = await s2.recvfrom_into(buf, nbytes=2)
|
||||
|
||||
(nbytes, addr) = await s2.recvfrom_into(buf)
|
||||
assert nbytes == 3
|
||||
assert buf == b"ghi" + b"\x00" * 7
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# recv_into
|
||||
assert await s1.sendto(b"jkl", s2.getsockname()) == 3
|
||||
buf2 = bytearray(10)
|
||||
nbytes = await s2.recv_into(buf2)
|
||||
assert nbytes == 3
|
||||
assert buf2 == b"jkl" + b"\x00" * 7
|
||||
|
||||
if sys.platform == "linux" and sys.implementation.name == "cpython":
|
||||
flags: int = socket.MSG_MORE
|
||||
else:
|
||||
flags = 1
|
||||
|
||||
# Send seems explicitly non-functional
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
await s2.send(b"mno")
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
|
||||
await s2.send(b"mno", flags)
|
||||
|
||||
# sendto errors
|
||||
# it's successfully used earlier
|
||||
with pytest.raises(NotImplementedError, match="^FakeNet send flags must be 0, not"):
|
||||
await s2.sendto(b"mno", flags, s1.getsockname())
|
||||
with pytest.raises(TypeError, match="wrong number of arguments$"):
|
||||
await s2.sendto(b"mno", flags, s1.getsockname(), "extra arg") # type: ignore[call-overload]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform == "win32",
|
||||
reason="functions not in socket on windows",
|
||||
)
|
||||
async def test_nonwindows_functionality() -> None:
|
||||
# mypy doesn't support a good way of aborting typechecking on different platforms
|
||||
if sys.platform != "win32": # pragma: no branch
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s2.bind(("127.0.0.1", 0))
|
||||
|
||||
# sendmsg
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
await s2.sendmsg([b"mno"])
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
|
||||
assert await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) == 3
|
||||
(data, ancdata, msg_flags, addr) = await s2.recvmsg(10)
|
||||
assert data == b"jkl"
|
||||
assert ancdata == []
|
||||
assert msg_flags == 0
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# TODO: recvmsg
|
||||
|
||||
# recvmsg_into
|
||||
assert await s1.sendto(b"xyzw", s2.getsockname()) == 4
|
||||
buf1 = bytearray(2)
|
||||
buf2 = bytearray(3)
|
||||
ret = await s2.recvmsg_into([buf1, buf2])
|
||||
(nbytes, ancdata, msg_flags, addr) = ret
|
||||
assert nbytes == 4
|
||||
assert buf1 == b"xy"
|
||||
assert buf2 == b"zw" + b"\x00"
|
||||
assert ancdata == []
|
||||
assert msg_flags == 0
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
# recvmsg_into with MSG_TRUNC set
|
||||
assert await s1.sendto(b"xyzwv", s2.getsockname()) == 5
|
||||
buf1 = bytearray(2)
|
||||
ret = await s2.recvmsg_into([buf1])
|
||||
(nbytes, ancdata, msg_flags, addr) = ret
|
||||
assert nbytes == 2
|
||||
assert buf1 == b"xy"
|
||||
assert ancdata == []
|
||||
assert msg_flags == socket.MSG_TRUNC
|
||||
assert addr == s1.getsockname()
|
||||
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'share'$",
|
||||
):
|
||||
await s1.share(0) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.platform != "win32",
|
||||
reason="windows-specific fakesocket testing",
|
||||
)
|
||||
async def test_windows_functionality() -> None:
|
||||
# mypy doesn't support a good way of aborting typechecking on different platforms
|
||||
if sys.platform == "win32": # pragma: no branch
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
s2 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'sendmsg'$",
|
||||
):
|
||||
await s1.sendmsg([b"jkl"], (), 0, s2.getsockname()) # type: ignore[attr-defined]
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'recvmsg'$",
|
||||
):
|
||||
s2.recvmsg(0) # type: ignore[attr-defined]
|
||||
with pytest.raises(
|
||||
AttributeError,
|
||||
match="^'FakeSocket' object has no attribute 'recvmsg_into'$",
|
||||
):
|
||||
s2.recvmsg_into([]) # type: ignore[attr-defined]
|
||||
with pytest.raises(NotImplementedError):
|
||||
s1.share(0)
|
||||
|
||||
|
||||
async def test_basic_tcp() -> None:
|
||||
fn()
|
||||
with pytest.raises(NotImplementedError):
|
||||
trio.socket.socket()
|
||||
|
||||
|
||||
async def test_not_implemented_functions() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
|
||||
# getsockopt
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement getsockopt\(\d, \d\)$",
|
||||
):
|
||||
s1.getsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY)
|
||||
|
||||
# setsockopt
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="^FakeNet always has IPV6_V6ONLY=True$",
|
||||
):
|
||||
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, False)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
|
||||
):
|
||||
s1.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, True)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^FakeNet doesn't implement setsockopt\(\d+, \d+, \.\.\.\)$",
|
||||
):
|
||||
s1.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
|
||||
# set_inheritable
|
||||
s1.set_inheritable(False)
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match="^FakeNet can't make inheritable sockets$",
|
||||
):
|
||||
s1.set_inheritable(True)
|
||||
|
||||
# get_inheritable
|
||||
assert not s1.get_inheritable()
|
||||
|
||||
|
||||
async def test_getpeername() -> None:
|
||||
fn()
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
with pytest.raises(OSError, match=ENOTCONN_MSG) as exc:
|
||||
s1.getpeername()
|
||||
assert exc.value.errno == errno.ENOTCONN
|
||||
|
||||
await s1.bind(("127.0.0.1", 0))
|
||||
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match="^This method seems to assume that self._binding has a remote UDPEndpoint$",
|
||||
):
|
||||
s1.getpeername()
|
||||
|
||||
|
||||
async def test_init() -> None:
|
||||
fn()
|
||||
with pytest.raises(
|
||||
NotImplementedError,
|
||||
match=re.escape(
|
||||
f"FakeNet doesn't (yet) support type={trio.socket.SOCK_STREAM}",
|
||||
),
|
||||
):
|
||||
s1 = trio.socket.socket()
|
||||
|
||||
# getsockname on unbound ipv4 socket
|
||||
s1 = trio.socket.socket(type=trio.socket.SOCK_DGRAM)
|
||||
assert s1.getsockname() == ("0.0.0.0", 0)
|
||||
|
||||
# getsockname on bound ipv4 socket
|
||||
await s1.bind(("0.0.0.0", 0))
|
||||
ip, port = s1.getsockname()
|
||||
assert ip == "127.0.0.1"
|
||||
assert port != 0
|
||||
|
||||
# getsockname on unbound ipv6 socket
|
||||
s2 = trio.socket.socket(family=socket.AF_INET6, type=socket.SOCK_DGRAM)
|
||||
assert s2.getsockname() == ("::", 0)
|
||||
|
||||
# getsockname on bound ipv6 socket
|
||||
await s2.bind(("::", 0))
|
||||
ip, port, *_ = s2.getsockname()
|
||||
assert ip == "::1"
|
||||
assert port != 0
|
||||
assert _ == [0, 0]
|
269
lib/python3.13/site-packages/trio/_tests/test_file_io.py
Normal file
269
lib/python3.13/site-packages/trio/_tests/test_file_io.py
Normal file
@ -0,0 +1,269 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from unittest import mock
|
||||
from unittest.mock import sentinel
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import _core, _file_io
|
||||
from trio._file_io import _FILE_ASYNC_METHODS, _FILE_SYNC_ATTRS, AsyncIOWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pathlib
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmp_path: pathlib.Path) -> str:
|
||||
return os.fspath(tmp_path / "test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wrapped() -> mock.Mock:
|
||||
return mock.Mock(spec_set=io.StringIO)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_file(wrapped: mock.Mock) -> AsyncIOWrapper[mock.Mock]:
|
||||
return trio.wrap_file(wrapped)
|
||||
|
||||
|
||||
def test_wrap_invalid() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file("")
|
||||
|
||||
|
||||
def test_wrap_non_iobase() -> None:
|
||||
class FakeFile:
|
||||
def close(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
def write(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
wrapped = FakeFile()
|
||||
assert not isinstance(wrapped, io.IOBase)
|
||||
|
||||
async_file = trio.wrap_file(wrapped)
|
||||
assert isinstance(async_file, AsyncIOWrapper)
|
||||
|
||||
del FakeFile.write
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
trio.wrap_file(FakeFile())
|
||||
|
||||
|
||||
def test_wrapped_property(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
assert async_file.wrapped is wrapped
|
||||
|
||||
|
||||
def test_dir_matches_wrapped(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
attrs = _FILE_SYNC_ATTRS.union(_FILE_ASYNC_METHODS)
|
||||
|
||||
# all supported attrs in wrapped should be available in async_file
|
||||
assert all(attr in dir(async_file) for attr in attrs if attr in dir(wrapped))
|
||||
# all supported attrs not in wrapped should not be available in async_file
|
||||
assert not any(
|
||||
attr in dir(async_file) for attr in attrs if attr not in dir(wrapped)
|
||||
)
|
||||
|
||||
|
||||
def test_unsupported_not_forwarded() -> None:
|
||||
class FakeFile(io.RawIOBase):
|
||||
def unsupported_attr(self) -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
async_file = trio.wrap_file(FakeFile())
|
||||
|
||||
assert hasattr(async_file.wrapped, "unsupported_attr")
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
# B018 "useless expression"
|
||||
async_file.unsupported_attr # type: ignore[attr-defined] # noqa: B018
|
||||
|
||||
|
||||
def test_type_stubs_match_lists() -> None:
|
||||
"""Check the manual stubs match the list of wrapped methods."""
|
||||
# Fetch the module's source code.
|
||||
assert _file_io.__spec__ is not None
|
||||
loader = _file_io.__spec__.loader
|
||||
assert isinstance(loader, importlib.abc.SourceLoader)
|
||||
source = io.StringIO(loader.get_source("trio._file_io"))
|
||||
|
||||
# Find the class, then find the TYPE_CHECKING block.
|
||||
for line in source:
|
||||
if "class AsyncIOWrapper" in line:
|
||||
break
|
||||
else: # pragma: no cover - should always find this
|
||||
pytest.fail("No class definition line?")
|
||||
|
||||
for line in source:
|
||||
if "if TYPE_CHECKING" in line:
|
||||
break
|
||||
else: # pragma: no cover - should always find this
|
||||
pytest.fail("No TYPE CHECKING line?")
|
||||
|
||||
# Now we should be at the type checking block.
|
||||
found: list[tuple[str, str]] = []
|
||||
for line in source: # pragma: no branch - expected to break early
|
||||
if line.strip() and not line.startswith(" " * 8):
|
||||
break # Dedented out of the if TYPE_CHECKING block.
|
||||
match = re.match(r"\s*(async )?def ([a-zA-Z0-9_]+)\(", line)
|
||||
if match is not None:
|
||||
kind = "async" if match.group(1) is not None else "sync"
|
||||
found.append((match.group(2), kind))
|
||||
|
||||
# Compare two lists so that we can easily see duplicates, and see what is different overall.
|
||||
expected = [(fname, "async") for fname in _FILE_ASYNC_METHODS]
|
||||
expected += [(fname, "sync") for fname in _FILE_SYNC_ATTRS]
|
||||
# Ignore order, error if duplicates are present.
|
||||
found.sort()
|
||||
expected.sort()
|
||||
assert found == expected
|
||||
|
||||
|
||||
def test_sync_attrs_forwarded(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, attr_name) is getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_sync_attrs_match_wrapper(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for attr_name in _FILE_SYNC_ATTRS:
|
||||
if attr_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, attr_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, attr_name)
|
||||
|
||||
|
||||
def test_async_methods_generated_once(async_file: AsyncIOWrapper[mock.Mock]) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
assert getattr(async_file, meth_name) is getattr(async_file, meth_name)
|
||||
|
||||
|
||||
# I gave up on typing this one
|
||||
def test_async_methods_signature(async_file: AsyncIOWrapper[mock.Mock]) -> None:
|
||||
# use read as a representative of all async methods
|
||||
assert async_file.read.__name__ == "read"
|
||||
assert async_file.read.__qualname__ == "AsyncIOWrapper.read"
|
||||
|
||||
assert async_file.read.__doc__ is not None
|
||||
assert "io.StringIO.read" in async_file.read.__doc__
|
||||
|
||||
|
||||
async def test_async_methods_wrap(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name not in dir(async_file):
|
||||
continue
|
||||
|
||||
meth = getattr(async_file, meth_name)
|
||||
wrapped_meth = getattr(wrapped, meth_name)
|
||||
|
||||
value = await meth(sentinel.argument, keyword=sentinel.keyword)
|
||||
|
||||
wrapped_meth.assert_called_once_with(
|
||||
sentinel.argument,
|
||||
keyword=sentinel.keyword,
|
||||
)
|
||||
assert value == wrapped_meth()
|
||||
|
||||
wrapped.reset_mock()
|
||||
|
||||
|
||||
async def test_async_methods_match_wrapper(
|
||||
async_file: AsyncIOWrapper[mock.Mock],
|
||||
wrapped: mock.Mock,
|
||||
) -> None:
|
||||
for meth_name in _FILE_ASYNC_METHODS:
|
||||
if meth_name in dir(async_file):
|
||||
continue
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(async_file, meth_name)
|
||||
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(wrapped, meth_name)
|
||||
|
||||
|
||||
async def test_open(path: pathlib.Path) -> None:
|
||||
f = await trio.open_file(path, "w")
|
||||
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
await f.aclose()
|
||||
|
||||
|
||||
async def test_open_context_manager(path: pathlib.Path) -> None:
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
assert not f.closed
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_async_iter() -> None:
|
||||
async_file = trio.wrap_file(io.StringIO("test\nfoo\nbar"))
|
||||
expected = list(async_file.wrapped)
|
||||
async_file.wrapped.seek(0)
|
||||
|
||||
result = [line async for line in async_file]
|
||||
|
||||
assert result == expected
|
||||
|
||||
|
||||
async def test_aclose_cancelled(path: pathlib.Path) -> None:
|
||||
with _core.CancelScope() as cscope:
|
||||
f = await trio.open_file(path, "w")
|
||||
cscope.cancel()
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.write("a")
|
||||
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await f.aclose()
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_detach_rewraps_asynciobase(tmp_path: pathlib.Path) -> None:
|
||||
tmp_file = tmp_path / "filename"
|
||||
tmp_file.touch()
|
||||
# flake8-async does not like opening files in async mode
|
||||
with open(tmp_file, mode="rb", buffering=0) as raw: # noqa: ASYNC230
|
||||
buffered = io.BufferedReader(raw)
|
||||
|
||||
async_file = trio.wrap_file(buffered)
|
||||
|
||||
detached = await async_file.detach()
|
||||
|
||||
assert isinstance(detached, AsyncIOWrapper)
|
||||
assert detached.wrapped is raw
|
@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
from .._highlevel_generic import StapledStream
|
||||
from ..abc import ReceiveStream, SendStream
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RecordSendStream(SendStream):
|
||||
record: list[str | tuple[str, object]] = attrs.Factory(list)
|
||||
|
||||
async def send_all(self, data: object) -> None:
|
||||
self.record.append(("send_all", data))
|
||||
|
||||
async def wait_send_all_might_not_block(self) -> None:
|
||||
self.record.append("wait_send_all_might_not_block")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class RecordReceiveStream(ReceiveStream):
|
||||
record: list[str | tuple[str, int | None]] = attrs.Factory(list)
|
||||
|
||||
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
||||
self.record.append(("receive_some", max_bytes))
|
||||
return b""
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.record.append("aclose")
|
||||
|
||||
|
||||
async def test_StapledStream() -> None:
|
||||
send_stream = RecordSendStream()
|
||||
receive_stream = RecordReceiveStream()
|
||||
stapled = StapledStream(send_stream, receive_stream)
|
||||
|
||||
assert stapled.send_stream is send_stream
|
||||
assert stapled.receive_stream is receive_stream
|
||||
|
||||
await stapled.send_all(b"foo")
|
||||
await stapled.wait_send_all_might_not_block()
|
||||
assert send_stream.record == [
|
||||
("send_all", b"foo"),
|
||||
"wait_send_all_might_not_block",
|
||||
]
|
||||
send_stream.record.clear()
|
||||
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["aclose"]
|
||||
send_stream.record.clear()
|
||||
|
||||
async def fake_send_eof() -> None:
|
||||
send_stream.record.append("send_eof")
|
||||
|
||||
send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
|
||||
await stapled.send_eof()
|
||||
assert send_stream.record == ["send_eof"]
|
||||
|
||||
send_stream.record.clear()
|
||||
assert receive_stream.record == []
|
||||
|
||||
await stapled.receive_some(1234)
|
||||
assert receive_stream.record == [("receive_some", 1234)]
|
||||
assert send_stream.record == []
|
||||
receive_stream.record.clear()
|
||||
|
||||
await stapled.aclose()
|
||||
assert receive_stream.record == ["aclose"]
|
||||
assert send_stream.record == ["aclose"]
|
||||
|
||||
|
||||
async def test_StapledStream_with_erroring_close() -> None:
|
||||
# Make sure that if one of the aclose methods errors out, then the other
|
||||
# one still gets called.
|
||||
class BrokenSendStream(RecordSendStream):
|
||||
async def aclose(self) -> NoReturn:
|
||||
await super().aclose()
|
||||
raise ValueError("send error")
|
||||
|
||||
class BrokenReceiveStream(RecordReceiveStream):
|
||||
async def aclose(self) -> NoReturn:
|
||||
await super().aclose()
|
||||
raise ValueError("recv error")
|
||||
|
||||
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
|
||||
|
||||
with pytest.raises(ValueError, match="^(send|recv) error$") as excinfo:
|
||||
await stapled.aclose()
|
||||
assert isinstance(excinfo.value.__context__, ValueError)
|
||||
|
||||
assert stapled.send_stream.record == ["aclose"]
|
||||
assert stapled.receive_stream.record == ["aclose"]
|
@ -0,0 +1,410 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
from socket import AddressFamily, SocketKind
|
||||
from typing import TYPE_CHECKING, Any, Sequence, overload
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio import (
|
||||
SocketListener,
|
||||
open_tcp_listeners,
|
||||
open_tcp_stream,
|
||||
serve_tcp,
|
||||
)
|
||||
from trio.abc import HostnameResolver, SendStream, SocketFactory
|
||||
from trio.testing import open_stream_to_socket_listener
|
||||
|
||||
from .. import socket as tsocket
|
||||
from .._core._tests.tutil import binds_ipv6
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Buffer
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_basic() -> None:
|
||||
listeners = await open_tcp_listeners(0)
|
||||
assert isinstance(listeners, list)
|
||||
for obj in listeners:
|
||||
assert isinstance(obj, SocketListener)
|
||||
# Binds to wildcard address by default
|
||||
assert obj.socket.family in [tsocket.AF_INET, tsocket.AF_INET6]
|
||||
assert obj.socket.getsockname()[0] in ["0.0.0.0", "::"]
|
||||
|
||||
listener = listeners[0]
|
||||
# Make sure the backlog is at least 2
|
||||
c1 = await open_stream_to_socket_listener(listener)
|
||||
c2 = await open_stream_to_socket_listener(listener)
|
||||
|
||||
s1 = await listener.accept()
|
||||
s2 = await listener.accept()
|
||||
|
||||
# Note that we don't know which client stream is connected to which server
|
||||
# stream
|
||||
await s1.send_all(b"x")
|
||||
await s2.send_all(b"x")
|
||||
assert await c1.receive_some(1) == b"x"
|
||||
assert await c2.receive_some(1) == b"x"
|
||||
|
||||
for resource in [c1, c2, s1, s2, *listeners]:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_specific_port_specific_host() -> None:
|
||||
# Pick a port
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
host, port = sock.getsockname()
|
||||
sock.close()
|
||||
|
||||
(listener,) = await open_tcp_listeners(port, host=host)
|
||||
async with listener:
|
||||
assert listener.socket.getsockname() == (host, port)
|
||||
|
||||
|
||||
@binds_ipv6
|
||||
async def test_open_tcp_listeners_ipv6_v6only() -> None:
|
||||
# Check IPV6_V6ONLY is working properly
|
||||
(ipv6_listener,) = await open_tcp_listeners(0, host="::1")
|
||||
async with ipv6_listener:
|
||||
_, port, *_ = ipv6_listener.socket.getsockname()
|
||||
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"(Error|all attempts to) connect(ing)* to (\(')*127\.0\.0\.1(', |:)\d+(\): Connection refused| failed)$",
|
||||
):
|
||||
await open_tcp_stream("127.0.0.1", port)
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_rebind() -> None:
|
||||
(l1,) = await open_tcp_listeners(0, host="127.0.0.1")
|
||||
sockaddr1 = l1.socket.getsockname()
|
||||
|
||||
# Plain old rebinding while it's still there should fail, even if we have
|
||||
# SO_REUSEADDR set
|
||||
with stdlib_socket.socket() as probe:
|
||||
probe.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_REUSEADDR, 1)
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match="(Address (already )?in use|An attempt was made to access a socket in a way forbidden by its access permissions)$",
|
||||
):
|
||||
probe.bind(sockaddr1)
|
||||
|
||||
# Now use the first listener to set up some connections in various states,
|
||||
# and make sure that they don't create any obstacle to rebinding a second
|
||||
# listener after the first one is closed.
|
||||
c_established = await open_stream_to_socket_listener(l1)
|
||||
s_established = await l1.accept()
|
||||
|
||||
c_time_wait = await open_stream_to_socket_listener(l1)
|
||||
s_time_wait = await l1.accept()
|
||||
# Server-initiated close leaves socket in TIME_WAIT
|
||||
await s_time_wait.aclose()
|
||||
|
||||
await l1.aclose()
|
||||
(l2,) = await open_tcp_listeners(sockaddr1[1], host="127.0.0.1")
|
||||
sockaddr2 = l2.socket.getsockname()
|
||||
|
||||
assert sockaddr1 == sockaddr2
|
||||
assert s_established.socket.getsockname() == sockaddr2
|
||||
assert c_time_wait.socket.getpeername() == sockaddr2
|
||||
|
||||
for resource in [
|
||||
l1,
|
||||
l2,
|
||||
c_established,
|
||||
s_established,
|
||||
c_time_wait,
|
||||
s_time_wait,
|
||||
]:
|
||||
await resource.aclose()
|
||||
|
||||
|
||||
class FakeOSError(OSError):
|
||||
pass
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
_family: AddressFamily = attrs.field(converter=AddressFamily)
|
||||
_type: SocketKind = attrs.field(converter=SocketKind)
|
||||
_proto: int
|
||||
|
||||
closed: bool = False
|
||||
poison_listen: bool = False
|
||||
backlog: int | None = None
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily:
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int: # pragma: no cover
|
||||
return self._proto
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int, buflen: int) -> bytes: ...
|
||||
|
||||
def getsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
if (level, optname) == (tsocket.SOL_SOCKET, tsocket.SO_ACCEPTCONN):
|
||||
return True
|
||||
raise AssertionError() # pragma: no cover
|
||||
|
||||
@overload
|
||||
def setsockopt(self, /, level: int, optname: int, value: int | Buffer) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def bind(self, address: Any) -> None:
|
||||
pass
|
||||
|
||||
def listen(self, /, backlog: int = min(stdlib_socket.SOMAXCONN, 128)) -> None:
|
||||
assert self.backlog is None
|
||||
assert backlog is not None
|
||||
self.backlog = backlog
|
||||
if self.poison_listen:
|
||||
raise FakeOSError("whoops")
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeSocketFactory(SocketFactory):
|
||||
poison_after: int
|
||||
sockets: list[tsocket.SocketType] = attrs.Factory(list)
|
||||
raise_on_family: dict[AddressFamily, int] = attrs.Factory(dict) # family => errno
|
||||
|
||||
def socket(
|
||||
self,
|
||||
family: AddressFamily | int | None = None,
|
||||
type_: SocketKind | int | None = None,
|
||||
proto: int = 0,
|
||||
) -> tsocket.SocketType:
|
||||
assert family is not None
|
||||
assert type_ is not None
|
||||
if isinstance(family, int) and not isinstance(family, AddressFamily):
|
||||
family = AddressFamily(family) # pragma: no cover
|
||||
if family in self.raise_on_family:
|
||||
raise OSError(self.raise_on_family[family], "nope")
|
||||
sock = FakeSocket(family, type_, proto)
|
||||
self.poison_after -= 1
|
||||
if self.poison_after == 0:
|
||||
sock.poison_listen = True
|
||||
self.sockets.append(sock)
|
||||
return sock
|
||||
|
||||
|
||||
@attrs.define(slots=False)
|
||||
class FakeHostnameResolver(HostnameResolver):
|
||||
family_addr_pairs: Sequence[tuple[AddressFamily, str]]
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
assert isinstance(port, int)
|
||||
return [
|
||||
(family, tsocket.SOCK_STREAM, 0, "", (addr, port))
|
||||
for family, addr in self.family_addr_pairs
|
||||
]
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_multiple_host_cleanup_on_error() -> None:
|
||||
# If we were trying to bind to multiple hosts and one of them failed, they
|
||||
# call get cleaned up before returning
|
||||
fsf = FakeSocketFactory(3)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver(
|
||||
[
|
||||
(tsocket.AF_INET, "1.1.1.1"),
|
||||
(tsocket.AF_INET, "2.2.2.2"),
|
||||
(tsocket.AF_INET, "3.3.3.3"),
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
with pytest.raises(FakeOSError):
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
assert len(fsf.sockets) == 3
|
||||
for sock in fsf.sockets:
|
||||
# property only exists on FakeSocket
|
||||
assert sock.closed # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_port_checking() -> None:
|
||||
for host in ["127.0.0.1", None]:
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(None, host=host) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners(b"80", host=host) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_listeners("http", host=host) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_serve_tcp() -> None:
|
||||
async def handler(stream: SendStream) -> None:
|
||||
await stream.send_all(b"x")
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
# nursery.start is incorrectly typed, awaiting #2773
|
||||
listeners: list[SocketListener] = await nursery.start(serve_tcp, handler, 0)
|
||||
stream = await open_stream_to_socket_listener(listeners[0])
|
||||
async with stream:
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"try_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"fail_families",
|
||||
[{tsocket.AF_INET}, {tsocket.AF_INET6}, {tsocket.AF_INET, tsocket.AF_INET6}],
|
||||
)
|
||||
async def test_open_tcp_listeners_some_address_families_unavailable(
|
||||
try_families: set[AddressFamily],
|
||||
fail_families: set[AddressFamily],
|
||||
) -> None:
|
||||
fsf = FakeSocketFactory(
|
||||
10,
|
||||
raise_on_family={family: errno.EAFNOSUPPORT for family in fail_families},
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(family, "foo") for family in try_families]),
|
||||
)
|
||||
|
||||
should_succeed = try_families - fail_families
|
||||
|
||||
if not should_succeed:
|
||||
with pytest.raises(OSError, match="This system doesn't support") as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
|
||||
# open_listeners always creates an exceptiongroup with the
|
||||
# unsupported address families, regardless of the value of
|
||||
# strict_exception_groups or number of unsupported families.
|
||||
assert isinstance(exc_info.value.__cause__, BaseExceptionGroup)
|
||||
for subexc in exc_info.value.__cause__.exceptions:
|
||||
assert "nope" in str(subexc)
|
||||
else:
|
||||
listeners = await open_tcp_listeners(80)
|
||||
for listener in listeners:
|
||||
should_succeed.remove(listener.socket.family)
|
||||
assert not should_succeed
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_socket_fails_not_afnosupport() -> None:
|
||||
fsf = FakeSocketFactory(
|
||||
10,
|
||||
raise_on_family={
|
||||
tsocket.AF_INET: errno.EAFNOSUPPORT,
|
||||
tsocket.AF_INET6: errno.EINVAL,
|
||||
},
|
||||
)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
tsocket.set_custom_hostname_resolver(
|
||||
FakeHostnameResolver([(tsocket.AF_INET, "foo"), (tsocket.AF_INET6, "bar")]),
|
||||
)
|
||||
|
||||
with pytest.raises(OSError, match="nope") as exc_info:
|
||||
await open_tcp_listeners(80, host="example.org")
|
||||
assert exc_info.value.errno == errno.EINVAL
|
||||
assert exc_info.value.__cause__ is None
|
||||
assert "nope" in str(exc_info.value)
|
||||
|
||||
|
||||
# We used to have an elaborate test that opened a real TCP listening socket
|
||||
# and then tried to measure its backlog by making connections to it. And most
|
||||
# of the time, it worked. But no matter what we tried, it was always fragile,
|
||||
# because it had to do things like use timeouts to guess when the listening
|
||||
# queue was full, sometimes the CI hosts go into SYN-cookie mode (where there
|
||||
# effectively is no backlog), sometimes the host might not be enough resources
|
||||
# to give us the full requested backlog... it was a mess. So now we just check
|
||||
# that the backlog argument is passed through correctly.
|
||||
async def test_open_tcp_listeners_backlog() -> None:
|
||||
fsf = FakeSocketFactory(99)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
for given, expected in [
|
||||
(None, 0xFFFF),
|
||||
(99999999, 0xFFFF),
|
||||
(10, 10),
|
||||
(1, 1),
|
||||
]:
|
||||
listeners = await open_tcp_listeners(0, backlog=given)
|
||||
assert listeners
|
||||
for listener in listeners:
|
||||
# `backlog` only exists on FakeSocket
|
||||
assert listener.socket.backlog == expected # type: ignore[attr-defined]
|
||||
|
||||
|
||||
async def test_open_tcp_listeners_backlog_float_error() -> None:
|
||||
fsf = FakeSocketFactory(99)
|
||||
tsocket.set_custom_socket_factory(fsf)
|
||||
for should_fail in (0.0, 2.18, 3.14, 9.75):
|
||||
with pytest.raises(
|
||||
TypeError,
|
||||
match=f"backlog must be an int or None, not {should_fail!r}",
|
||||
):
|
||||
await open_tcp_listeners(0, backlog=should_fail) # type: ignore[arg-type]
|
@ -0,0 +1,686 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import sys
|
||||
from socket import AddressFamily, SocketKind
|
||||
from typing import TYPE_CHECKING, Any, Sequence
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio._highlevel_open_tcp_stream import (
|
||||
close_all,
|
||||
format_host_port,
|
||||
open_tcp_stream,
|
||||
reorder_for_rfc_6555_section_5_4,
|
||||
)
|
||||
from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trio.testing import MockClock
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup
|
||||
|
||||
|
||||
def test_close_all() -> None:
|
||||
class CloseMe(SocketType):
|
||||
closed = False
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
class CloseKiller(SocketType):
|
||||
def close(self) -> None:
|
||||
raise OSError("os error text")
|
||||
|
||||
c: CloseMe = CloseMe()
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_all() as to_close:
|
||||
to_close.add(c)
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
c = CloseMe()
|
||||
with pytest.raises(OSError, match="os error text"):
|
||||
with close_all() as to_close:
|
||||
to_close.add(CloseKiller())
|
||||
to_close.add(c)
|
||||
assert c.closed
|
||||
|
||||
|
||||
def test_reorder_for_rfc_6555_section_5_4() -> None:
|
||||
def fake4(
|
||||
i: int,
|
||||
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
|
||||
return (
|
||||
AF_INET,
|
||||
SOCK_STREAM,
|
||||
IPPROTO_TCP,
|
||||
"",
|
||||
(f"10.0.0.{i}", 80),
|
||||
)
|
||||
|
||||
def fake6(
|
||||
i: int,
|
||||
) -> tuple[socket.AddressFamily, socket.SocketKind, int, str, tuple[str, int]]:
|
||||
return (AF_INET6, SOCK_STREAM, IPPROTO_TCP, "", (f"::{i}", 80))
|
||||
|
||||
for fake in fake4, fake6:
|
||||
# No effect on homogeneous lists
|
||||
targets = [fake(0), fake(1), fake(2)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0), fake(1), fake(2)]
|
||||
|
||||
# Single item lists also OK
|
||||
targets = [fake(0)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake(0)]
|
||||
|
||||
# If the list starts out with different families in positions 0 and 1,
|
||||
# then it's left alone
|
||||
orig = [fake4(0), fake6(0), fake4(1), fake6(1)]
|
||||
targets = list(orig)
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == orig
|
||||
|
||||
# If not, it's reordered
|
||||
targets = [fake4(0), fake4(1), fake4(2), fake6(0), fake6(1)]
|
||||
reorder_for_rfc_6555_section_5_4(targets)
|
||||
assert targets == [fake4(0), fake6(0), fake4(1), fake4(2), fake6(1)]
|
||||
|
||||
|
||||
def test_format_host_port() -> None:
|
||||
assert format_host_port("127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port(b"127.0.0.1", 80) == "127.0.0.1:80"
|
||||
assert format_host_port("example.com", 443) == "example.com:443"
|
||||
assert format_host_port(b"example.com", 443) == "example.com:443"
|
||||
assert format_host_port("::1", "http") == "[::1]:http"
|
||||
assert format_host_port(b"::1", "http") == "[::1]:http"
|
||||
|
||||
|
||||
# Make sure we can connect to localhost using real kernel sockets
|
||||
async def test_open_tcp_stream_real_socket_smoketest() -> None:
|
||||
listen_sock = trio.socket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
_, listen_port = listen_sock.getsockname()
|
||||
listen_sock.listen(1)
|
||||
client_stream = await open_tcp_stream("127.0.0.1", listen_port)
|
||||
server_sock, _ = await listen_sock.accept()
|
||||
await client_stream.send_all(b"x")
|
||||
assert await server_sock.recv(1) == b"x"
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
|
||||
listen_sock.close()
|
||||
|
||||
|
||||
async def test_open_tcp_stream_input_validation() -> None:
|
||||
with pytest.raises(ValueError, match="^host must be str or bytes, not None$"):
|
||||
await open_tcp_stream(None, 80) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
await open_tcp_stream("127.0.0.1", b"80") # type: ignore[arg-type]
|
||||
|
||||
|
||||
def can_bind_127_0_0_2() -> bool:
|
||||
with socket.socket() as s:
|
||||
try:
|
||||
s.bind(("127.0.0.2", 0))
|
||||
except OSError:
|
||||
return False
|
||||
# s.getsockname() is typed as returning Any
|
||||
return s.getsockname()[0] == "127.0.0.2" # type: ignore[no-any-return]
|
||||
|
||||
|
||||
async def test_local_address_real() -> None:
|
||||
with trio.socket.socket() as listener:
|
||||
await listener.bind(("127.0.0.1", 0))
|
||||
listener.listen()
|
||||
|
||||
# It's hard to test local_address properly, because you need multiple
|
||||
# local addresses that you can bind to. Fortunately, on most Linux
|
||||
# systems, you can bind to any 127.*.*.* address, and they all go
|
||||
# through the loopback interface. So we can use a non-standard
|
||||
# loopback address. On other systems, the only address we know for
|
||||
# certain we have is 127.0.0.1, so we can't really test local_address=
|
||||
# properly -- passing local_address=127.0.0.1 is indistinguishable
|
||||
# from not passing local_address= at all. But, we can still do a smoke
|
||||
# test to make sure the local_address= code doesn't crash.
|
||||
local_address = "127.0.0.2" if can_bind_127_0_0_2() else "127.0.0.1"
|
||||
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(),
|
||||
local_address=local_address,
|
||||
) as client_stream:
|
||||
assert client_stream.socket.getsockname()[0] == local_address
|
||||
if hasattr(trio.socket, "IP_BIND_ADDRESS_NO_PORT"):
|
||||
assert client_stream.socket.getsockopt(
|
||||
trio.socket.IPPROTO_IP,
|
||||
trio.socket.IP_BIND_ADDRESS_NO_PORT,
|
||||
)
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
await client_stream.aclose()
|
||||
server_sock.close()
|
||||
# accept returns tuple[SocketType, object], due to typeshed returning `Any`
|
||||
assert remote_addr[0] == local_address
|
||||
|
||||
# Trying to connect to an ipv4 address with the ipv6 wildcard
|
||||
# local_address should fail
|
||||
with pytest.raises(
|
||||
OSError,
|
||||
match=r"^all attempts to connect* to *127\.0\.0\.\d:\d+ failed$",
|
||||
):
|
||||
await open_tcp_stream(*listener.getsockname(), local_address="::")
|
||||
|
||||
# But the ipv4 wildcard address should work
|
||||
async with await open_tcp_stream(
|
||||
*listener.getsockname(),
|
||||
local_address="0.0.0.0",
|
||||
) as client_stream:
|
||||
server_sock, remote_addr = await listener.accept()
|
||||
server_sock.close()
|
||||
assert remote_addr == client_stream.socket.getsockname()
|
||||
|
||||
|
||||
# Now, thorough tests using fake sockets
|
||||
|
||||
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class FakeSocket(trio.socket.SocketType):
|
||||
scenario: Scenario
|
||||
_family: AddressFamily
|
||||
_type: SocketKind
|
||||
_proto: int
|
||||
|
||||
ip: str | int | None = None
|
||||
port: str | int | None = None
|
||||
succeeded: bool = False
|
||||
closed: bool = False
|
||||
failing: bool = False
|
||||
|
||||
@property
|
||||
def type(self) -> SocketKind:
|
||||
return self._type
|
||||
|
||||
@property
|
||||
def family(self) -> AddressFamily: # pragma: no cover
|
||||
return self._family
|
||||
|
||||
@property
|
||||
def proto(self) -> int: # pragma: no cover
|
||||
return self._proto
|
||||
|
||||
async def connect(self, sockaddr: tuple[str | int, str | int | None]) -> None:
|
||||
self.ip = sockaddr[0]
|
||||
self.port = sockaddr[1]
|
||||
assert self.ip not in self.scenario.sockets
|
||||
self.scenario.sockets[self.ip] = self
|
||||
self.scenario.connect_times[self.ip] = trio.current_time()
|
||||
delay, result = self.scenario.ip_dict[self.ip]
|
||||
await trio.sleep(delay)
|
||||
if result == "error":
|
||||
raise OSError("sorry")
|
||||
if result == "postconnect_fail":
|
||||
self.failing = True
|
||||
self.succeeded = True
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
# called when SocketStream is constructed
|
||||
def setsockopt(self, *args: object, **kwargs: object) -> None:
|
||||
if self.failing:
|
||||
# raise something that isn't OSError as SocketStream
|
||||
# ignores those
|
||||
raise KeyboardInterrupt
|
||||
|
||||
|
||||
class Scenario(trio.abc.SocketFactory, trio.abc.HostnameResolver):
|
||||
def __init__(
|
||||
self,
|
||||
port: int,
|
||||
ip_list: Sequence[tuple[str, float, str]],
|
||||
supported_families: set[AddressFamily],
|
||||
) -> None:
|
||||
# ip_list have to be unique
|
||||
ip_order = [ip for (ip, _, _) in ip_list]
|
||||
assert len(set(ip_order)) == len(ip_list)
|
||||
ip_dict: dict[str | int, tuple[float, str]] = {}
|
||||
for ip, delay, result in ip_list:
|
||||
assert delay >= 0
|
||||
assert result in ["error", "success", "postconnect_fail"]
|
||||
ip_dict[ip] = (delay, result)
|
||||
|
||||
self.port = port
|
||||
self.ip_order = ip_order
|
||||
self.ip_dict = ip_dict
|
||||
self.supported_families = supported_families
|
||||
self.socket_count = 0
|
||||
self.sockets: dict[str | int, FakeSocket] = {}
|
||||
self.connect_times: dict[str | int, float] = {}
|
||||
|
||||
def socket(
|
||||
self,
|
||||
family: AddressFamily | int | None = None,
|
||||
type_: SocketKind | int | None = None,
|
||||
proto: int | None = None,
|
||||
) -> SocketType:
|
||||
assert isinstance(family, AddressFamily)
|
||||
assert isinstance(type_, SocketKind)
|
||||
assert proto is not None
|
||||
if family not in self.supported_families:
|
||||
raise OSError("pretending not to support this family")
|
||||
self.socket_count += 1
|
||||
return FakeSocket(self, family, type_, proto)
|
||||
|
||||
def _ip_to_gai_entry(self, ip: str) -> tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int, int, int] | tuple[str, int],
|
||||
]:
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int]
|
||||
if ":" in ip:
|
||||
family = trio.socket.AF_INET6
|
||||
sockaddr = (ip, self.port, 0, 0)
|
||||
else:
|
||||
family = trio.socket.AF_INET
|
||||
sockaddr = (ip, self.port)
|
||||
return (family, SOCK_STREAM, IPPROTO_TCP, "", sockaddr)
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = -1,
|
||||
type: int = -1,
|
||||
proto: int = -1,
|
||||
flags: int = -1,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int, int, int] | tuple[str, int],
|
||||
]
|
||||
]:
|
||||
assert host == b"test.example.com"
|
||||
assert port == self.port
|
||||
assert family == trio.socket.AF_UNSPEC
|
||||
assert type == trio.socket.SOCK_STREAM
|
||||
assert proto == 0
|
||||
assert flags == 0
|
||||
return [self._ip_to_gai_entry(ip) for ip in self.ip_order]
|
||||
|
||||
async def getnameinfo(
|
||||
self,
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int],
|
||||
flags: int,
|
||||
) -> tuple[str, str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def check(self, succeeded: SocketType | None) -> None:
|
||||
# sockets only go into self.sockets when connect is called; make sure
|
||||
# all the sockets that were created did in fact go in there.
|
||||
assert self.socket_count == len(self.sockets)
|
||||
|
||||
for ip, socket_ in self.sockets.items():
|
||||
assert ip in self.ip_dict
|
||||
if socket_ is not succeeded:
|
||||
assert socket_.closed
|
||||
assert socket_.port == self.port
|
||||
|
||||
|
||||
async def run_scenario(
|
||||
# The port to connect to
|
||||
port: int,
|
||||
# A list of
|
||||
# (ip, delay, result)
|
||||
# tuples, where delay is in seconds and result is "success" or "error"
|
||||
# The ip's will be returned from getaddrinfo in this order, and then
|
||||
# connect() calls to them will have the given result.
|
||||
ip_list: Sequence[tuple[str, float, str]],
|
||||
*,
|
||||
# If False, AF_INET4/6 sockets error out on creation, before connect is
|
||||
# even called.
|
||||
ipv4_supported: bool = True,
|
||||
ipv6_supported: bool = True,
|
||||
# Normally, we return (winning_sock, scenario object)
|
||||
# If this is True, we require there to be an exception, and return
|
||||
# (exception, scenario object)
|
||||
expect_error: tuple[type[BaseException], ...] | type[BaseException] = (),
|
||||
**kwargs: Any,
|
||||
) -> tuple[SocketType, Scenario] | tuple[BaseException, Scenario]:
|
||||
supported_families = set()
|
||||
if ipv4_supported:
|
||||
supported_families.add(trio.socket.AF_INET)
|
||||
if ipv6_supported:
|
||||
supported_families.add(trio.socket.AF_INET6)
|
||||
scenario = Scenario(port, ip_list, supported_families)
|
||||
trio.socket.set_custom_hostname_resolver(scenario)
|
||||
trio.socket.set_custom_socket_factory(scenario)
|
||||
|
||||
try:
|
||||
stream = await open_tcp_stream("test.example.com", port, **kwargs)
|
||||
assert expect_error == ()
|
||||
scenario.check(stream.socket)
|
||||
return (stream.socket, scenario)
|
||||
except AssertionError: # pragma: no cover
|
||||
raise
|
||||
except expect_error as exc:
|
||||
scenario.check(None)
|
||||
return (exc, scenario)
|
||||
|
||||
|
||||
async def test_one_host_quick_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(80, [("1.2.3.4", 0.123, "success")])
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(81, [("1.2.3.4", 100, "success")])
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.2.3.4"
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_quick_fail(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
82,
|
||||
[("1.2.3.4", 0.123, "error")],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 0.123
|
||||
|
||||
|
||||
async def test_one_host_slow_fail(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
83,
|
||||
[("1.2.3.4", 100, "error")],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
assert trio.current_time() == 100
|
||||
|
||||
|
||||
async def test_one_host_failed_after_connect(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
83,
|
||||
[("1.2.3.4", 1, "postconnect_fail")],
|
||||
expect_error=KeyboardInterrupt,
|
||||
)
|
||||
assert isinstance(exc, KeyboardInterrupt)
|
||||
|
||||
|
||||
# With the default 0.250 second delay, the third attempt will win
|
||||
async def test_basic_fallthrough(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "3.3.3.3"
|
||||
# current time is default time + default time + connection time
|
||||
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
"3.3.3.3": 0.500,
|
||||
}
|
||||
|
||||
|
||||
async def test_early_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 0.1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "2.2.2.2"
|
||||
assert trio.current_time() == (0.250 + 0.1)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
# 3.3.3.3 was never even started
|
||||
}
|
||||
|
||||
|
||||
# With a 0.450 second delay, the first attempt will win
|
||||
async def test_custom_delay(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=0.450,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "1.1.1.1"
|
||||
assert trio.current_time() == 1
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.450,
|
||||
"3.3.3.3": 0.900,
|
||||
}
|
||||
|
||||
|
||||
async def test_none_default(autojump_clock: MockClock) -> None:
|
||||
"""Copy of test_basic_fallthrough, but specifying the delay =None"""
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 1, "success"),
|
||||
("2.2.2.2", 1, "success"),
|
||||
("3.3.3.3", 0.2, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=None,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "3.3.3.3"
|
||||
# current time is default time + default time + connection time
|
||||
assert trio.current_time() == (0.250 + 0.250 + 0.2)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.250,
|
||||
"3.3.3.3": 0.500,
|
||||
}
|
||||
|
||||
|
||||
async def test_custom_errors_expedite(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
# .25 is the default timeout
|
||||
("4.4.4.4", 0.25, "success"),
|
||||
],
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == (0.1 + 0.2 + 0.25 + 0.25)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_all_fail(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.1, "error"),
|
||||
("2.2.2.2", 0.2, "error"),
|
||||
("3.3.3.3", 10, "error"),
|
||||
("4.4.4.4", 0.250, "error"),
|
||||
],
|
||||
expect_error=OSError,
|
||||
)
|
||||
assert isinstance(exc, OSError)
|
||||
|
||||
subexceptions = (Matcher(OSError, match="^sorry$"),) * 4
|
||||
assert RaisesGroup(
|
||||
*subexceptions,
|
||||
match="all attempts to connect to test.example.com:80 failed",
|
||||
).matches(exc.__cause__)
|
||||
|
||||
assert trio.current_time() == (0.1 + 0.2 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.1,
|
||||
"3.3.3.3": 0.1 + 0.2,
|
||||
"4.4.4.4": 0.1 + 0.2 + 0.25,
|
||||
}
|
||||
|
||||
|
||||
async def test_multi_success(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 0.5, "error"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10 - 1, "success"),
|
||||
("4.4.4.4", 10 - 2, "success"),
|
||||
("5.5.5.5", 0.5, "error"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert not scenario.sockets["1.1.1.1"].succeeded
|
||||
assert (
|
||||
scenario.sockets["2.2.2.2"].succeeded
|
||||
or scenario.sockets["3.3.3.3"].succeeded
|
||||
or scenario.sockets["4.4.4.4"].succeeded
|
||||
)
|
||||
assert not scenario.sockets["5.5.5.5"].succeeded
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip in ["2.2.2.2", "3.3.3.3", "4.4.4.4"]
|
||||
assert trio.current_time() == (0.5 + 10)
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"2.2.2.2": 0.5,
|
||||
"3.3.3.3": 1.5,
|
||||
"4.4.4.4": 2.5,
|
||||
"5.5.5.5": 3.5,
|
||||
}
|
||||
|
||||
|
||||
async def test_does_reorder(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "error"),
|
||||
# This would win if we tried it first...
|
||||
("2.2.2.2", 1, "success"),
|
||||
# But in fact we try this first, because of section 5.4
|
||||
("::3", 0.5, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.5
|
||||
assert scenario.connect_times == {
|
||||
"1.1.1.1": 0,
|
||||
"::3": 1,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv4(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 10, "success"),
|
||||
("2.2.2.2", 0, "success"),
|
||||
("::3", 0.1, "success"),
|
||||
("4.4.4.4", 0, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv4_supported=False,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "::3"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"::1": 0,
|
||||
"::3": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_handles_no_ipv6(autojump_clock: MockClock) -> None:
|
||||
sock, scenario = await run_scenario(
|
||||
80,
|
||||
# Here the ipv6 addresses fail at socket creation time, so the connect
|
||||
# configuration doesn't matter
|
||||
[
|
||||
("::1", 0, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("::3", 0, "success"),
|
||||
("4.4.4.4", 0.1, "success"),
|
||||
],
|
||||
happy_eyeballs_delay=1,
|
||||
ipv6_supported=False,
|
||||
)
|
||||
assert isinstance(sock, FakeSocket)
|
||||
assert sock.ip == "4.4.4.4"
|
||||
assert trio.current_time() == 1 + 0.1
|
||||
assert scenario.connect_times == {
|
||||
"2.2.2.2": 0,
|
||||
"4.4.4.4": 1.0,
|
||||
}
|
||||
|
||||
|
||||
async def test_no_hosts(autojump_clock: MockClock) -> None:
|
||||
exc, scenario = await run_scenario(80, [], expect_error=OSError)
|
||||
assert "no results found" in str(exc)
|
||||
|
||||
|
||||
async def test_cancel(autojump_clock: MockClock) -> None:
|
||||
with trio.move_on_after(5) as cancel_scope:
|
||||
exc, scenario = await run_scenario(
|
||||
80,
|
||||
[
|
||||
("1.1.1.1", 10, "success"),
|
||||
("2.2.2.2", 10, "success"),
|
||||
("3.3.3.3", 10, "success"),
|
||||
("4.4.4.4", 10, "success"),
|
||||
],
|
||||
expect_error=BaseExceptionGroup,
|
||||
)
|
||||
assert isinstance(exc, BaseException)
|
||||
# What comes out should be 1 or more Cancelled errors that all belong
|
||||
# to this cancel_scope; this is the easiest way to check that
|
||||
raise exc
|
||||
assert cancel_scope.cancelled_caught
|
||||
|
||||
assert trio.current_time() == 5
|
||||
|
||||
# This should have been called already, but just to make sure, since the
|
||||
# exception-handling logic in run_scenario is a bit complicated and the
|
||||
# main thing we care about here is that all the sockets were cleaned up.
|
||||
scenario.check(succeeded=None)
|
@ -0,0 +1,86 @@
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trio import Path, open_unix_socket
|
||||
from trio._highlevel_open_unix_stream import close_on_error
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform != "win32"
|
||||
|
||||
skip_if_not_unix = pytest.mark.skipif(
|
||||
not hasattr(socket, "AF_UNIX"),
|
||||
reason="Needs unix socket support",
|
||||
)
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
def test_close_on_error() -> None:
|
||||
class CloseMe:
|
||||
closed = False
|
||||
|
||||
def close(self) -> None:
|
||||
self.closed = True
|
||||
|
||||
with close_on_error(CloseMe()) as c:
|
||||
pass
|
||||
assert not c.closed
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
with close_on_error(CloseMe()) as c:
|
||||
raise RuntimeError
|
||||
assert c.closed
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
@pytest.mark.parametrize("filename", [4, 4.5])
|
||||
async def test_open_with_bad_filename_type(filename: float) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
await open_unix_socket(filename) # type: ignore[arg-type]
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
async def test_open_bad_socket() -> None:
|
||||
# mktemp is marked as insecure, but that's okay, we don't want the file to
|
||||
# exist
|
||||
name = tempfile.mktemp()
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await open_unix_socket(name)
|
||||
|
||||
|
||||
@skip_if_not_unix
|
||||
async def test_open_unix_socket() -> None:
|
||||
for name_type in [Path, str]:
|
||||
name = tempfile.mktemp()
|
||||
serv_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
with serv_sock:
|
||||
serv_sock.bind(name)
|
||||
try:
|
||||
serv_sock.listen(1)
|
||||
|
||||
# The actual function we're testing
|
||||
unix_socket = await open_unix_socket(name_type(name))
|
||||
|
||||
async with unix_socket:
|
||||
client, _ = serv_sock.accept()
|
||||
with client:
|
||||
await unix_socket.send_all(b"test")
|
||||
assert client.recv(2048) == b"test"
|
||||
|
||||
client.sendall(b"response")
|
||||
received = await unix_socket.receive_some(2048)
|
||||
assert received == b"response"
|
||||
finally:
|
||||
os.unlink(name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(hasattr(socket, "AF_UNIX"), reason="Test for non-unix platforms")
|
||||
async def test_error_on_no_unix() -> None:
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="^Unix sockets are not supported on this platform$",
|
||||
):
|
||||
await open_unix_socket("")
|
@ -0,0 +1,183 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, NoReturn
|
||||
|
||||
import attrs
|
||||
|
||||
import trio
|
||||
from trio import Nursery, StapledStream, TaskStatus
|
||||
from trio.testing import (
|
||||
Matcher,
|
||||
MemoryReceiveStream,
|
||||
MemorySendStream,
|
||||
MockClock,
|
||||
RaisesGroup,
|
||||
memory_stream_pair,
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pytest
|
||||
|
||||
from trio._channel import MemoryReceiveChannel, MemorySendChannel
|
||||
from trio.abc import Stream
|
||||
|
||||
# types are somewhat tentative - I just bruteforced them until I got something that didn't
|
||||
# give errors
|
||||
StapledMemoryStream = StapledStream[MemorySendStream, MemoryReceiveStream]
|
||||
|
||||
|
||||
@attrs.define(eq=False, slots=False)
|
||||
class MemoryListener(trio.abc.Listener[StapledMemoryStream]):
|
||||
closed: bool = False
|
||||
accepted_streams: list[trio.abc.Stream] = attrs.Factory(list)
|
||||
queued_streams: tuple[
|
||||
MemorySendChannel[StapledMemoryStream],
|
||||
MemoryReceiveChannel[StapledMemoryStream],
|
||||
] = attrs.Factory(lambda: trio.open_memory_channel[StapledMemoryStream](1))
|
||||
accept_hook: Callable[[], Awaitable[object]] | None = None
|
||||
|
||||
async def connect(self) -> StapledMemoryStream:
|
||||
assert not self.closed
|
||||
client, server = memory_stream_pair()
|
||||
await self.queued_streams[0].send(server)
|
||||
return client
|
||||
|
||||
async def accept(self) -> StapledMemoryStream:
|
||||
await trio.lowlevel.checkpoint()
|
||||
assert not self.closed
|
||||
if self.accept_hook is not None:
|
||||
await self.accept_hook()
|
||||
stream = await self.queued_streams[1].receive()
|
||||
self.accepted_streams.append(stream)
|
||||
return stream
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self.closed = True
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
async def test_serve_listeners_basic() -> None:
|
||||
listeners = [MemoryListener(), MemoryListener()]
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook() -> None:
|
||||
# Make sure this is a forceful close
|
||||
assert trio.current_effective_deadline() == float("-inf")
|
||||
record.append("closed")
|
||||
|
||||
async def handler(stream: StapledMemoryStream) -> None:
|
||||
await stream.send_all(b"123")
|
||||
assert await stream.receive_some(10) == b"456"
|
||||
stream.send_stream.close_hook = close_hook
|
||||
stream.receive_stream.close_hook = close_hook
|
||||
|
||||
async def client(listener: MemoryListener) -> None:
|
||||
s = await listener.connect()
|
||||
assert await s.receive_some(10) == b"123"
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def do_tests(parent_nursery: Nursery) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for listener in listeners:
|
||||
for _ in range(3):
|
||||
nursery.start_soon(client, listener)
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# verifies that all 6 streams x 2 directions each were closed ok
|
||||
assert len(record) == 12
|
||||
|
||||
parent_nursery.cancel_scope.cancel()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
l2: list[MemoryListener] = await nursery.start(
|
||||
trio.serve_listeners,
|
||||
handler,
|
||||
listeners,
|
||||
)
|
||||
assert l2 == listeners
|
||||
# This is just split into another function because gh-136 isn't
|
||||
# implemented yet
|
||||
nursery.start_soon(do_tests, nursery)
|
||||
|
||||
for listener in listeners:
|
||||
assert listener.closed
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_unrecognized_error() -> None:
|
||||
for error in [KeyError(), OSError(errno.ECONNABORTED, "ECONNABORTED")]:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_error() -> NoReturn:
|
||||
raise error # noqa: B023 # Set from loop
|
||||
|
||||
def check_error(e: BaseException) -> bool:
|
||||
return e is error # noqa: B023
|
||||
|
||||
listener.accept_hook = raise_error
|
||||
|
||||
with RaisesGroup(Matcher(check=check_error)):
|
||||
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_serve_listeners_accept_capacity_error(
|
||||
autojump_clock: MockClock,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def raise_EMFILE() -> NoReturn:
|
||||
raise OSError(errno.EMFILE, "out of file descriptors")
|
||||
|
||||
listener.accept_hook = raise_EMFILE
|
||||
|
||||
# It retries every 100 ms, so in 950 ms it will retry at 0, 100, ..., 900
|
||||
# = 10 times total
|
||||
with trio.move_on_after(0.950):
|
||||
await trio.serve_listeners(None, [listener]) # type: ignore[arg-type]
|
||||
|
||||
assert len(caplog.records) == 10
|
||||
for record in caplog.records:
|
||||
assert "retrying" in record.msg
|
||||
assert record.exc_info is not None
|
||||
assert isinstance(record.exc_info[1], OSError)
|
||||
assert record.exc_info[1].errno == errno.EMFILE
|
||||
|
||||
|
||||
async def test_serve_listeners_connection_nursery(autojump_clock: MockClock) -> None:
|
||||
listener = MemoryListener()
|
||||
|
||||
async def handler(stream: Stream) -> None:
|
||||
await trio.sleep(1)
|
||||
|
||||
class Done(Exception):
|
||||
pass
|
||||
|
||||
async def connection_watcher(
|
||||
*,
|
||||
task_status: TaskStatus[Nursery] = trio.TASK_STATUS_IGNORED,
|
||||
) -> NoReturn:
|
||||
async with trio.open_nursery() as nursery:
|
||||
task_status.started(nursery)
|
||||
await wait_all_tasks_blocked()
|
||||
assert len(nursery.child_tasks) == 10
|
||||
raise Done
|
||||
|
||||
# the exception is wrapped twice because we open two nested nurseries
|
||||
with RaisesGroup(RaisesGroup(Done)):
|
||||
async with trio.open_nursery() as nursery:
|
||||
handler_nursery: trio.Nursery = await nursery.start(connection_watcher)
|
||||
await nursery.start(
|
||||
partial(
|
||||
trio.serve_listeners,
|
||||
handler,
|
||||
[listener],
|
||||
handler_nursery=handler_nursery,
|
||||
),
|
||||
)
|
||||
for _ in range(10):
|
||||
nursery.start_soon(listener.connect)
|
@ -0,0 +1,330 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import socket as stdlib_socket
|
||||
import sys
|
||||
from typing import Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core, socket as tsocket
|
||||
from .._highlevel_socket import *
|
||||
from ..testing import (
|
||||
assert_checkpoints,
|
||||
check_half_closeable_stream,
|
||||
wait_all_tasks_blocked,
|
||||
)
|
||||
from .test_socket import setsockopt_tests
|
||||
|
||||
|
||||
async def test_SocketStream_basics() -> None:
|
||||
# stdlib socket bad (even if connected)
|
||||
stdlib_a, stdlib_b = stdlib_socket.socketpair()
|
||||
with stdlib_a, stdlib_b:
|
||||
with pytest.raises(TypeError):
|
||||
SocketStream(stdlib_a) # type: ignore[arg-type]
|
||||
|
||||
# DGRAM socket bad
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^SocketStream requires a SOCK_STREAM socket$",
|
||||
):
|
||||
# TODO: does not raise an error?
|
||||
SocketStream(sock)
|
||||
|
||||
a, b = tsocket.socketpair()
|
||||
with a, b:
|
||||
s = SocketStream(a)
|
||||
assert s.socket is a
|
||||
|
||||
# Use a real, connected socket to test socket options, because
|
||||
# socketpair() might give us a unix socket that doesn't support any of
|
||||
# these options
|
||||
with tsocket.socket() as listen_sock:
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(1)
|
||||
with tsocket.socket() as client_sock:
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
|
||||
s = SocketStream(client_sock)
|
||||
|
||||
# TCP_NODELAY enabled by default
|
||||
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
# We can disable it though
|
||||
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
|
||||
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||||
|
||||
res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
|
||||
assert isinstance(res, bytes)
|
||||
|
||||
setsockopt_tests(s)
|
||||
|
||||
|
||||
async def test_SocketStream_send_all() -> None:
|
||||
BIG = 10000000
|
||||
|
||||
a_sock, b_sock = tsocket.socketpair()
|
||||
with a_sock, b_sock:
|
||||
a = SocketStream(a_sock)
|
||||
b = SocketStream(b_sock)
|
||||
|
||||
# Check a send_all that has to be split into multiple parts (on most
|
||||
# platforms... on Windows every send() either succeeds or fails as a
|
||||
# whole)
|
||||
async def sender() -> None:
|
||||
data = bytearray(BIG)
|
||||
await a.send_all(data)
|
||||
# send_all uses memoryviews internally, which temporarily "lock"
|
||||
# the object they view. If it doesn't clean them up properly, then
|
||||
# some bytearray operations might raise an error afterwards, which
|
||||
# would be a pretty weird and annoying side-effect to spring on
|
||||
# users. So test that this doesn't happen, by forcing the
|
||||
# bytearray's underlying buffer to be realloc'ed:
|
||||
data += bytes(BIG)
|
||||
# (Note: the above line of code doesn't do a very good job at
|
||||
# testing anything, because:
|
||||
# - on CPython, the refcount GC generally cleans up memoryviews
|
||||
# for us even if we're sloppy.
|
||||
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
|
||||
# bytearray code conspire so that resizing never fails – if
|
||||
# resizing forces the bytearray's internal buffer to move, then
|
||||
# all memoryview references are automagically updated (!!).
|
||||
# See:
|
||||
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
|
||||
# But I'm leaving the test here in hopes that if this ever changes
|
||||
# and we break our implementation of send_all, then we'll get some
|
||||
# early warning...)
|
||||
|
||||
async def receiver() -> None:
|
||||
# Make sure the sender fills up the kernel buffers and blocks
|
||||
await wait_all_tasks_blocked()
|
||||
nbytes = 0
|
||||
while nbytes < BIG:
|
||||
nbytes += len(await b.receive_some(BIG))
|
||||
assert nbytes == BIG
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(receiver)
|
||||
|
||||
# We know that we received BIG bytes of NULs so far. Make sure that
|
||||
# was all the data in there.
|
||||
await a.send_all(b"e")
|
||||
assert await b.receive_some(10) == b"e"
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
|
||||
async def fill_stream(s: SocketStream) -> None:
|
||||
async def sender() -> None:
|
||||
while True:
|
||||
await s.send_all(b"x" * 10000)
|
||||
|
||||
async def waiter(nursery: _core.Nursery) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(sender)
|
||||
nursery.start_soon(waiter, nursery)
|
||||
|
||||
|
||||
async def test_SocketStream_generic() -> None:
|
||||
async def stream_maker() -> tuple[SocketStream, SocketStream]:
|
||||
left, right = tsocket.socketpair()
|
||||
return SocketStream(left), SocketStream(right)
|
||||
|
||||
async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]:
|
||||
left, right = await stream_maker()
|
||||
await fill_stream(left)
|
||||
await fill_stream(right)
|
||||
return left, right
|
||||
|
||||
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
|
||||
|
||||
|
||||
async def test_SocketListener() -> None:
|
||||
# Not a Trio socket
|
||||
with stdlib_socket.socket() as s:
|
||||
s.bind(("127.0.0.1", 0))
|
||||
s.listen(10)
|
||||
with pytest.raises(TypeError):
|
||||
SocketListener(s) # type: ignore[arg-type]
|
||||
|
||||
# Not a SOCK_STREAM
|
||||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^SocketListener requires a SOCK_STREAM socket$",
|
||||
) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*SOCK_STREAM")
|
||||
|
||||
# Didn't call .listen()
|
||||
# macOS has no way to check for this, so skip testing it there.
|
||||
if sys.platform != "darwin":
|
||||
with tsocket.socket() as s:
|
||||
await s.bind(("127.0.0.1", 0))
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^SocketListener requires a listening socket$",
|
||||
) as excinfo:
|
||||
SocketListener(s)
|
||||
excinfo.match(r".*listen")
|
||||
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
assert listener.socket is listen_sock
|
||||
|
||||
client_sock = tsocket.socket()
|
||||
await client_sock.connect(listen_sock.getsockname())
|
||||
with assert_checkpoints():
|
||||
server_stream = await listener.accept()
|
||||
assert isinstance(server_stream, SocketStream)
|
||||
assert server_stream.socket.getsockname() == listen_sock.getsockname()
|
||||
assert server_stream.socket.getpeername() == client_sock.getsockname()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
await listener.aclose()
|
||||
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
client_sock.close()
|
||||
await server_stream.aclose()
|
||||
|
||||
|
||||
async def test_SocketListener_socket_closed_underfoot() -> None:
|
||||
listen_sock = tsocket.socket()
|
||||
await listen_sock.bind(("127.0.0.1", 0))
|
||||
listen_sock.listen(10)
|
||||
listener = SocketListener(listen_sock)
|
||||
|
||||
# Close the socket, not the listener
|
||||
listen_sock.close()
|
||||
|
||||
# SocketListener gives correct error
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await listener.accept()
|
||||
|
||||
|
||||
async def test_SocketListener_accept_errors() -> None:
|
||||
class FakeSocket(tsocket.SocketType):
|
||||
def __init__(self, events: Sequence[SocketType | BaseException]) -> None:
|
||||
self._events = iter(events)
|
||||
|
||||
type = tsocket.SOCK_STREAM
|
||||
|
||||
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
|
||||
@overload
|
||||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||||
|
||||
@overload
|
||||
def getsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int,
|
||||
) -> bytes: ...
|
||||
|
||||
def getsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
buflen: int | None = None,
|
||||
) -> int | bytes:
|
||||
return True
|
||||
|
||||
@overload
|
||||
def setsockopt(
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def setsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: None,
|
||||
optlen: int,
|
||||
) -> None: ...
|
||||
|
||||
def setsockopt( # noqa: F811
|
||||
self,
|
||||
/,
|
||||
level: int,
|
||||
optname: int,
|
||||
value: int | Buffer | None,
|
||||
optlen: int | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def accept(self) -> tuple[SocketType, object]:
|
||||
await _core.checkpoint()
|
||||
event = next(self._events)
|
||||
if isinstance(event, BaseException):
|
||||
raise event
|
||||
else:
|
||||
return event, None
|
||||
|
||||
fake_server_sock = FakeSocket([])
|
||||
|
||||
fake_listen_sock = FakeSocket(
|
||||
[
|
||||
OSError(errno.ECONNABORTED, "Connection aborted"),
|
||||
OSError(errno.EPERM, "Permission denied"),
|
||||
OSError(errno.EPROTO, "Bad protocol"),
|
||||
fake_server_sock,
|
||||
OSError(errno.EMFILE, "Out of file descriptors"),
|
||||
OSError(errno.EFAULT, "attempt to write to read-only memory"),
|
||||
OSError(errno.ENOBUFS, "out of buffers"),
|
||||
fake_server_sock,
|
||||
],
|
||||
)
|
||||
|
||||
listener = SocketListener(fake_listen_sock)
|
||||
|
||||
with assert_checkpoints():
|
||||
stream = await listener.accept()
|
||||
assert stream.socket is fake_server_sock
|
||||
|
||||
for code, match in {
|
||||
errno.EMFILE: r"\[\w+ \d+\] Out of file descriptors$",
|
||||
errno.EFAULT: r"\[\w+ \d+\] attempt to write to read-only memory$",
|
||||
errno.ENOBUFS: r"\[\w+ \d+\] out of buffers$",
|
||||
}.items():
|
||||
with assert_checkpoints():
|
||||
with pytest.raises(OSError, match=match) as excinfo:
|
||||
await listener.accept()
|
||||
assert excinfo.value.errno == code
|
||||
|
||||
with assert_checkpoints():
|
||||
stream = await listener.accept()
|
||||
assert stream.socket is fake_server_sock
|
||||
|
||||
|
||||
async def test_socket_stream_works_when_peer_has_already_closed() -> None:
|
||||
sock_a, sock_b = tsocket.socketpair()
|
||||
with sock_a, sock_b:
|
||||
await sock_b.send(b"x")
|
||||
sock_b.close()
|
||||
stream = SocketStream(sock_a)
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
assert await stream.receive_some(1) == b""
|
@ -0,0 +1,166 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, NoReturn
|
||||
|
||||
import attrs
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
import trio.testing
|
||||
from trio.socket import AF_INET, IPPROTO_TCP, SOCK_STREAM
|
||||
|
||||
from .._highlevel_ssl_helpers import (
|
||||
open_ssl_over_tcp_listeners,
|
||||
open_ssl_over_tcp_stream,
|
||||
serve_ssl_over_tcp,
|
||||
)
|
||||
|
||||
# using noqa because linters don't understand how pytest fixtures work.
|
||||
from .test_ssl import SERVER_CTX, client_ctx # noqa: F401
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from socket import AddressFamily, SocketKind
|
||||
from ssl import SSLContext
|
||||
|
||||
from trio.abc import Stream
|
||||
|
||||
from .._highlevel_socket import SocketListener
|
||||
from .._ssl import SSLListener
|
||||
|
||||
|
||||
async def echo_handler(stream: Stream) -> None:
|
||||
async with stream:
|
||||
try:
|
||||
while True:
|
||||
data = await stream.receive_some(10000)
|
||||
if not data:
|
||||
break
|
||||
await stream.send_all(data)
|
||||
except trio.BrokenResourceError:
|
||||
pass
|
||||
|
||||
|
||||
# Resolver that always returns the given sockaddr, no matter what host/port
|
||||
# you ask for.
|
||||
@attrs.define(slots=False)
|
||||
class FakeHostnameResolver(trio.abc.HostnameResolver):
|
||||
sockaddr: tuple[str, int] | tuple[str, int, int, int]
|
||||
|
||||
async def getaddrinfo(
|
||||
self,
|
||||
host: bytes | None,
|
||||
port: bytes | str | int | None,
|
||||
family: int = 0,
|
||||
type: int = 0,
|
||||
proto: int = 0,
|
||||
flags: int = 0,
|
||||
) -> list[
|
||||
tuple[
|
||||
AddressFamily,
|
||||
SocketKind,
|
||||
int,
|
||||
str,
|
||||
tuple[str, int] | tuple[str, int, int, int],
|
||||
]
|
||||
]:
|
||||
return [(AF_INET, SOCK_STREAM, IPPROTO_TCP, "", self.sockaddr)]
|
||||
|
||||
async def getnameinfo(self, *args: Any) -> NoReturn: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# This uses serve_ssl_over_tcp, which uses open_ssl_over_tcp_listeners...
|
||||
# using noqa because linters don't understand how pytest fixtures work.
|
||||
async def test_open_ssl_over_tcp_stream_and_everything_else(
|
||||
client_ctx: SSLContext, # noqa: F811 # linters doesn't understand fixture
|
||||
) -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
# TODO: this function wraps an SSLListener around a SocketListener, this is illegal
|
||||
# according to current type hints, and probably for good reason. But there should
|
||||
# maybe be a different wrapper class/function that could be used instead?
|
||||
res: list[SSLListener[SocketListener]] = ( # type: ignore[type-var]
|
||||
await nursery.start(
|
||||
partial(
|
||||
serve_ssl_over_tcp,
|
||||
echo_handler,
|
||||
0,
|
||||
SERVER_CTX,
|
||||
host="127.0.0.1",
|
||||
),
|
||||
)
|
||||
)
|
||||
(listener,) = res
|
||||
async with listener:
|
||||
# listener.transport_listener is of type Listener[Stream]
|
||||
tp_listener: SocketListener = listener.transport_listener # type: ignore[assignment]
|
||||
|
||||
sockaddr = tp_listener.socket.getsockname()
|
||||
hostname_resolver = FakeHostnameResolver(sockaddr)
|
||||
trio.socket.set_custom_hostname_resolver(hostname_resolver)
|
||||
|
||||
# We don't have the right trust set up
|
||||
# (checks that ssl_context=None is doing some validation)
|
||||
stream = await open_ssl_over_tcp_stream("trio-test-1.example.org", 80)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# We have the trust but not the hostname
|
||||
# (checks custom ssl_context + hostname checking)
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"xyzzy.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
)
|
||||
async with stream:
|
||||
with pytest.raises(trio.BrokenResourceError):
|
||||
await stream.do_handshake()
|
||||
|
||||
# This one should work!
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
)
|
||||
async with stream:
|
||||
assert isinstance(stream, trio.SSLStream)
|
||||
assert stream.server_hostname == "trio-test-1.example.org"
|
||||
await stream.send_all(b"x")
|
||||
assert await stream.receive_some(1) == b"x"
|
||||
|
||||
# Check https_compatible settings are being passed through
|
||||
assert not stream._https_compatible
|
||||
stream = await open_ssl_over_tcp_stream(
|
||||
"trio-test-1.example.org",
|
||||
80,
|
||||
ssl_context=client_ctx,
|
||||
https_compatible=True,
|
||||
# also, smoke test happy_eyeballs_delay
|
||||
happy_eyeballs_delay=1,
|
||||
)
|
||||
async with stream:
|
||||
assert stream._https_compatible
|
||||
|
||||
# Stop the echo server
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_open_ssl_over_tcp_listeners() -> None:
|
||||
(listener,) = await open_ssl_over_tcp_listeners(0, SERVER_CTX, host="127.0.0.1")
|
||||
async with listener:
|
||||
assert isinstance(listener, trio.SSLListener)
|
||||
tl = listener.transport_listener
|
||||
assert isinstance(tl, trio.SocketListener)
|
||||
assert tl.socket.getsockname()[0] == "127.0.0.1"
|
||||
|
||||
assert not listener._https_compatible
|
||||
|
||||
(listener,) = await open_ssl_over_tcp_listeners(
|
||||
0,
|
||||
SERVER_CTX,
|
||||
host="127.0.0.1",
|
||||
https_compatible=True,
|
||||
)
|
||||
async with listener:
|
||||
assert listener._https_compatible
|
275
lib/python3.13/site-packages/trio/_tests/test_path.py
Normal file
275
lib/python3.13/site-packages/trio/_tests/test_path.py
Normal file
@ -0,0 +1,275 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
from typing import TYPE_CHECKING, Type, Union
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio._file_io import AsyncIOWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def path(tmp_path: pathlib.Path) -> trio.Path:
|
||||
return trio.Path(tmp_path / "test")
|
||||
|
||||
|
||||
def method_pair(
|
||||
path: str,
|
||||
method_name: str,
|
||||
) -> tuple[Callable[[], object], Callable[[], Awaitable[object]]]:
|
||||
sync_path = pathlib.Path(path)
|
||||
async_path = trio.Path(path)
|
||||
return getattr(sync_path, method_name), getattr(async_path, method_name)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name == "nt", reason="OS is not posix")
|
||||
async def test_instantiate_posix() -> None:
|
||||
assert isinstance(trio.Path(), trio.PosixPath)
|
||||
|
||||
|
||||
@pytest.mark.skipif(os.name != "nt", reason="OS is not Windows")
|
||||
async def test_instantiate_windows() -> None:
|
||||
assert isinstance(trio.Path(), trio.WindowsPath)
|
||||
|
||||
|
||||
async def test_open_is_async_context_manager(path: trio.Path) -> None:
|
||||
async with await path.open("w") as f:
|
||||
assert isinstance(f, AsyncIOWrapper)
|
||||
|
||||
assert f.closed
|
||||
|
||||
|
||||
async def test_magic() -> None:
|
||||
path = trio.Path("test")
|
||||
|
||||
assert str(path) == "test"
|
||||
assert bytes(path) == b"test"
|
||||
|
||||
|
||||
EitherPathType = Union[Type[trio.Path], Type[pathlib.Path]]
|
||||
PathOrStrType = Union[EitherPathType, Type[str]]
|
||||
cls_pairs: list[tuple[EitherPathType, EitherPathType]] = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(pathlib.Path, trio.Path),
|
||||
(trio.Path, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs)
|
||||
async def test_cmp_magic(cls_a: EitherPathType, cls_b: EitherPathType) -> None:
|
||||
a, b = cls_a(""), cls_b("")
|
||||
assert a == b
|
||||
assert not a != b # noqa: SIM202 # negate-not-equal-op
|
||||
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
assert a < b
|
||||
assert b > a
|
||||
|
||||
# this is intentionally testing equivalence with none, due to the
|
||||
# other=sentinel logic in _forward_magic
|
||||
assert not a == None # noqa
|
||||
assert not b == None # noqa
|
||||
|
||||
|
||||
# upstream python3.8 bug: we should also test (pathlib.Path, trio.Path), but
|
||||
# __*div__ does not properly raise NotImplementedError like the other comparison
|
||||
# magic, so trio.Path's implementation does not get dispatched
|
||||
cls_pairs_str: list[tuple[PathOrStrType, PathOrStrType]] = [
|
||||
(trio.Path, pathlib.Path),
|
||||
(trio.Path, trio.Path),
|
||||
(trio.Path, str),
|
||||
(str, trio.Path),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("cls_a", "cls_b"), cls_pairs_str)
|
||||
async def test_div_magic(cls_a: PathOrStrType, cls_b: PathOrStrType) -> None:
|
||||
a, b = cls_a("a"), cls_b("b")
|
||||
|
||||
result = a / b # type: ignore[operator]
|
||||
# Type checkers think str / str could happen. Check each combo manually in type_tests/.
|
||||
assert isinstance(result, trio.Path)
|
||||
assert str(result) == os.path.join("a", "b")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("cls_a", "cls_b"),
|
||||
[(trio.Path, pathlib.Path), (trio.Path, trio.Path)],
|
||||
)
|
||||
@pytest.mark.parametrize("path", ["foo", "foo/bar/baz", "./foo"])
|
||||
async def test_hash_magic(
|
||||
cls_a: EitherPathType,
|
||||
cls_b: EitherPathType,
|
||||
path: str,
|
||||
) -> None:
|
||||
a, b = cls_a(path), cls_b(path)
|
||||
assert hash(a) == hash(b)
|
||||
|
||||
|
||||
async def test_forwarded_properties(path: trio.Path) -> None:
|
||||
# use `name` as a representative of forwarded properties
|
||||
|
||||
assert "name" in dir(path)
|
||||
assert path.name == "test"
|
||||
|
||||
|
||||
async def test_async_method_signature(path: trio.Path) -> None:
|
||||
# use `resolve` as a representative of wrapped methods
|
||||
|
||||
assert path.resolve.__name__ == "resolve"
|
||||
assert path.resolve.__qualname__ == "Path.resolve"
|
||||
|
||||
assert path.resolve.__doc__ is not None
|
||||
assert path.resolve.__qualname__ in path.resolve.__doc__
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["is_dir", "is_file"])
|
||||
async def test_compare_async_stat_methods(method_name: str) -> None:
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert result == async_result
|
||||
|
||||
|
||||
async def test_invalid_name_not_wrapped(path: trio.Path) -> None:
|
||||
with pytest.raises(AttributeError):
|
||||
getattr(path, "invalid_fake_attr") # noqa: B009 # "get-attr-with-constant"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method_name", ["absolute", "resolve"])
|
||||
async def test_async_methods_rewrap(method_name: str) -> None:
|
||||
method, async_method = method_pair(".", method_name)
|
||||
|
||||
result = method()
|
||||
async_result = await async_method()
|
||||
|
||||
assert isinstance(async_result, trio.Path)
|
||||
assert str(result) == str(async_result)
|
||||
|
||||
|
||||
async def test_forward_methods_rewrap(path: trio.Path, tmp_path: pathlib.Path) -> None:
|
||||
with_name = path.with_name("foo")
|
||||
with_suffix = path.with_suffix(".py")
|
||||
|
||||
assert isinstance(with_name, trio.Path)
|
||||
assert with_name == tmp_path / "foo"
|
||||
assert isinstance(with_suffix, trio.Path)
|
||||
assert with_suffix == tmp_path / "test.py"
|
||||
|
||||
|
||||
async def test_forward_properties_rewrap(path: trio.Path) -> None:
|
||||
assert isinstance(path.parent, trio.Path)
|
||||
|
||||
|
||||
async def test_forward_methods_without_rewrap(path: trio.Path) -> None:
|
||||
path = await path.parent.resolve()
|
||||
|
||||
assert path.as_uri().startswith("file:///")
|
||||
|
||||
|
||||
async def test_repr() -> None:
|
||||
path = trio.Path(".")
|
||||
|
||||
assert repr(path) == "trio.Path('.')"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("meth", [trio.Path.__init__, trio.Path.joinpath])
|
||||
async def test_path_wraps_path(
|
||||
path: trio.Path,
|
||||
meth: Callable[[trio.Path, trio.Path], object],
|
||||
) -> None:
|
||||
wrapped = await path.absolute()
|
||||
result = meth(path, wrapped)
|
||||
if result is None:
|
||||
result = path
|
||||
|
||||
assert wrapped == result
|
||||
|
||||
|
||||
async def test_path_nonpath() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
trio.Path(1) # type: ignore
|
||||
|
||||
|
||||
async def test_open_file_can_open_path(path: trio.Path) -> None:
|
||||
async with await trio.open_file(path, "w") as f:
|
||||
assert f.name == os.fspath(path)
|
||||
|
||||
|
||||
async def test_globmethods(path: trio.Path) -> None:
|
||||
# Populate a directory tree
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "foo" / "_bar.txt").write_bytes(b"")
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
await (path / "bar.dat").write_bytes(b"")
|
||||
|
||||
# Path.glob
|
||||
for _pattern, _results in {
|
||||
"*.txt": {"bar.txt"},
|
||||
"**/*.txt": {"_bar.txt", "bar.txt"},
|
||||
}.items():
|
||||
entries = set()
|
||||
for entry in await path.glob(_pattern):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == _results
|
||||
|
||||
# Path.rglob
|
||||
entries = set()
|
||||
for entry in await path.rglob("*.txt"):
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"_bar.txt", "bar.txt"}
|
||||
|
||||
|
||||
async def test_iterdir(path: trio.Path) -> None:
|
||||
# Populate a directory
|
||||
await path.mkdir()
|
||||
await (path / "foo").mkdir()
|
||||
await (path / "bar.txt").write_bytes(b"")
|
||||
|
||||
entries = set()
|
||||
for entry in await path.iterdir():
|
||||
assert isinstance(entry, trio.Path)
|
||||
entries.add(entry.name)
|
||||
|
||||
assert entries == {"bar.txt", "foo"}
|
||||
|
||||
|
||||
async def test_classmethods() -> None:
|
||||
assert isinstance(await trio.Path.home(), trio.Path)
|
||||
|
||||
# pathlib.Path has only two classmethods
|
||||
assert str(await trio.Path.home()) == os.path.expanduser("~")
|
||||
assert str(await trio.Path.cwd()) == os.getcwd()
|
||||
|
||||
# Wrapped method has docstring
|
||||
assert trio.Path.home.__doc__
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wrapper",
|
||||
[
|
||||
trio._path._wraps_async,
|
||||
trio._path._wrap_method,
|
||||
trio._path._wrap_method_path,
|
||||
trio._path._wrap_method_path_iterable,
|
||||
],
|
||||
)
|
||||
def test_wrapping_without_docstrings(
|
||||
wrapper: Callable[[Callable[[], None]], Callable[[], None]],
|
||||
) -> None:
|
||||
@wrapper
|
||||
def func_without_docstring() -> None: ... # pragma: no cover
|
||||
|
||||
assert func_without_docstring.__doc__ is None
|
242
lib/python3.13/site-packages/trio/_tests/test_repl.py
Normal file
242
lib/python3.13/site-packages/trio/_tests/test_repl.py
Normal file
@ -0,0 +1,242 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Protocol
|
||||
|
||||
import pytest
|
||||
|
||||
import trio._repl
|
||||
|
||||
|
||||
class RawInput(Protocol):
|
||||
def __call__(self, prompt: str = "") -> str: ...
|
||||
|
||||
|
||||
def build_raw_input(cmds: list[str]) -> RawInput:
|
||||
"""
|
||||
Pass in a list of strings.
|
||||
Returns a callable that returns each string, each time its called
|
||||
When there are not more strings to return, raise EOFError
|
||||
"""
|
||||
cmds_iter = iter(cmds)
|
||||
prompts = []
|
||||
|
||||
def _raw_helper(prompt: str = "") -> str:
|
||||
prompts.append(prompt)
|
||||
try:
|
||||
return next(cmds_iter)
|
||||
except StopIteration:
|
||||
raise EOFError from None
|
||||
|
||||
return _raw_helper
|
||||
|
||||
|
||||
def test_build_raw_input() -> None:
|
||||
"""Quick test of our helper function."""
|
||||
raw_input = build_raw_input(["cmd1"])
|
||||
assert raw_input() == "cmd1"
|
||||
with pytest.raises(EOFError):
|
||||
raw_input()
|
||||
|
||||
|
||||
# In 3.10 or later, types.FunctionType (used internally) will automatically
|
||||
# attach __builtins__ to the function objects. However we need to explicitly
|
||||
# include it for 3.8 & 3.9
|
||||
def build_locals() -> dict[str, object]:
|
||||
return {"__builtins__": __builtins__}
|
||||
|
||||
|
||||
async def test_basic_interaction(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""
|
||||
Run some basic commands through the interpreter while capturing stdout.
|
||||
Ensure that the interpreted prints the expected results.
|
||||
"""
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# evaluate simple expression and recall the value
|
||||
"x = 1",
|
||||
"print(f'{x=}')",
|
||||
# Literal gets printed
|
||||
"'hello'",
|
||||
# define and call sync function
|
||||
"def func():",
|
||||
" print(x + 1)",
|
||||
"",
|
||||
"func()",
|
||||
# define and call async function
|
||||
"async def afunc():",
|
||||
" return 4",
|
||||
"",
|
||||
"await afunc()",
|
||||
# import works
|
||||
"import sys",
|
||||
"sys.stdout.write('hello stdout\\n')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert out.splitlines() == ["x=1", "'hello'", "2", "4", "hello stdout", "13"]
|
||||
|
||||
|
||||
async def test_system_exits_quit_interpreter(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"raise SystemExit",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
with pytest.raises(SystemExit):
|
||||
await trio._repl.run_repl(console)
|
||||
|
||||
|
||||
async def test_KI_interrupts(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"from trio._util import signal_raise",
|
||||
"import signal, trio, trio.lowlevel",
|
||||
"async def f():",
|
||||
" trio.lowlevel.spawn_system_task("
|
||||
" trio.to_thread.run_sync,"
|
||||
" signal_raise,signal.SIGINT,"
|
||||
" )", # just awaiting this kills the test runner?!
|
||||
" await trio.sleep_forever()",
|
||||
" print('should not see this')",
|
||||
"",
|
||||
"await f()",
|
||||
"print('AFTER KeyboardInterrupt')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "KeyboardInterrupt" in err
|
||||
assert "should" not in out
|
||||
assert "AFTER KeyboardInterrupt" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup('', [RuntimeError(), SystemExit()])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_system_exits_in_nested_exc_group(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"import sys",
|
||||
"if sys.version_info < (3, 11):",
|
||||
" from exceptiongroup import BaseExceptionGroup",
|
||||
"",
|
||||
"raise BaseExceptionGroup(",
|
||||
" '', [BaseExceptionGroup('', [RuntimeError(), SystemExit()])])",
|
||||
"print('AFTER BaseExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
# assert that raise SystemExit in an exception group
|
||||
# doesn't quit
|
||||
assert "AFTER BaseExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise BaseException",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
async def test_exc_group_captured(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
# The statement after raise should still get executed
|
||||
"raise ExceptionGroup('', [KeyError()])",
|
||||
"print('AFTER ExceptionGroup')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "AFTER ExceptionGroup" in out
|
||||
|
||||
|
||||
async def test_base_exception_capture_from_coroutine(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
console = trio._repl.TrioInteractiveConsole(repl_locals=build_locals())
|
||||
raw_input = build_raw_input(
|
||||
[
|
||||
"async def async_func_raises_base_exception():",
|
||||
" raise BaseException",
|
||||
"",
|
||||
# This will raise, but the statement after should still
|
||||
# be executed
|
||||
"await async_func_raises_base_exception()",
|
||||
"print('AFTER BaseException')",
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(console, "raw_input", raw_input)
|
||||
await trio._repl.run_repl(console)
|
||||
out, err = capsys.readouterr()
|
||||
assert "_threads.py" not in err
|
||||
assert "_repl.py" not in err
|
||||
assert "AFTER BaseException" in out
|
||||
|
||||
|
||||
def test_main_entrypoint() -> None:
|
||||
"""
|
||||
Basic smoke test when running via the package __main__ entrypoint.
|
||||
"""
|
||||
repl = subprocess.run([sys.executable, "-m", "trio"], input=b"exit()")
|
||||
assert repl.returncode == 0
|
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import trio
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pytest
|
||||
|
||||
|
||||
async def scheduler_trace() -> tuple[tuple[str, int], ...]:
|
||||
"""Returns a scheduler-dependent value we can use to check determinism."""
|
||||
trace = []
|
||||
|
||||
async def tracer(name: str) -> None:
|
||||
for i in range(50):
|
||||
trace.append((name, i))
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(tracer, str(i))
|
||||
|
||||
return tuple(trace)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_not_deterministic() -> None:
|
||||
# At least, not yet. See https://github.com/python-trio/trio/issues/32
|
||||
traces = [trio.run(scheduler_trace) for _ in range(10)]
|
||||
assert len(set(traces)) == len(traces)
|
||||
|
||||
|
||||
def test_the_trio_scheduler_is_deterministic_if_seeded(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(trio._core._run, "_ALLOW_DETERMINISTIC_SCHEDULING", True)
|
||||
traces = []
|
||||
for _ in range(10):
|
||||
state = trio._core._run._r.getstate()
|
||||
try:
|
||||
trio._core._run._r.seed(0)
|
||||
traces.append(trio.run(scheduler_trace))
|
||||
finally:
|
||||
trio._core._run._r.setstate(state)
|
||||
|
||||
assert len(traces) == 10
|
||||
assert len(set(traces)) == 1
|
188
lib/python3.13/site-packages/trio/_tests/test_signals.py
Normal file
188
lib/python3.13/site-packages/trio/_tests/test_signals.py
Normal file
@ -0,0 +1,188 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import signal
|
||||
from typing import TYPE_CHECKING, NoReturn
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import RaisesGroup
|
||||
|
||||
from .. import _core
|
||||
from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver
|
||||
from .._util import signal_raise
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
|
||||
async def test_open_signal_receiver() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
# Raise it a few times, to exercise signal coalescing, both at the
|
||||
# call_soon level and at the SignalQueue level
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
signal_raise(signal.SIGILL)
|
||||
await _core.wait_all_tasks_blocked()
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert get_pending_signal_count(receiver) == 0
|
||||
signal_raise(signal.SIGILL)
|
||||
async for signum in receiver: # pragma: no branch
|
||||
assert signum == signal.SIGILL
|
||||
break
|
||||
assert get_pending_signal_count(receiver) == 0
|
||||
with pytest.raises(RuntimeError):
|
||||
await receiver.__anext__()
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="(signal number out of range|invalid signal value)$",
|
||||
):
|
||||
with open_signal_receiver(signal.SIGILL, 1234567):
|
||||
pass # pragma: no cover
|
||||
# Still restored even if we errored out
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_open_signal_receiver_empty_fail() -> None:
|
||||
with pytest.raises(TypeError, match="No signals were provided"):
|
||||
with open_signal_receiver():
|
||||
pass
|
||||
|
||||
|
||||
async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None:
|
||||
orig = signal.getsignal(signal.SIGILL)
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGILL):
|
||||
pass
|
||||
# Still restored correctly
|
||||
assert signal.getsignal(signal.SIGILL) is orig
|
||||
|
||||
|
||||
async def test_catch_signals_wrong_thread() -> None:
|
||||
async def naughty() -> None:
|
||||
with open_signal_receiver(signal.SIGINT):
|
||||
pass # pragma: no cover
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await trio.to_thread.run_sync(trio.run, naughty)
|
||||
|
||||
|
||||
async def test_open_signal_receiver_conflict() -> None:
|
||||
with RaisesGroup(trio.BusyResourceError):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
nursery.start_soon(receiver.__anext__)
|
||||
|
||||
|
||||
# Blocks until all previous calls to run_sync_soon(idempotent=True) have been
|
||||
# processed.
|
||||
async def wait_run_sync_soon_idempotent_queue_barrier() -> None:
|
||||
ev = trio.Event()
|
||||
token = _core.current_trio_token()
|
||||
token.run_sync_soon(ev.set, idempotent=True)
|
||||
await ev.wait()
|
||||
|
||||
|
||||
async def test_open_signal_receiver_no_starvation() -> None:
|
||||
# Set up a situation where there are always 2 pending signals available to
|
||||
# report, and make sure that instead of getting the same signal reported
|
||||
# over and over, it alternates between reporting both of them.
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
try:
|
||||
print(signal.getsignal(signal.SIGILL))
|
||||
previous = None
|
||||
for _ in range(10):
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
if previous is None:
|
||||
previous = await receiver.__anext__()
|
||||
else:
|
||||
got = await receiver.__anext__()
|
||||
assert got in [signal.SIGILL, signal.SIGFPE]
|
||||
assert got != previous
|
||||
previous = got
|
||||
# Clear out the last signal so that it doesn't get redelivered
|
||||
while get_pending_signal_count(receiver) != 0:
|
||||
await receiver.__anext__()
|
||||
except BaseException: # pragma: no cover
|
||||
# If there's an unhandled exception above, then exiting the
|
||||
# open_signal_receiver block might cause the signal to be
|
||||
# redelivered and give us a core dump instead of a traceback...
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def test_catch_signals_race_condition_on_exit() -> None:
|
||||
delivered_directly: set[int] = set()
|
||||
|
||||
def direct_handler(signo: int, frame: FrameType | None) -> None:
|
||||
delivered_directly.add(signo)
|
||||
|
||||
print(1)
|
||||
# Test the version where the call_soon *doesn't* have a chance to run
|
||||
# before we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
print(2)
|
||||
# Test the version where the call_soon *does* have a chance to run before
|
||||
# we exit the with block:
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler):
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 2
|
||||
assert delivered_directly == {signal.SIGILL, signal.SIGFPE}
|
||||
delivered_directly.clear()
|
||||
|
||||
# Again, but with a SIG_IGN signal:
|
||||
|
||||
print(3)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
print(4)
|
||||
with _signal_handler({signal.SIGILL}, signal.SIG_IGN):
|
||||
with open_signal_receiver(signal.SIGILL) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 1
|
||||
# test passes if the process reaches this point without dying
|
||||
|
||||
# Check exception chaining if there are multiple exception-raising
|
||||
# handlers
|
||||
def raise_handler(signum: int, frame: FrameType | None) -> NoReturn:
|
||||
raise RuntimeError(signum)
|
||||
|
||||
with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler):
|
||||
with pytest.raises(RuntimeError) as excinfo:
|
||||
with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver:
|
||||
signal_raise(signal.SIGILL)
|
||||
signal_raise(signal.SIGFPE)
|
||||
await wait_run_sync_soon_idempotent_queue_barrier()
|
||||
assert get_pending_signal_count(receiver) == 2
|
||||
exc = excinfo.value
|
||||
signums = {exc.args[0]}
|
||||
assert isinstance(exc.__context__, RuntimeError)
|
||||
signums.add(exc.__context__.args[0])
|
||||
assert signums == {signal.SIGILL, signal.SIGFPE}
|
1173
lib/python3.13/site-packages/trio/_tests/test_socket.py
Normal file
1173
lib/python3.13/site-packages/trio/_tests/test_socket.py
Normal file
File diff suppressed because it is too large
Load Diff
1366
lib/python3.13/site-packages/trio/_tests/test_ssl.py
Normal file
1366
lib/python3.13/site-packages/trio/_tests/test_ssl.py
Normal file
File diff suppressed because it is too large
Load Diff
696
lib/python3.13/site-packages/trio/_tests/test_subprocess.py
Normal file
696
lib/python3.13/site-packages/trio/_tests/test_subprocess.py
Normal file
@ -0,0 +1,696 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path as SyncPath
|
||||
from signal import Signals
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncContextManager,
|
||||
AsyncIterator,
|
||||
Callable,
|
||||
NoReturn,
|
||||
)
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
from .. import (
|
||||
Event,
|
||||
Process,
|
||||
_core,
|
||||
fail_after,
|
||||
move_on_after,
|
||||
run_process,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
)
|
||||
from .._core._tests.tutil import skip_if_fbsd_pipes_broken, slow
|
||||
from ..lowlevel import open_process
|
||||
from ..testing import MockClock, assert_no_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import FrameType
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from .._abc import ReceiveStream
|
||||
|
||||
if sys.platform == "win32":
|
||||
SignalType: TypeAlias = None
|
||||
else:
|
||||
SignalType: TypeAlias = Signals
|
||||
|
||||
SIGKILL: SignalType
|
||||
SIGTERM: SignalType
|
||||
SIGUSR1: SignalType
|
||||
|
||||
posix = os.name == "posix"
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
from signal import SIGKILL, SIGTERM, SIGUSR1
|
||||
else:
|
||||
SIGKILL, SIGTERM, SIGUSR1 = None, None, None
|
||||
|
||||
|
||||
# Since Windows has very few command-line utilities generally available,
|
||||
# all of our subprocesses are Python processes running short bits of
|
||||
# (mostly) cross-platform code.
|
||||
def python(code: str) -> list[str]:
|
||||
return [sys.executable, "-u", "-c", "import sys; " + code]
|
||||
|
||||
|
||||
EXIT_TRUE = python("sys.exit(0)")
|
||||
EXIT_FALSE = python("sys.exit(1)")
|
||||
CAT = python("sys.stdout.buffer.write(sys.stdin.buffer.read())")
|
||||
|
||||
if posix:
|
||||
|
||||
def SLEEP(seconds: int) -> list[str]:
|
||||
return ["sleep", str(seconds)]
|
||||
|
||||
else:
|
||||
|
||||
def SLEEP(seconds: int) -> list[str]:
|
||||
return python(f"import time; time.sleep({seconds})")
|
||||
|
||||
|
||||
def got_signal(proc: Process, sig: SignalType) -> bool:
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
return proc.returncode == -sig
|
||||
else:
|
||||
return proc.returncode != 0
|
||||
|
||||
|
||||
@asynccontextmanager # type: ignore[misc] # Any in decorator
|
||||
async def open_process_then_kill(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
|
||||
proc = await open_process(*args, **kwargs)
|
||||
try:
|
||||
yield proc
|
||||
finally:
|
||||
proc.kill()
|
||||
await proc.wait()
|
||||
|
||||
|
||||
@asynccontextmanager # type: ignore[misc] # Any in decorator
|
||||
async def run_process_in_nursery(*args: Any, **kwargs: Any) -> AsyncIterator[Process]:
|
||||
async with _core.open_nursery() as nursery:
|
||||
kwargs.setdefault("check", False)
|
||||
proc: Process = await nursery.start(partial(run_process, *args, **kwargs))
|
||||
yield proc
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
background_process_param = pytest.mark.parametrize(
|
||||
"background_process",
|
||||
[open_process_then_kill, run_process_in_nursery],
|
||||
ids=["open_process", "run_process in nursery"],
|
||||
)
|
||||
|
||||
BackgroundProcessType: TypeAlias = Callable[..., AsyncContextManager[Process]]
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_basic(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(EXIT_TRUE) as proc:
|
||||
await proc.wait()
|
||||
assert isinstance(proc, Process)
|
||||
assert proc._pidfd is None
|
||||
assert proc.returncode == 0
|
||||
assert repr(proc) == f"<trio.Process {EXIT_TRUE}: exited with status 0>"
|
||||
|
||||
async with background_process(EXIT_FALSE) as proc:
|
||||
await proc.wait()
|
||||
assert proc.returncode == 1
|
||||
assert repr(proc) == "<trio.Process {!r}: {}>".format(
|
||||
EXIT_FALSE,
|
||||
"exited with status 1",
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_auto_update_returncode(
|
||||
background_process: BackgroundProcessType,
|
||||
) -> None:
|
||||
async with background_process(SLEEP(9999)) as p:
|
||||
assert p.returncode is None
|
||||
assert "running" in repr(p)
|
||||
p.kill()
|
||||
p._proc.wait()
|
||||
assert p.returncode is not None
|
||||
assert "exited" in repr(p)
|
||||
assert p._pidfd is None
|
||||
assert p.returncode is not None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_multi_wait(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(SLEEP(10)) as proc:
|
||||
# Check that wait (including multi-wait) tolerates being cancelled
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# Now try waiting for real
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
|
||||
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR = python(
|
||||
"data = sys.stdin.buffer.read(); "
|
||||
"sys.stdout.buffer.write(data); "
|
||||
"sys.stderr.buffer.write(data[::-1])",
|
||||
)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_pipes(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
msg = b"the quick brown fox jumps over the lazy dog"
|
||||
|
||||
async def feed_input() -> None:
|
||||
assert proc.stdin is not None
|
||||
await proc.stdin.send_all(msg)
|
||||
await proc.stdin.aclose()
|
||||
|
||||
async def check_output(stream: ReceiveStream, expected: bytes) -> None:
|
||||
seen = bytearray()
|
||||
async for chunk in stream:
|
||||
seen += chunk
|
||||
assert seen == expected
|
||||
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
# fail eventually if something is broken
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 30.0
|
||||
nursery.start_soon(feed_input)
|
||||
nursery.start_soon(check_output, proc.stdout, msg)
|
||||
nursery.start_soon(check_output, proc.stderr, msg[::-1])
|
||||
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert await proc.wait() == 0
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_interactive(background_process: BackgroundProcessType) -> None:
|
||||
# Test some back-and-forth with a subprocess. This one works like so:
|
||||
# in: 32\n
|
||||
# out: 0000...0000\n (32 zeroes)
|
||||
# err: 1111...1111\n (64 ones)
|
||||
# in: 10\n
|
||||
# out: 2222222222\n (10 twos)
|
||||
# err: 3333....3333\n (20 threes)
|
||||
# in: EOF
|
||||
# out: EOF
|
||||
# err: EOF
|
||||
|
||||
async with background_process(
|
||||
python(
|
||||
"idx = 0\n"
|
||||
"while True:\n"
|
||||
" line = sys.stdin.readline()\n"
|
||||
" if line == '': break\n"
|
||||
" request = int(line.strip())\n"
|
||||
" print(str(idx * 2) * request)\n"
|
||||
" print(str(idx * 2 + 1) * request * 2, file=sys.stderr)\n"
|
||||
" idx += 1\n",
|
||||
),
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
) as proc:
|
||||
newline = b"\n" if posix else b"\r\n"
|
||||
|
||||
async def expect(idx: int, request: int) -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
|
||||
async def drain_one(
|
||||
stream: ReceiveStream,
|
||||
count: int,
|
||||
digit: int,
|
||||
) -> None:
|
||||
while count > 0:
|
||||
result = await stream.receive_some(count)
|
||||
assert result == (f"{digit}".encode() * len(result))
|
||||
count -= len(result)
|
||||
assert count == 0
|
||||
assert await stream.receive_some(len(newline)) == newline
|
||||
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
nursery.start_soon(drain_one, proc.stdout, request, idx * 2)
|
||||
nursery.start_soon(drain_one, proc.stderr, request * 2, idx * 2 + 1)
|
||||
|
||||
assert proc.stdin is not None
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is not None
|
||||
with fail_after(5):
|
||||
await proc.stdin.send_all(b"12")
|
||||
await sleep(0.1)
|
||||
await proc.stdin.send_all(b"345" + newline)
|
||||
await expect(0, 12345)
|
||||
await proc.stdin.send_all(b"100" + newline + b"200" + newline)
|
||||
await expect(1, 100)
|
||||
await expect(2, 200)
|
||||
await proc.stdin.send_all(b"0" + newline)
|
||||
await expect(3, 0)
|
||||
await proc.stdin.send_all(b"999999")
|
||||
with move_on_after(0.1) as scope:
|
||||
await expect(4, 0)
|
||||
assert scope.cancelled_caught
|
||||
await proc.stdin.send_all(newline)
|
||||
await expect(4, 999999)
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.stdout.receive_some(1) == b""
|
||||
assert await proc.stderr.receive_some(1) == b""
|
||||
await proc.wait()
|
||||
|
||||
assert proc.returncode == 0
|
||||
|
||||
|
||||
async def test_run() -> None:
|
||||
data = bytes(random.randint(0, 255) for _ in range(2**18))
|
||||
|
||||
result = await run_process(
|
||||
CAT,
|
||||
stdin=data,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == b""
|
||||
|
||||
result = await run_process(CAT, capture_stdout=True)
|
||||
assert result.args == CAT
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b""
|
||||
assert result.stderr is None
|
||||
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=data,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
)
|
||||
assert result.args == COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == data
|
||||
assert result.stderr == data[::-1]
|
||||
|
||||
# invalid combinations
|
||||
with pytest.raises(UnicodeError):
|
||||
await run_process(CAT, stdin="oh no, it's text")
|
||||
|
||||
pipe_stdout_error = r"^stdout=subprocess\.PIPE is only valid with nursery\.start, since that's the only way to access the pipe(; use nursery\.start or pass the data you want to write directly)*$"
|
||||
with pytest.raises(ValueError, match=pipe_stdout_error):
|
||||
await run_process(CAT, stdin=subprocess.PIPE)
|
||||
with pytest.raises(ValueError, match=pipe_stdout_error):
|
||||
await run_process(CAT, stdout=subprocess.PIPE)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=pipe_stdout_error.replace("stdout", "stderr", 1),
|
||||
):
|
||||
await run_process(CAT, stderr=subprocess.PIPE)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^can't specify both stdout and capture_stdout$",
|
||||
):
|
||||
await run_process(CAT, capture_stdout=True, stdout=subprocess.DEVNULL)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^can't specify both stderr and capture_stderr$",
|
||||
):
|
||||
await run_process(CAT, capture_stderr=True, stderr=None)
|
||||
|
||||
|
||||
async def test_run_check() -> None:
|
||||
cmd = python("sys.stderr.buffer.write(b'test\\n'); sys.exit(1)")
|
||||
with pytest.raises(subprocess.CalledProcessError) as excinfo:
|
||||
await run_process(cmd, stdin=subprocess.DEVNULL, capture_stderr=True)
|
||||
assert excinfo.value.cmd == cmd
|
||||
assert excinfo.value.returncode == 1
|
||||
assert excinfo.value.stderr == b"test\n"
|
||||
assert excinfo.value.stdout is None
|
||||
|
||||
result = await run_process(
|
||||
cmd,
|
||||
capture_stdout=True,
|
||||
capture_stderr=True,
|
||||
check=False,
|
||||
)
|
||||
assert result.args == cmd
|
||||
assert result.stdout == b""
|
||||
assert result.stderr == b"test\n"
|
||||
assert result.returncode == 1
|
||||
|
||||
|
||||
@skip_if_fbsd_pipes_broken
|
||||
async def test_run_with_broken_pipe() -> None:
|
||||
result = await run_process(
|
||||
[sys.executable, "-c", "import sys; sys.stdin.close()"],
|
||||
stdin=b"x" * 131072,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout is result.stderr is None
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_stderr_stdout(background_process: BackgroundProcessType) -> None:
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
assert proc.stdio is not None
|
||||
assert proc.stdout is not None
|
||||
assert proc.stderr is None
|
||||
await proc.stdio.send_all(b"1234")
|
||||
await proc.stdio.send_eof()
|
||||
|
||||
output = []
|
||||
while True:
|
||||
chunk = await proc.stdio.receive_some(16)
|
||||
if chunk == b"":
|
||||
break
|
||||
output.append(chunk)
|
||||
assert b"".join(output) == b"12344321"
|
||||
assert proc.returncode == 0
|
||||
|
||||
# equivalent test with run_process()
|
||||
result = await run_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=b"1234",
|
||||
capture_stdout=True,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
assert result.returncode == 0
|
||||
assert result.stdout == b"12344321"
|
||||
assert result.stderr is None
|
||||
|
||||
# this one hits the branch where stderr=STDOUT but stdout
|
||||
# is not redirected
|
||||
async with background_process(
|
||||
CAT,
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.aclose()
|
||||
await proc.wait()
|
||||
assert proc.returncode == 0
|
||||
|
||||
if posix:
|
||||
try:
|
||||
r, w = os.pipe()
|
||||
|
||||
async with background_process(
|
||||
COPY_STDIN_TO_STDOUT_AND_BACKWARD_TO_STDERR,
|
||||
stdin=subprocess.PIPE,
|
||||
stdout=w,
|
||||
stderr=subprocess.STDOUT,
|
||||
) as proc:
|
||||
os.close(w)
|
||||
assert proc.stdio is None
|
||||
assert proc.stdout is None
|
||||
assert proc.stderr is None
|
||||
await proc.stdin.send_all(b"1234")
|
||||
await proc.stdin.aclose()
|
||||
assert await proc.wait() == 0
|
||||
assert os.read(r, 4096) == b"12344321"
|
||||
assert os.read(r, 4096) == b""
|
||||
finally:
|
||||
os.close(r)
|
||||
|
||||
|
||||
async def test_errors() -> None:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
# call-overload on unix, call-arg on windows
|
||||
await open_process(["ls"], encoding="utf-8") # type: ignore
|
||||
assert "unbuffered byte streams" in str(excinfo.value)
|
||||
assert "the 'encoding' option is not supported" in str(excinfo.value)
|
||||
|
||||
if posix:
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process(["ls"], shell=True)
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
await open_process("ls", shell=False)
|
||||
|
||||
|
||||
@background_process_param
|
||||
async def test_signals(background_process: BackgroundProcessType) -> None:
|
||||
async def test_one_signal(
|
||||
send_it: Callable[[Process], None],
|
||||
signum: signal.Signals | None,
|
||||
) -> None:
|
||||
with move_on_after(1.0) as scope:
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
send_it(proc)
|
||||
await proc.wait()
|
||||
assert not scope.cancelled_caught
|
||||
if posix:
|
||||
assert signum is not None
|
||||
assert proc.returncode == -signum
|
||||
else:
|
||||
assert proc.returncode != 0
|
||||
|
||||
await test_one_signal(Process.kill, SIGKILL)
|
||||
await test_one_signal(Process.terminate, SIGTERM)
|
||||
# Test that we can send arbitrary signals.
|
||||
#
|
||||
# We used to use SIGINT here, but it turns out that the Python interpreter
|
||||
# has race conditions that can cause it to explode in weird ways if it
|
||||
# tries to handle SIGINT during startup. SIGUSR1's default disposition is
|
||||
# to terminate the target process, and Python doesn't try to do anything
|
||||
# clever to handle it.
|
||||
if (not TYPE_CHECKING and posix) or sys.platform != "win32":
|
||||
await test_one_signal(lambda proc: proc.send_signal(SIGUSR1), SIGUSR1)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="POSIX specific")
|
||||
@background_process_param
|
||||
async def test_wait_reapable_fails(background_process: BackgroundProcessType) -> None:
|
||||
if TYPE_CHECKING and sys.platform == "win32":
|
||||
return
|
||||
old_sigchld = signal.signal(signal.SIGCHLD, signal.SIG_IGN)
|
||||
try:
|
||||
# With SIGCHLD disabled, the wait() syscall will wait for the
|
||||
# process to exit but then fail with ECHILD. Make sure we
|
||||
# support this case as the stdlib subprocess module does.
|
||||
async with background_process(SLEEP(3600)) as proc:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(proc.wait)
|
||||
await wait_all_tasks_blocked()
|
||||
proc.kill()
|
||||
nursery.cancel_scope.deadline = _core.current_time() + 1.0
|
||||
assert not nursery.cancel_scope.cancelled_caught
|
||||
assert proc.returncode == 0 # exit status unknowable, so...
|
||||
finally:
|
||||
signal.signal(signal.SIGCHLD, old_sigchld)
|
||||
|
||||
|
||||
@slow
|
||||
def test_waitid_eintr() -> None:
|
||||
# This only matters on PyPy (where we're coding EINTR handling
|
||||
# ourselves) but the test works on all waitid platforms.
|
||||
from .._subprocess_platform import wait_child_exiting
|
||||
|
||||
if TYPE_CHECKING and (sys.platform == "win32" or sys.platform == "darwin"):
|
||||
return
|
||||
|
||||
if not wait_child_exiting.__module__.endswith("waitid"):
|
||||
pytest.skip("waitid only")
|
||||
|
||||
# despite the TYPE_CHECKING early return silencing warnings about signal.SIGALRM etc
|
||||
# this import is still checked on win32&darwin and raises [attr-defined].
|
||||
# Linux doesn't raise [attr-defined] though, so we need [unused-ignore]
|
||||
from .._subprocess_platform.waitid import ( # type: ignore[attr-defined, unused-ignore]
|
||||
sync_wait_reapable,
|
||||
)
|
||||
|
||||
got_alarm = False
|
||||
sleeper = subprocess.Popen(["sleep", "3600"])
|
||||
|
||||
def on_alarm(sig: int, frame: FrameType | None) -> None:
|
||||
nonlocal got_alarm
|
||||
got_alarm = True
|
||||
sleeper.kill()
|
||||
|
||||
old_sigalrm = signal.signal(signal.SIGALRM, on_alarm)
|
||||
try:
|
||||
signal.alarm(1)
|
||||
sync_wait_reapable(sleeper.pid)
|
||||
assert sleeper.wait(timeout=1) == -9
|
||||
finally:
|
||||
if sleeper.returncode is None: # pragma: no cover
|
||||
# We only get here if something fails in the above;
|
||||
# if the test passes, wait() will reap the process
|
||||
sleeper.kill()
|
||||
sleeper.wait()
|
||||
signal.signal(signal.SIGALRM, old_sigalrm)
|
||||
|
||||
|
||||
async def test_custom_deliver_cancel() -> None:
|
||||
custom_deliver_cancel_called = False
|
||||
|
||||
async def custom_deliver_cancel(proc: Process) -> None:
|
||||
nonlocal custom_deliver_cancel_called
|
||||
custom_deliver_cancel_called = True
|
||||
proc.terminate()
|
||||
# Make sure this does get cancelled when the process exits, and that
|
||||
# the process really exited.
|
||||
try:
|
||||
await sleep_forever()
|
||||
finally:
|
||||
assert proc.returncode is not None
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
assert custom_deliver_cancel_called
|
||||
|
||||
|
||||
def test_bad_deliver_cancel() -> None:
|
||||
async def custom_deliver_cancel(proc: Process) -> None:
|
||||
proc.terminate()
|
||||
raise ValueError("foo")
|
||||
|
||||
async def do_stuff() -> None:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
partial(run_process, SLEEP(9999), deliver_cancel=custom_deliver_cancel),
|
||||
)
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
# double wrap from our nursery + the internal nursery
|
||||
with RaisesGroup(RaisesGroup(Matcher(ValueError, "^foo$"))):
|
||||
_core.run(do_stuff, strict_exception_groups=True)
|
||||
|
||||
|
||||
async def test_warn_on_failed_cancel_terminate(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
original_terminate = Process.terminate
|
||||
|
||||
def broken_terminate(self: Process) -> NoReturn:
|
||||
original_terminate(self)
|
||||
raise OSError("whoops")
|
||||
|
||||
monkeypatch.setattr(Process, "terminate", broken_terminate)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*whoops.*"):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
@pytest.mark.skipif(not posix, reason="posix only")
|
||||
async def test_warn_on_cancel_SIGKILL_escalation(
|
||||
autojump_clock: MockClock,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(Process, "terminate", lambda *args: None)
|
||||
|
||||
with pytest.warns(RuntimeWarning, match=".*ignored SIGTERM.*"):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(run_process, SLEEP(9999))
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
# the background_process_param exercises a lot of run_process cases, but it uses
|
||||
# check=False, so lets have a test that uses check=True as well
|
||||
async def test_run_process_background_fail() -> None:
|
||||
with RaisesGroup(subprocess.CalledProcessError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
proc: Process = await nursery.start(run_process, EXIT_FALSE)
|
||||
assert proc.returncode == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not SyncPath("/dev/fd").exists(),
|
||||
reason="requires a way to iterate through open files",
|
||||
)
|
||||
async def test_for_leaking_fds() -> None:
|
||||
gc.collect() # address possible flakiness on PyPy
|
||||
|
||||
starting_fds = set(SyncPath("/dev/fd").iterdir())
|
||||
await run_process(EXIT_TRUE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
with pytest.raises(subprocess.CalledProcessError):
|
||||
await run_process(EXIT_FALSE)
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await run_process(["/dev/fd/0"])
|
||||
assert set(SyncPath("/dev/fd").iterdir()) == starting_fds
|
||||
|
||||
|
||||
async def test_run_process_internal_error(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# There's probably less extreme ways of triggering errors inside the nursery
|
||||
# in run_process.
|
||||
async def very_broken_open(*args: object, **kwargs: object) -> str:
|
||||
return "oops"
|
||||
|
||||
monkeypatch.setattr(trio._subprocess, "open_process", very_broken_open)
|
||||
with RaisesGroup(AttributeError, AttributeError):
|
||||
await run_process(EXIT_TRUE, capture_stdout=True)
|
||||
|
||||
|
||||
# regression test for #2209
|
||||
async def test_subprocess_pidfd_unnotified() -> None:
|
||||
noticed_exit = None
|
||||
|
||||
async def wait_and_tell(proc: Process) -> None:
|
||||
nonlocal noticed_exit
|
||||
noticed_exit = Event()
|
||||
await proc.wait()
|
||||
noticed_exit.set()
|
||||
|
||||
proc = await open_process(SLEEP(9999))
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_and_tell, proc)
|
||||
await wait_all_tasks_blocked()
|
||||
assert isinstance(noticed_exit, Event)
|
||||
proc.terminate()
|
||||
# without giving trio a chance to do so,
|
||||
with assert_no_checkpoints():
|
||||
# wait until the process has actually exited;
|
||||
proc._proc.wait()
|
||||
# force a call to poll (that closes the pidfd on linux)
|
||||
proc.poll()
|
||||
with move_on_after(5):
|
||||
# Some platforms use threads to wait for exit, so it might take a bit
|
||||
# for everything to notice
|
||||
await noticed_exit.wait()
|
||||
assert noticed_exit.is_set(), "child task wasn't woken after poll, DEADLOCK"
|
655
lib/python3.13/site-packages/trio/_tests/test_sync.py
Normal file
655
lib/python3.13/site-packages/trio/_tests/test_sync.py
Normal file
@ -0,0 +1,655 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import weakref
|
||||
from typing import TYPE_CHECKING, Callable, Union
|
||||
|
||||
import pytest
|
||||
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
from .. import _core
|
||||
from .._core._parking_lot import GLOBAL_PARKING_LOT_BREAKER
|
||||
from .._sync import *
|
||||
from .._timeouts import sleep_forever
|
||||
from ..testing import assert_checkpoints, wait_all_tasks_blocked
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
|
||||
async def test_Event() -> None:
|
||||
e = Event()
|
||||
assert not e.is_set()
|
||||
assert e.statistics().tasks_waiting == 0
|
||||
|
||||
e.set()
|
||||
assert e.is_set()
|
||||
with assert_checkpoints():
|
||||
await e.wait()
|
||||
|
||||
e = Event()
|
||||
|
||||
record = []
|
||||
|
||||
async def child() -> None:
|
||||
record.append("sleeping")
|
||||
await e.wait()
|
||||
record.append("woken")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child)
|
||||
nursery.start_soon(child)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping"]
|
||||
assert e.statistics().tasks_waiting == 2
|
||||
e.set()
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["sleeping", "sleeping", "woken", "woken"]
|
||||
|
||||
|
||||
async def test_CapacityLimiter() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
CapacityLimiter(1.0)
|
||||
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
|
||||
CapacityLimiter(-1)
|
||||
c = CapacityLimiter(2)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == 2
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == 2
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == 1
|
||||
|
||||
stats = c.statistics()
|
||||
assert stats.borrowed_tokens == 1
|
||||
assert stats.total_tokens == 2
|
||||
assert stats.borrowers == [_core.current_task()]
|
||||
assert stats.tasks_waiting == 0
|
||||
|
||||
# Can't re-acquire when we already have it
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
with pytest.raises(RuntimeError):
|
||||
await c.acquire()
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
# We can acquire on behalf of someone else though
|
||||
with assert_checkpoints():
|
||||
await c.acquire_on_behalf_of("someone")
|
||||
|
||||
# But then we've run out of capacity
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_on_behalf_of_nowait("third party")
|
||||
|
||||
assert set(c.statistics().borrowers) == {_core.current_task(), "someone"}
|
||||
|
||||
# Until we release one
|
||||
c.release_on_behalf_of(_core.current_task())
|
||||
assert c.statistics().borrowers == ["someone"]
|
||||
|
||||
c.release_on_behalf_of("someone")
|
||||
assert c.borrowed_tokens == 0
|
||||
with assert_checkpoints():
|
||||
async with c:
|
||||
assert c.borrowed_tokens == 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
await c.acquire_on_behalf_of("value 1")
|
||||
await c.acquire_on_behalf_of("value 2")
|
||||
nursery.start_soon(c.acquire_on_behalf_of, "value 3")
|
||||
await wait_all_tasks_blocked()
|
||||
assert c.borrowed_tokens == 2
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of("value 2")
|
||||
# Fairness:
|
||||
assert c.borrowed_tokens == 2
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
c.acquire_nowait()
|
||||
|
||||
c.release_on_behalf_of("value 3")
|
||||
c.release_on_behalf_of("value 1")
|
||||
|
||||
|
||||
async def test_CapacityLimiter_inf() -> None:
|
||||
from math import inf
|
||||
|
||||
c = CapacityLimiter(inf)
|
||||
repr(c) # smoke test
|
||||
assert c.total_tokens == inf
|
||||
assert c.borrowed_tokens == 0
|
||||
assert c.available_tokens == inf
|
||||
with pytest.raises(RuntimeError):
|
||||
c.release()
|
||||
assert c.borrowed_tokens == 0
|
||||
c.acquire_nowait()
|
||||
assert c.borrowed_tokens == 1
|
||||
assert c.available_tokens == inf
|
||||
|
||||
|
||||
async def test_CapacityLimiter_change_total_tokens() -> None:
|
||||
c = CapacityLimiter(2)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
c.total_tokens = 1.0
|
||||
|
||||
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
|
||||
c.total_tokens = 0
|
||||
|
||||
with pytest.raises(ValueError, match="^total_tokens must be >= 1$"):
|
||||
c.total_tokens = -10
|
||||
|
||||
assert c.total_tokens == 2
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(5):
|
||||
nursery.start_soon(c.acquire_on_behalf_of, i)
|
||||
await wait_all_tasks_blocked()
|
||||
assert set(c.statistics().borrowers) == {0, 1}
|
||||
assert c.statistics().tasks_waiting == 3
|
||||
c.total_tokens += 2
|
||||
assert set(c.statistics().borrowers) == {0, 1, 2, 3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.total_tokens -= 3
|
||||
assert c.borrowed_tokens == 4
|
||||
assert c.total_tokens == 1
|
||||
c.release_on_behalf_of(0)
|
||||
c.release_on_behalf_of(1)
|
||||
c.release_on_behalf_of(2)
|
||||
assert set(c.statistics().borrowers) == {3}
|
||||
assert c.statistics().tasks_waiting == 1
|
||||
c.release_on_behalf_of(3)
|
||||
assert set(c.statistics().borrowers) == {4}
|
||||
assert c.statistics().tasks_waiting == 0
|
||||
|
||||
|
||||
# regression test for issue #548
|
||||
async def test_CapacityLimiter_memleak_548() -> None:
|
||||
limiter = CapacityLimiter(total_tokens=1)
|
||||
await limiter.acquire()
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(limiter.acquire)
|
||||
await wait_all_tasks_blocked() # give it a chance to run the task
|
||||
n.cancel_scope.cancel()
|
||||
|
||||
# if this is 1, the acquire call (despite being killed) is still there in the task, and will
|
||||
# leak memory all the while the limiter is active
|
||||
assert len(limiter._pending_borrowers) == 0
|
||||
|
||||
|
||||
async def test_Semaphore() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="^initial value must be >= 0$"):
|
||||
Semaphore(-1)
|
||||
s = Semaphore(1)
|
||||
repr(s) # smoke test
|
||||
assert s.value == 1
|
||||
assert s.max_value is None
|
||||
s.release()
|
||||
assert s.value == 2
|
||||
assert s.statistics().tasks_waiting == 0
|
||||
s.acquire_nowait()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
await s.acquire()
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
|
||||
s.release()
|
||||
assert s.value == 1
|
||||
with assert_checkpoints():
|
||||
async with s:
|
||||
assert s.value == 0
|
||||
assert s.value == 1
|
||||
s.acquire_nowait()
|
||||
|
||||
record = []
|
||||
|
||||
async def do_acquire(s: Semaphore) -> None:
|
||||
record.append("started")
|
||||
await s.acquire()
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_acquire, s)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
assert s.value == 0
|
||||
s.release()
|
||||
# Fairness:
|
||||
assert s.value == 0
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
s.acquire_nowait()
|
||||
assert record == ["started", "finished"]
|
||||
|
||||
|
||||
async def test_Semaphore_bounded() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Semaphore(1, max_value=1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(ValueError, match="^max_values must be >= initial_value$"):
|
||||
Semaphore(2, max_value=1)
|
||||
bs = Semaphore(1, max_value=1)
|
||||
assert bs.max_value == 1
|
||||
repr(bs) # smoke test
|
||||
with pytest.raises(ValueError, match="^semaphore released too many times$"):
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
bs.acquire_nowait()
|
||||
assert bs.value == 0
|
||||
bs.release()
|
||||
assert bs.value == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lockcls", [Lock, StrictFIFOLock], ids=lambda fn: fn.__name__)
|
||||
async def test_Lock_and_StrictFIFOLock(
|
||||
lockcls: type[Lock | StrictFIFOLock],
|
||||
) -> None:
|
||||
l = lockcls() # noqa
|
||||
assert not l.locked()
|
||||
|
||||
# make sure locks can be weakref'ed (gh-331)
|
||||
r = weakref.ref(l)
|
||||
assert r() is l
|
||||
|
||||
repr(l) # smoke test
|
||||
# make sure repr uses the right name for subclasses
|
||||
assert lockcls.__name__ in repr(l)
|
||||
with assert_checkpoints():
|
||||
async with l:
|
||||
assert l.locked()
|
||||
repr(l) # smoke test (repr branches on locked/unlocked)
|
||||
assert not l.locked()
|
||||
l.acquire_nowait()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
with assert_checkpoints():
|
||||
await l.acquire()
|
||||
assert l.locked()
|
||||
l.release()
|
||||
assert not l.locked()
|
||||
|
||||
l.acquire_nowait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we already own the lock
|
||||
l.acquire_nowait()
|
||||
l.release()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Error out if we don't own the lock
|
||||
l.release()
|
||||
|
||||
holder_task = None
|
||||
|
||||
async def holder() -> None:
|
||||
nonlocal holder_task
|
||||
holder_task = _core.current_task()
|
||||
async with l:
|
||||
await sleep_forever()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
assert not l.locked()
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
assert l.locked()
|
||||
# WouldBlock if someone else holds the lock
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
l.acquire_nowait()
|
||||
# Can't release a lock someone else holds
|
||||
with pytest.raises(RuntimeError):
|
||||
l.release()
|
||||
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.locked
|
||||
assert statistics.owner is holder_task
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
nursery.start_soon(holder)
|
||||
await wait_all_tasks_blocked()
|
||||
statistics = l.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
statistics = l.statistics()
|
||||
assert not statistics.locked
|
||||
assert statistics.owner is None
|
||||
assert statistics.tasks_waiting == 0
|
||||
|
||||
|
||||
async def test_Condition() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Condition(Semaphore(1)) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
Condition(StrictFIFOLock) # type: ignore[arg-type]
|
||||
l = Lock() # noqa
|
||||
c = Condition(l)
|
||||
assert not l.locked()
|
||||
assert not c.locked()
|
||||
with assert_checkpoints():
|
||||
await c.acquire()
|
||||
assert l.locked()
|
||||
assert c.locked()
|
||||
|
||||
c = Condition()
|
||||
assert not c.locked()
|
||||
c.acquire_nowait()
|
||||
assert c.locked()
|
||||
with pytest.raises(RuntimeError):
|
||||
c.acquire_nowait()
|
||||
c.release()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't wait without holding the lock
|
||||
await c.wait()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify()
|
||||
with pytest.raises(RuntimeError):
|
||||
# Can't notify without holding the lock
|
||||
c.notify_all()
|
||||
|
||||
finished_waiters = set()
|
||||
|
||||
async def waiter(i: int) -> None:
|
||||
async with c:
|
||||
await c.wait()
|
||||
finished_waiters.add(i)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify()
|
||||
assert c.locked()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0}
|
||||
async with c:
|
||||
c.notify_all()
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1, 2}
|
||||
|
||||
finished_waiters = set()
|
||||
async with _core.open_nursery() as nursery:
|
||||
for i in range(3):
|
||||
nursery.start_soon(waiter, i)
|
||||
await wait_all_tasks_blocked()
|
||||
async with c:
|
||||
c.notify(2)
|
||||
statistics = c.statistics()
|
||||
print(statistics)
|
||||
assert statistics.tasks_waiting == 1
|
||||
assert statistics.lock_statistics.tasks_waiting == 2
|
||||
# exiting the context manager hands off the lock to the first task
|
||||
assert c.statistics().lock_statistics.tasks_waiting == 1
|
||||
|
||||
await wait_all_tasks_blocked()
|
||||
assert finished_waiters == {0, 1}
|
||||
|
||||
async with c:
|
||||
c.notify_all()
|
||||
|
||||
# After being cancelled still hold the lock (!)
|
||||
# (Note that c.__aexit__ checks that we hold the lock as well)
|
||||
with _core.CancelScope() as scope:
|
||||
async with c:
|
||||
scope.cancel()
|
||||
try:
|
||||
await c.wait()
|
||||
finally:
|
||||
assert c.locked()
|
||||
|
||||
|
||||
from .._channel import open_memory_channel
|
||||
from .._sync import AsyncContextManagerMixin
|
||||
|
||||
# Three ways of implementing a Lock in terms of a channel. Used to let us put
|
||||
# the channel through the generic lock tests.
|
||||
|
||||
|
||||
class ChannelLock1(AsyncContextManagerMixin):
|
||||
def __init__(self, capacity: int) -> None:
|
||||
self.s, self.r = open_memory_channel[None](capacity)
|
||||
for _ in range(capacity - 1):
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
self.s.send_nowait(None)
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self.s.send(None)
|
||||
|
||||
def release(self) -> None:
|
||||
self.r.receive_nowait()
|
||||
|
||||
|
||||
class ChannelLock2(AsyncContextManagerMixin):
|
||||
def __init__(self) -> None:
|
||||
self.s, self.r = open_memory_channel[None](10)
|
||||
self.s.send_nowait(None)
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
self.r.receive_nowait()
|
||||
|
||||
async def acquire(self) -> None:
|
||||
await self.r.receive()
|
||||
|
||||
def release(self) -> None:
|
||||
self.s.send_nowait(None)
|
||||
|
||||
|
||||
class ChannelLock3(AsyncContextManagerMixin):
|
||||
def __init__(self) -> None:
|
||||
self.s, self.r = open_memory_channel[None](0)
|
||||
# self.acquired is true when one task acquires the lock and
|
||||
# only becomes false when it's released and no tasks are
|
||||
# waiting to acquire.
|
||||
self.acquired = False
|
||||
|
||||
def acquire_nowait(self) -> None:
|
||||
assert not self.acquired
|
||||
self.acquired = True
|
||||
|
||||
async def acquire(self) -> None:
|
||||
if self.acquired:
|
||||
await self.s.send(None)
|
||||
else:
|
||||
self.acquired = True
|
||||
await _core.checkpoint()
|
||||
|
||||
def release(self) -> None:
|
||||
try:
|
||||
self.r.receive_nowait()
|
||||
except _core.WouldBlock:
|
||||
assert self.acquired
|
||||
self.acquired = False
|
||||
|
||||
|
||||
lock_factories = [
|
||||
lambda: CapacityLimiter(1),
|
||||
lambda: Semaphore(1),
|
||||
Lock,
|
||||
StrictFIFOLock,
|
||||
lambda: ChannelLock1(10),
|
||||
lambda: ChannelLock1(1),
|
||||
ChannelLock2,
|
||||
ChannelLock3,
|
||||
]
|
||||
lock_factory_names = [
|
||||
"CapacityLimiter(1)",
|
||||
"Semaphore(1)",
|
||||
"Lock",
|
||||
"StrictFIFOLock",
|
||||
"ChannelLock1(10)",
|
||||
"ChannelLock1(1)",
|
||||
"ChannelLock2",
|
||||
"ChannelLock3",
|
||||
]
|
||||
|
||||
generic_lock_test = pytest.mark.parametrize(
|
||||
"lock_factory",
|
||||
lock_factories,
|
||||
ids=lock_factory_names,
|
||||
)
|
||||
|
||||
LockLike: TypeAlias = Union[
|
||||
CapacityLimiter,
|
||||
Semaphore,
|
||||
Lock,
|
||||
StrictFIFOLock,
|
||||
ChannelLock1,
|
||||
ChannelLock2,
|
||||
ChannelLock3,
|
||||
]
|
||||
LockFactory: TypeAlias = Callable[[], LockLike]
|
||||
|
||||
|
||||
# Spawn a bunch of workers that take a lock and then yield; make sure that
|
||||
# only one worker is ever in the critical section at a time.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_exclusion(lock_factory: LockFactory) -> None:
|
||||
LOOPS = 10
|
||||
WORKERS = 5
|
||||
in_critical_section = False
|
||||
acquires = 0
|
||||
|
||||
async def worker(lock_like: LockLike) -> None:
|
||||
nonlocal in_critical_section, acquires
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
acquires += 1
|
||||
assert not in_critical_section
|
||||
in_critical_section = True
|
||||
await _core.checkpoint()
|
||||
await _core.checkpoint()
|
||||
assert in_critical_section
|
||||
in_critical_section = False
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like = lock_factory()
|
||||
for _ in range(WORKERS):
|
||||
nursery.start_soon(worker, lock_like)
|
||||
assert not in_critical_section
|
||||
assert acquires == LOOPS * WORKERS
|
||||
|
||||
|
||||
# Several workers queue on the same lock; make sure they each get it, in
|
||||
# order.
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_fifo_fairness(lock_factory: LockFactory) -> None:
|
||||
initial_order = []
|
||||
record = []
|
||||
LOOPS = 5
|
||||
|
||||
async def loopy(name: int, lock_like: LockLike) -> None:
|
||||
# Record the order each task was initially scheduled in
|
||||
initial_order.append(name)
|
||||
for _ in range(LOOPS):
|
||||
async with lock_like:
|
||||
record.append(name)
|
||||
|
||||
lock_like = lock_factory()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(loopy, 1, lock_like)
|
||||
nursery.start_soon(loopy, 2, lock_like)
|
||||
nursery.start_soon(loopy, 3, lock_like)
|
||||
# The first three could be in any order due to scheduling randomness,
|
||||
# but after that they should repeat in the same order
|
||||
for i in range(LOOPS):
|
||||
assert record[3 * i : 3 * (i + 1)] == initial_order
|
||||
|
||||
|
||||
@generic_lock_test
|
||||
async def test_generic_lock_acquire_nowait_blocks_acquire(
|
||||
lock_factory: LockFactory,
|
||||
) -> None:
|
||||
lock_like = lock_factory()
|
||||
|
||||
record = []
|
||||
|
||||
async def lock_taker() -> None:
|
||||
record.append("started")
|
||||
async with lock_like:
|
||||
pass
|
||||
record.append("finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
lock_like.acquire_nowait()
|
||||
nursery.start_soon(lock_taker)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["started"]
|
||||
lock_like.release()
|
||||
|
||||
|
||||
async def test_lock_acquire_unowned_lock() -> None:
|
||||
"""Test that trying to acquire a lock whose owner has exited raises an error.
|
||||
see https://github.com/python-trio/trio/issues/3035
|
||||
"""
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
lock = trio.Lock()
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
owner_str = re.escape(str(lock._lot.broken_by[0]))
|
||||
with pytest.raises(
|
||||
trio.BrokenResourceError,
|
||||
match=f"^Owner of this lock exited without releasing: {owner_str}$",
|
||||
):
|
||||
await lock.acquire()
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
|
||||
|
||||
async def test_lock_multiple_acquire() -> None:
|
||||
"""Test for error if awaiting on a lock whose owner exits without releasing.
|
||||
see https://github.com/python-trio/trio/issues/3035"""
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
lock = trio.Lock()
|
||||
with RaisesGroup(
|
||||
Matcher(
|
||||
trio.BrokenResourceError,
|
||||
match="^Owner of this lock exited without releasing: ",
|
||||
),
|
||||
):
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
nursery.start_soon(lock.acquire)
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
|
||||
|
||||
async def test_lock_handover() -> None:
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
||||
child_task: Task | None = None
|
||||
lock = trio.Lock()
|
||||
|
||||
# this task acquires the lock
|
||||
lock.acquire_nowait()
|
||||
assert GLOBAL_PARKING_LOT_BREAKER == {
|
||||
_core.current_task(): [
|
||||
lock._lot,
|
||||
],
|
||||
}
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(lock.acquire)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# hand over the lock to the child task
|
||||
lock.release()
|
||||
|
||||
# check values, and get the identifier out of the dict for later check
|
||||
assert len(GLOBAL_PARKING_LOT_BREAKER) == 1
|
||||
child_task = next(iter(GLOBAL_PARKING_LOT_BREAKER))
|
||||
assert GLOBAL_PARKING_LOT_BREAKER[child_task] == [lock._lot]
|
||||
|
||||
assert lock._lot.broken_by == [child_task]
|
||||
assert not GLOBAL_PARKING_LOT_BREAKER
|
684
lib/python3.13/site-packages/trio/_tests/test_testing.py
Normal file
684
lib/python3.13/site-packages/trio/_tests/test_testing.py
Normal file
@ -0,0 +1,684 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# XX this should get broken up, like testing.py did
|
||||
import tempfile
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trio.testing import RaisesGroup
|
||||
|
||||
from .. import _core, sleep, socket as tsocket
|
||||
from .._core._tests.tutil import can_bind_ipv6
|
||||
from .._highlevel_generic import StapledStream, aclose_forcefully
|
||||
from .._highlevel_socket import SocketListener
|
||||
from ..testing import *
|
||||
from ..testing._check_streams import _assert_raises
|
||||
from ..testing._memory_streams import _UnboundedByteQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trio import Nursery
|
||||
from trio.abc import ReceiveStream, SendStream
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked() -> None:
|
||||
record = []
|
||||
|
||||
async def busy_bee() -> None:
|
||||
for _ in range(10):
|
||||
await _core.checkpoint()
|
||||
record.append("busy bee exhausted")
|
||||
|
||||
async def waiting_for_bee_to_leave() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("quiet at last!")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(busy_bee)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
nursery.start_soon(waiting_for_bee_to_leave)
|
||||
|
||||
# check cancellation
|
||||
record = []
|
||||
|
||||
async def cancelled_while_waiting() -> None:
|
||||
try:
|
||||
await wait_all_tasks_blocked()
|
||||
except _core.Cancelled:
|
||||
record.append("ok")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancelled_while_waiting)
|
||||
nursery.cancel_scope.cancel()
|
||||
assert record == ["ok"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_timeouts(mock_clock: MockClock) -> None:
|
||||
record = []
|
||||
|
||||
async def timeout_task() -> None:
|
||||
record.append("tt start")
|
||||
await sleep(5)
|
||||
record.append("tt finished")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(timeout_task)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start"]
|
||||
mock_clock.jump(10)
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == ["tt start", "tt finished"]
|
||||
|
||||
|
||||
async def test_wait_all_tasks_blocked_with_cushion() -> None:
|
||||
record = []
|
||||
|
||||
async def blink() -> None:
|
||||
record.append("blink start")
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
await sleep(0.01)
|
||||
record.append("blink end")
|
||||
|
||||
async def wait_no_cushion() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
record.append("wait_no_cushion end")
|
||||
|
||||
async def wait_small_cushion() -> None:
|
||||
await wait_all_tasks_blocked(0.02)
|
||||
record.append("wait_small_cushion end")
|
||||
|
||||
async def wait_big_cushion() -> None:
|
||||
await wait_all_tasks_blocked(0.03)
|
||||
record.append("wait_big_cushion end")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(blink)
|
||||
nursery.start_soon(wait_no_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_small_cushion)
|
||||
nursery.start_soon(wait_big_cushion)
|
||||
|
||||
assert record == [
|
||||
"blink start",
|
||||
"wait_no_cushion end",
|
||||
"blink end",
|
||||
"wait_small_cushion end",
|
||||
"wait_small_cushion end",
|
||||
"wait_big_cushion end",
|
||||
]
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_assert_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that's not a checkpoint.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# But both together count as a checkpoint
|
||||
with assert_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
async def test_assert_no_checkpoints(recwarn: pytest.WarningsRecorder) -> None:
|
||||
with assert_no_checkpoints():
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint()
|
||||
|
||||
# partial yield cases
|
||||
# if you have a schedule point but not a cancel point, or vice-versa, then
|
||||
# that doesn't make *either* version of assert_{no_,}yields happy.
|
||||
for partial_yield in [
|
||||
_core.checkpoint_if_cancelled,
|
||||
_core.cancel_shielded_checkpoint,
|
||||
]:
|
||||
print(partial_yield)
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await partial_yield()
|
||||
|
||||
# And both together also count as a checkpoint
|
||||
with pytest.raises(AssertionError):
|
||||
with assert_no_checkpoints():
|
||||
await _core.checkpoint_if_cancelled()
|
||||
await _core.cancel_shielded_checkpoint()
|
||||
|
||||
|
||||
################################################################
|
||||
|
||||
|
||||
async def test_Sequencer() -> None:
|
||||
record = []
|
||||
|
||||
def t(val: object) -> None:
|
||||
print(val)
|
||||
record.append(val)
|
||||
|
||||
async def f1(seq: Sequencer) -> None:
|
||||
async with seq(1):
|
||||
t(("f1", 1))
|
||||
async with seq(3):
|
||||
t(("f1", 3))
|
||||
async with seq(4):
|
||||
t(("f1", 4))
|
||||
|
||||
async def f2(seq: Sequencer) -> None:
|
||||
async with seq(0):
|
||||
t(("f2", 0))
|
||||
async with seq(2):
|
||||
t(("f2", 2))
|
||||
|
||||
seq = Sequencer()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(f1, seq)
|
||||
nursery.start_soon(f2, seq)
|
||||
async with seq(5):
|
||||
await wait_all_tasks_blocked()
|
||||
assert record == [("f2", 0), ("f1", 1), ("f2", 2), ("f1", 3), ("f1", 4)]
|
||||
|
||||
seq = Sequencer()
|
||||
# Catches us if we try to reuse a sequence point:
|
||||
async with seq(0):
|
||||
pass
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
async def test_Sequencer_cancel() -> None:
|
||||
# Killing a blocked task makes everything blow up
|
||||
record = []
|
||||
seq = Sequencer()
|
||||
|
||||
async def child(i: int) -> None:
|
||||
with _core.CancelScope() as scope:
|
||||
if i == 1:
|
||||
scope.cancel()
|
||||
try:
|
||||
async with seq(i):
|
||||
pass # pragma: no cover
|
||||
except RuntimeError:
|
||||
record.append(f"seq({i}) RuntimeError")
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(child, 1)
|
||||
nursery.start_soon(child, 2)
|
||||
async with seq(0):
|
||||
pass # pragma: no cover
|
||||
|
||||
assert record == ["seq(1) RuntimeError", "seq(2) RuntimeError"]
|
||||
|
||||
# Late arrivals also get errors
|
||||
with pytest.raises(RuntimeError):
|
||||
async with seq(3):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
################################################################
|
||||
async def test__assert_raises() -> None:
|
||||
with pytest.raises(AssertionError):
|
||||
with _assert_raises(RuntimeError):
|
||||
1 + 1 # noqa: B018 # "useless expression"
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
with _assert_raises(RuntimeError):
|
||||
"foo" + 1 # type: ignore[operator] # noqa: B018 # "useless expression"
|
||||
|
||||
with _assert_raises(RuntimeError):
|
||||
raise RuntimeError
|
||||
|
||||
|
||||
# This is a private implementation detail, but it's complex enough to be worth
|
||||
# testing directly
|
||||
async def test__UnboundeByteQueue() -> None:
|
||||
ubq = _UnboundedByteQueue()
|
||||
|
||||
ubq.put(b"123")
|
||||
ubq.put(b"456")
|
||||
assert ubq.get_nowait(1) == b"1"
|
||||
assert ubq.get_nowait(10) == b"23456"
|
||||
ubq.put(b"789")
|
||||
assert ubq.get_nowait() == b"789"
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait(10)
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
ubq.get_nowait()
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
ubq.put("string") # type: ignore[arg-type]
|
||||
|
||||
ubq.put(b"abc")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(10) == b"abc"
|
||||
ubq.put(b"def")
|
||||
ubq.put(b"ghi")
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get(1) == b"d"
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == b"efghi"
|
||||
|
||||
async def putter(data: bytes) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
ubq.put(data)
|
||||
|
||||
async def getter(expect: bytes) -> None:
|
||||
with assert_checkpoints():
|
||||
assert await ubq.get() == expect
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"xyz")
|
||||
nursery.start_soon(putter, b"xyz")
|
||||
|
||||
# Two gets at the same time -> BusyResourceError
|
||||
with RaisesGroup(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
nursery.start_soon(getter, b"asdf")
|
||||
|
||||
# Closing
|
||||
|
||||
ubq.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
ubq.put(b"---")
|
||||
|
||||
assert ubq.get_nowait(10) == b""
|
||||
assert ubq.get_nowait() == b""
|
||||
assert await ubq.get(10) == b""
|
||||
assert await ubq.get() == b""
|
||||
|
||||
# close is idempotent
|
||||
ubq.close()
|
||||
|
||||
# close wakes up blocked getters
|
||||
ubq2 = _UnboundedByteQueue()
|
||||
|
||||
async def closer() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
ubq2.close()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(getter, b"")
|
||||
nursery.start_soon(closer)
|
||||
|
||||
|
||||
async def test_MemorySendStream() -> None:
|
||||
mss = MemorySendStream()
|
||||
|
||||
async def do_send_all(data: bytes) -> None:
|
||||
with assert_checkpoints():
|
||||
await mss.send_all(data)
|
||||
|
||||
await do_send_all(b"123")
|
||||
assert mss.get_data_nowait(1) == b"1"
|
||||
assert mss.get_data_nowait() == b"23"
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.wait_send_all_might_not_block()
|
||||
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait()
|
||||
with pytest.raises(_core.WouldBlock):
|
||||
mss.get_data_nowait(10)
|
||||
|
||||
await do_send_all(b"456")
|
||||
with assert_checkpoints():
|
||||
assert await mss.get_data() == b"456"
|
||||
|
||||
# Call send_all twice at once; one should get BusyResourceError and one
|
||||
# should succeed. But we can't let the error propagate, because it might
|
||||
# cause the other to be cancelled before it can finish doing its thing,
|
||||
# and we don't know which one will get the error.
|
||||
resource_busy_count = 0
|
||||
|
||||
async def do_send_all_count_resourcebusy() -> None:
|
||||
nonlocal resource_busy_count
|
||||
try:
|
||||
await do_send_all(b"xxx")
|
||||
except _core.BusyResourceError:
|
||||
resource_busy_count += 1
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
nursery.start_soon(do_send_all_count_resourcebusy)
|
||||
|
||||
assert resource_busy_count == 1
|
||||
|
||||
with assert_checkpoints():
|
||||
await mss.aclose()
|
||||
|
||||
assert await mss.get_data() == b"xxx"
|
||||
assert await mss.get_data() == b""
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await do_send_all(b"---")
|
||||
|
||||
# hooks
|
||||
|
||||
assert mss.send_all_hook is None
|
||||
assert mss.wait_send_all_might_not_block_hook is None
|
||||
assert mss.close_hook is None
|
||||
|
||||
record = []
|
||||
|
||||
async def send_all_hook() -> None:
|
||||
# hook runs after send_all does its work (can pull data out)
|
||||
assert mss2.get_data_nowait() == b"abc"
|
||||
record.append("send_all_hook")
|
||||
|
||||
async def wait_send_all_might_not_block_hook() -> None:
|
||||
record.append("wait_send_all_might_not_block_hook")
|
||||
|
||||
def close_hook() -> None:
|
||||
record.append("close_hook")
|
||||
|
||||
mss2 = MemorySendStream(
|
||||
send_all_hook,
|
||||
wait_send_all_might_not_block_hook,
|
||||
close_hook,
|
||||
)
|
||||
|
||||
assert mss2.send_all_hook is send_all_hook
|
||||
assert mss2.wait_send_all_might_not_block_hook is wait_send_all_might_not_block_hook
|
||||
assert mss2.close_hook is close_hook
|
||||
|
||||
await mss2.send_all(b"abc")
|
||||
await mss2.wait_send_all_might_not_block()
|
||||
await aclose_forcefully(mss2)
|
||||
mss2.close()
|
||||
|
||||
assert record == [
|
||||
"send_all_hook",
|
||||
"wait_send_all_might_not_block_hook",
|
||||
"close_hook",
|
||||
"close_hook",
|
||||
]
|
||||
|
||||
|
||||
async def test_MemoryReceiveStream() -> None:
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
async def do_receive_some(max_bytes: int | None) -> bytes:
|
||||
with assert_checkpoints():
|
||||
return await mrs.receive_some(max_bytes)
|
||||
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(1) == b"a"
|
||||
assert await do_receive_some(10) == b"bc"
|
||||
mrs.put_data(b"abc")
|
||||
assert await do_receive_some(None) == b"abc"
|
||||
|
||||
with RaisesGroup(_core.BusyResourceError):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
nursery.start_soon(do_receive_some, 10)
|
||||
|
||||
assert mrs.receive_some_hook is None
|
||||
|
||||
mrs.put_data(b"def")
|
||||
mrs.put_eof()
|
||||
mrs.put_eof()
|
||||
|
||||
assert await do_receive_some(10) == b"def"
|
||||
assert await do_receive_some(10) == b""
|
||||
assert await do_receive_some(10) == b""
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"---")
|
||||
|
||||
async def receive_some_hook() -> None:
|
||||
mrs2.put_data(b"xxx")
|
||||
|
||||
record = []
|
||||
|
||||
def close_hook() -> None:
|
||||
record.append("closed")
|
||||
|
||||
mrs2 = MemoryReceiveStream(receive_some_hook, close_hook)
|
||||
assert mrs2.receive_some_hook is receive_some_hook
|
||||
assert mrs2.close_hook is close_hook
|
||||
|
||||
mrs2.put_data(b"yyy")
|
||||
assert await mrs2.receive_some(10) == b"yyyxxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
assert await mrs2.receive_some(10) == b"xxx"
|
||||
|
||||
mrs2.put_data(b"zzz")
|
||||
mrs2.receive_some_hook = None
|
||||
assert await mrs2.receive_some(10) == b"zzz"
|
||||
|
||||
mrs2.put_data(b"lost on close")
|
||||
with assert_checkpoints():
|
||||
await mrs2.aclose()
|
||||
assert record == ["closed"]
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_MemoryRecvStream_closing() -> None:
|
||||
mrs = MemoryReceiveStream()
|
||||
# close with no pending data
|
||||
mrs.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
assert await mrs.receive_some(10) == b""
|
||||
# repeated closes ok
|
||||
mrs.close()
|
||||
# put_data now fails
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
mrs.put_data(b"123")
|
||||
|
||||
mrs2 = MemoryReceiveStream()
|
||||
# close with pending data
|
||||
mrs2.put_data(b"xyz")
|
||||
mrs2.close()
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await mrs2.receive_some(10)
|
||||
|
||||
|
||||
async def test_memory_stream_pump() -> None:
|
||||
mss = MemorySendStream()
|
||||
mrs = MemoryReceiveStream()
|
||||
|
||||
# no-op if no data present
|
||||
memory_stream_pump(mss, mrs)
|
||||
|
||||
await mss.send_all(b"123")
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b"123"
|
||||
|
||||
await mss.send_all(b"456")
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"4"
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert not memory_stream_pump(mss, mrs, max_bytes=1)
|
||||
assert await mrs.receive_some(10) == b"56"
|
||||
|
||||
mss.close()
|
||||
memory_stream_pump(mss, mrs)
|
||||
assert await mrs.receive_some(10) == b""
|
||||
|
||||
|
||||
async def test_memory_stream_one_way_pair() -> None:
|
||||
s, r = memory_stream_one_way_pair()
|
||||
assert s.send_all_hook is not None
|
||||
assert s.wait_send_all_might_not_block_hook is None
|
||||
assert s.close_hook is not None
|
||||
assert r.receive_some_hook is None
|
||||
await s.send_all(b"123")
|
||||
assert await r.receive_some(10) == b"123"
|
||||
|
||||
async def receiver(expected: bytes) -> None:
|
||||
assert await r.receive_some(10) == expected
|
||||
|
||||
# This fails if we pump on r.receive_some_hook; we need to pump on s.send_all_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"abc")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.send_all(b"abc")
|
||||
|
||||
# And this fails if we don't pump from close_hook
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
await s.aclose()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver, b"")
|
||||
await wait_all_tasks_blocked()
|
||||
s.close()
|
||||
|
||||
s, r = memory_stream_one_way_pair()
|
||||
|
||||
old = s.send_all_hook
|
||||
s.send_all_hook = None
|
||||
await s.send_all(b"456")
|
||||
|
||||
async def cancel_after_idle(nursery: Nursery) -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
async def check_for_cancel() -> None:
|
||||
with pytest.raises(_core.Cancelled):
|
||||
# This should block forever... or until cancelled. Even though we
|
||||
# sent some data on the send stream.
|
||||
await r.receive_some(10)
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(cancel_after_idle, nursery)
|
||||
nursery.start_soon(check_for_cancel)
|
||||
|
||||
s.send_all_hook = old
|
||||
await s.send_all(b"789")
|
||||
assert await r.receive_some(10) == b"456789"
|
||||
|
||||
|
||||
async def test_memory_stream_pair() -> None:
|
||||
a, b = memory_stream_pair()
|
||||
await a.send_all(b"123")
|
||||
await b.send_all(b"abc")
|
||||
assert await b.receive_some(10) == b"123"
|
||||
assert await a.receive_some(10) == b"abc"
|
||||
|
||||
await a.send_eof()
|
||||
assert await b.receive_some(10) == b""
|
||||
|
||||
async def sender() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
await b.send_all(b"xyz")
|
||||
|
||||
async def receiver() -> None:
|
||||
assert await a.receive_some(10) == b"xyz"
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(receiver)
|
||||
nursery.start_soon(sender)
|
||||
|
||||
|
||||
async def test_memory_streams_with_generic_tests() -> None:
|
||||
async def one_way_stream_maker() -> tuple[MemorySendStream, MemoryReceiveStream]:
|
||||
return memory_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, None)
|
||||
|
||||
async def half_closeable_stream_maker() -> tuple[
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
StapledStream[MemorySendStream, MemoryReceiveStream],
|
||||
]:
|
||||
return memory_stream_pair()
|
||||
|
||||
await check_half_closeable_stream(half_closeable_stream_maker, None)
|
||||
|
||||
|
||||
async def test_lockstep_streams_with_generic_tests() -> None:
|
||||
async def one_way_stream_maker() -> tuple[SendStream, ReceiveStream]:
|
||||
return lockstep_stream_one_way_pair()
|
||||
|
||||
await check_one_way_stream(one_way_stream_maker, one_way_stream_maker)
|
||||
|
||||
async def two_way_stream_maker() -> tuple[
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
StapledStream[SendStream, ReceiveStream],
|
||||
]:
|
||||
return lockstep_stream_pair()
|
||||
|
||||
await check_two_way_stream(two_way_stream_maker, two_way_stream_maker)
|
||||
|
||||
|
||||
async def test_open_stream_to_socket_listener() -> None:
|
||||
async def check(listener: SocketListener) -> None:
|
||||
async with listener:
|
||||
client_stream = await open_stream_to_socket_listener(listener)
|
||||
async with client_stream:
|
||||
server_stream = await listener.accept()
|
||||
async with server_stream:
|
||||
await client_stream.send_all(b"x")
|
||||
assert await server_stream.receive_some(1) == b"x"
|
||||
|
||||
# Listener bound to localhost
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("127.0.0.1", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
# Listener bound to IPv4 wildcard (needs special handling)
|
||||
sock = tsocket.socket()
|
||||
await sock.bind(("0.0.0.0", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
# true on all CI systems
|
||||
if can_bind_ipv6: # pragma: no branch
|
||||
# Listener bound to IPv6 wildcard (needs special handling)
|
||||
sock = tsocket.socket(family=tsocket.AF_INET6)
|
||||
await sock.bind(("::", 0))
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
if hasattr(tsocket, "AF_UNIX"):
|
||||
# Listener bound to Unix-domain socket
|
||||
sock = tsocket.socket(family=tsocket.AF_UNIX)
|
||||
# can't use pytest's tmpdir; if we try then macOS says "OSError:
|
||||
# AF_UNIX path too long"
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
path = f"{tmpdir}/sock"
|
||||
await sock.bind(path)
|
||||
sock.listen(10)
|
||||
await check(SocketListener(sock))
|
||||
|
||||
|
||||
def test_trio_test() -> None:
|
||||
async def busy_kitchen(
|
||||
*,
|
||||
mock_clock: object,
|
||||
autojump_clock: object,
|
||||
) -> None: ... # pragma: no cover
|
||||
|
||||
with pytest.raises(ValueError, match="^too many clocks spoil the broth!$"):
|
||||
trio_test(busy_kitchen)(
|
||||
mock_clock=MockClock(),
|
||||
autojump_clock=MockClock(autojump_threshold=0),
|
||||
)
|
@ -0,0 +1,374 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import sys
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import ExceptionGroup
|
||||
|
||||
|
||||
def wrap_escape(s: str) -> str:
|
||||
return "^" + re.escape(s) + "$"
|
||||
|
||||
|
||||
def test_raises_group() -> None:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=wrap_escape(
|
||||
f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.',
|
||||
),
|
||||
):
|
||||
RaisesGroup(TypeError())
|
||||
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
|
||||
with RaisesGroup(SyntaxError):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("foo", (SyntaxError(),))
|
||||
|
||||
# multiple exceptions
|
||||
with RaisesGroup(ValueError, SyntaxError):
|
||||
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
|
||||
|
||||
# order doesn't matter
|
||||
with RaisesGroup(SyntaxError, ValueError):
|
||||
raise ExceptionGroup("foo", (ValueError(), SyntaxError()))
|
||||
|
||||
# nested exceptions
|
||||
with RaisesGroup(RaisesGroup(ValueError)):
|
||||
raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),))
|
||||
|
||||
with RaisesGroup(
|
||||
SyntaxError,
|
||||
RaisesGroup(ValueError),
|
||||
RaisesGroup(RuntimeError),
|
||||
):
|
||||
raise ExceptionGroup(
|
||||
"foo",
|
||||
(
|
||||
SyntaxError(),
|
||||
ExceptionGroup("bar", (ValueError(),)),
|
||||
ExceptionGroup("", (RuntimeError(),)),
|
||||
),
|
||||
)
|
||||
|
||||
# will error if there's excess exceptions
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("", (ValueError(), ValueError()))
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ExceptionGroup("", (RuntimeError(), ValueError()))
|
||||
|
||||
# will error if there's missing exceptions
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, ValueError):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, SyntaxError):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
|
||||
def test_flatten_subgroups() -> None:
|
||||
# loose semantics, as with expect*
|
||||
with RaisesGroup(ValueError, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
|
||||
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(), TypeError())),))
|
||||
with RaisesGroup(ValueError, TypeError, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()]), TypeError()])
|
||||
|
||||
# mixed loose is possible if you want it to be at least N deep
|
||||
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
|
||||
raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),))
|
||||
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
|
||||
raise ExceptionGroup(
|
||||
"",
|
||||
(ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),),
|
||||
)
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(RaisesGroup(ValueError, flatten_subgroups=True)):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
# but not the other way around
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^You cannot specify a nested structure inside a RaisesGroup with",
|
||||
):
|
||||
RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore[call-overload]
|
||||
|
||||
|
||||
def test_catch_unwrapped_exceptions() -> None:
|
||||
# Catches lone exceptions with strict=False
|
||||
# just as except* would
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True):
|
||||
raise ValueError
|
||||
|
||||
# expecting multiple unwrapped exceptions is not possible
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^You cannot specify multiple exceptions with",
|
||||
):
|
||||
RaisesGroup(SyntaxError, ValueError, allow_unwrapped=True) # type: ignore[call-overload]
|
||||
# if users want one of several exception types they need to use a Matcher
|
||||
# (which the error message suggests)
|
||||
with RaisesGroup(
|
||||
Matcher(check=lambda e: isinstance(e, (SyntaxError, ValueError))),
|
||||
allow_unwrapped=True,
|
||||
):
|
||||
raise ValueError
|
||||
|
||||
# Unwrapped nested `RaisesGroup` is likely a user error, so we raise an error.
|
||||
with pytest.raises(ValueError, match="has no effect when expecting"):
|
||||
RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore[call-overload]
|
||||
|
||||
# But it *can* be used to check for nesting level +- 1 if they move it to
|
||||
# the nested RaisesGroup. Users should probably use `Matcher`s instead though.
|
||||
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
|
||||
with RaisesGroup(RaisesGroup(ValueError, allow_unwrapped=True)):
|
||||
raise ExceptionGroup("", [ValueError()])
|
||||
|
||||
# with allow_unwrapped=False (default) it will not be caught
|
||||
with pytest.raises(ValueError, match="^value error text$"):
|
||||
with RaisesGroup(ValueError):
|
||||
raise ValueError("value error text")
|
||||
|
||||
# allow_unwrapped on it's own won't match against nested groups
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
|
||||
|
||||
# for that you need both allow_unwrapped and flatten_subgroups
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True):
|
||||
raise ExceptionGroup("", [ExceptionGroup("", [ValueError()])])
|
||||
|
||||
# code coverage
|
||||
with pytest.raises(TypeError):
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True):
|
||||
raise TypeError
|
||||
|
||||
|
||||
def test_match() -> None:
|
||||
# supports match string
|
||||
with RaisesGroup(ValueError, match="bar"):
|
||||
raise ExceptionGroup("bar", (ValueError(),))
|
||||
|
||||
# now also works with ^$
|
||||
with RaisesGroup(ValueError, match="^bar$"):
|
||||
raise ExceptionGroup("bar", (ValueError(),))
|
||||
|
||||
# it also includes notes
|
||||
with RaisesGroup(ValueError, match="my note"):
|
||||
e = ExceptionGroup("bar", (ValueError(),))
|
||||
e.add_note("my note")
|
||||
raise e
|
||||
|
||||
# and technically you can match it all with ^$
|
||||
# but you're probably better off using a Matcher at that point
|
||||
with RaisesGroup(ValueError, match="^bar\nmy note$"):
|
||||
e = ExceptionGroup("bar", (ValueError(),))
|
||||
e.add_note("my note")
|
||||
raise e
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, match="foo"):
|
||||
raise ExceptionGroup("bar", (ValueError(),))
|
||||
|
||||
|
||||
def test_check() -> None:
|
||||
exc = ExceptionGroup("", (ValueError(),))
|
||||
with RaisesGroup(ValueError, check=lambda x: x is exc):
|
||||
raise exc
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(ValueError, check=lambda x: x is exc):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
|
||||
def test_unwrapped_match_check() -> None:
|
||||
def my_check(e: object) -> bool: # pragma: no cover
|
||||
return True
|
||||
|
||||
msg = (
|
||||
"`allow_unwrapped=True` bypasses the `match` and `check` parameters"
|
||||
" if the exception is unwrapped. If you intended to match/check the"
|
||||
" exception you should use a `Matcher` object. If you want to match/check"
|
||||
" the exceptiongroup when the exception *is* wrapped you need to"
|
||||
" do e.g. `if isinstance(exc.value, ExceptionGroup):"
|
||||
" assert RaisesGroup(...).matches(exc.value)` afterwards."
|
||||
)
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
RaisesGroup(ValueError, allow_unwrapped=True, match="foo") # type: ignore[call-overload]
|
||||
with pytest.raises(ValueError, match=re.escape(msg)):
|
||||
RaisesGroup(ValueError, allow_unwrapped=True, check=my_check) # type: ignore[call-overload]
|
||||
|
||||
# Users should instead use a Matcher
|
||||
rg = RaisesGroup(Matcher(ValueError, match="^foo$"), allow_unwrapped=True)
|
||||
with rg:
|
||||
raise ValueError("foo")
|
||||
with rg:
|
||||
raise ExceptionGroup("", [ValueError("foo")])
|
||||
|
||||
# or if they wanted to match/check the group, do a conditional `.matches()`
|
||||
with RaisesGroup(ValueError, allow_unwrapped=True) as exc:
|
||||
raise ExceptionGroup("bar", [ValueError("foo")])
|
||||
if isinstance(exc.value, ExceptionGroup): # pragma: no branch
|
||||
assert RaisesGroup(ValueError, match="bar").matches(exc.value)
|
||||
|
||||
|
||||
def test_RaisesGroup_matches() -> None:
|
||||
rg = RaisesGroup(ValueError)
|
||||
assert not rg.matches(None)
|
||||
assert not rg.matches(ValueError())
|
||||
assert rg.matches(ExceptionGroup("", (ValueError(),)))
|
||||
|
||||
|
||||
def test_message() -> None:
|
||||
def check_message(message: str, body: RaisesGroup[Any]) -> None:
|
||||
with pytest.raises(
|
||||
AssertionError,
|
||||
match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$",
|
||||
):
|
||||
with body:
|
||||
...
|
||||
|
||||
# basic
|
||||
check_message("ExceptionGroup(ValueError)", RaisesGroup(ValueError))
|
||||
# multiple exceptions
|
||||
check_message(
|
||||
"ExceptionGroup(ValueError, ValueError)",
|
||||
RaisesGroup(ValueError, ValueError),
|
||||
)
|
||||
# nested
|
||||
check_message(
|
||||
"ExceptionGroup(ExceptionGroup(ValueError))",
|
||||
RaisesGroup(RaisesGroup(ValueError)),
|
||||
)
|
||||
|
||||
# Matcher
|
||||
check_message(
|
||||
"ExceptionGroup(Matcher(ValueError, match='my_str'))",
|
||||
RaisesGroup(Matcher(ValueError, "my_str")),
|
||||
)
|
||||
check_message(
|
||||
"ExceptionGroup(Matcher(match='my_str'))",
|
||||
RaisesGroup(Matcher(match="my_str")),
|
||||
)
|
||||
|
||||
# BaseExceptionGroup
|
||||
check_message(
|
||||
"BaseExceptionGroup(KeyboardInterrupt)",
|
||||
RaisesGroup(KeyboardInterrupt),
|
||||
)
|
||||
# BaseExceptionGroup with type inside Matcher
|
||||
check_message(
|
||||
"BaseExceptionGroup(Matcher(KeyboardInterrupt))",
|
||||
RaisesGroup(Matcher(KeyboardInterrupt)),
|
||||
)
|
||||
# Base-ness transfers to parent containers
|
||||
check_message(
|
||||
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt))",
|
||||
RaisesGroup(RaisesGroup(KeyboardInterrupt)),
|
||||
)
|
||||
# but not to child containers
|
||||
check_message(
|
||||
"BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt), ExceptionGroup(ValueError))",
|
||||
RaisesGroup(RaisesGroup(KeyboardInterrupt), RaisesGroup(ValueError)),
|
||||
)
|
||||
|
||||
|
||||
def test_matcher() -> None:
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^You must specify at least one parameter to match on.$",
|
||||
):
|
||||
Matcher() # type: ignore[call-overload]
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"^exception_type {re.escape(repr(object))} must be a subclass of BaseException$",
|
||||
):
|
||||
Matcher(object) # type: ignore[type-var]
|
||||
|
||||
with RaisesGroup(Matcher(ValueError)):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(TypeError)):
|
||||
raise ExceptionGroup("", (ValueError(),))
|
||||
|
||||
|
||||
def test_matcher_match() -> None:
|
||||
with RaisesGroup(Matcher(ValueError, "foo")):
|
||||
raise ExceptionGroup("", (ValueError("foo"),))
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(ValueError, "foo")):
|
||||
raise ExceptionGroup("", (ValueError("bar"),))
|
||||
|
||||
# Can be used without specifying the type
|
||||
with RaisesGroup(Matcher(match="foo")):
|
||||
raise ExceptionGroup("", (ValueError("foo"),))
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(match="foo")):
|
||||
raise ExceptionGroup("", (ValueError("bar"),))
|
||||
|
||||
# check ^$
|
||||
with RaisesGroup(Matcher(ValueError, match="^bar$")):
|
||||
raise ExceptionGroup("", [ValueError("bar")])
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(ValueError, match="^bar$")):
|
||||
raise ExceptionGroup("", [ValueError("barr")])
|
||||
|
||||
|
||||
def test_Matcher_check() -> None:
|
||||
def check_oserror_and_errno_is_5(e: BaseException) -> bool:
|
||||
return isinstance(e, OSError) and e.errno == 5
|
||||
|
||||
with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)):
|
||||
raise ExceptionGroup("", (OSError(5, ""),))
|
||||
|
||||
# specifying exception_type narrows the parameter type to the callable
|
||||
def check_errno_is_5(e: OSError) -> bool:
|
||||
return e.errno == 5
|
||||
|
||||
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
|
||||
raise ExceptionGroup("", (OSError(5, ""),))
|
||||
|
||||
with pytest.raises(ExceptionGroup):
|
||||
with RaisesGroup(Matcher(OSError, check=check_errno_is_5)):
|
||||
raise ExceptionGroup("", (OSError(6, ""),))
|
||||
|
||||
|
||||
def test_matcher_tostring() -> None:
|
||||
assert str(Matcher(ValueError)) == "Matcher(ValueError)"
|
||||
assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')"
|
||||
pattern_no_flags = re.compile("noflag", 0)
|
||||
assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')"
|
||||
pattern_flags = re.compile("noflag", re.IGNORECASE)
|
||||
assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})"
|
||||
assert (
|
||||
str(Matcher(ValueError, match="re", check=bool))
|
||||
== f"Matcher(ValueError, match='re', check={bool!r})"
|
||||
)
|
||||
|
||||
|
||||
def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(
|
||||
trio.testing._raises_group,
|
||||
"ExceptionInfo",
|
||||
trio.testing._raises_group._ExceptionInfo,
|
||||
)
|
||||
with trio.testing.RaisesGroup(ValueError) as excinfo:
|
||||
raise ExceptionGroup("", (ValueError("hello"),))
|
||||
assert excinfo.type is ExceptionGroup
|
||||
assert excinfo.value.exceptions[0].args == ("hello",)
|
||||
assert isinstance(excinfo.tb, TracebackType)
|
1148
lib/python3.13/site-packages/trio/_tests/test_threads.py
Normal file
1148
lib/python3.13/site-packages/trio/_tests/test_threads.py
Normal file
File diff suppressed because it is too large
Load Diff
272
lib/python3.13/site-packages/trio/_tests/test_timeouts.py
Normal file
272
lib/python3.13/site-packages/trio/_tests/test_timeouts.py
Normal file
@ -0,0 +1,272 @@
|
||||
import time
|
||||
from typing import Awaitable, Callable, Protocol, TypeVar
|
||||
|
||||
import outcome
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
|
||||
from .. import _core
|
||||
from .._core._tests.tutil import slow
|
||||
from .._timeouts import (
|
||||
TooSlowError,
|
||||
fail_after,
|
||||
fail_at,
|
||||
move_on_after,
|
||||
move_on_at,
|
||||
sleep,
|
||||
sleep_forever,
|
||||
sleep_until,
|
||||
)
|
||||
from ..testing import assert_checkpoints
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def check_takes_about(f: Callable[[], Awaitable[T]], expected_dur: float) -> T:
|
||||
start = time.perf_counter()
|
||||
result = await outcome.acapture(f)
|
||||
dur = time.perf_counter() - start
|
||||
print(dur / expected_dur)
|
||||
# 1.5 is an arbitrary fudge factor because there's always some delay
|
||||
# between when we become eligible to wake up and when we actually do. We
|
||||
# used to sleep for 0.05, and regularly observed overruns of 1.6x on
|
||||
# Appveyor, and then started seeing overruns of 2.3x on Travis's macOS, so
|
||||
# now we bumped up the sleep to 1 second, marked the tests as slow, and
|
||||
# hopefully now the proportional error will be less huge.
|
||||
#
|
||||
# We also also for durations that are a hair shorter than expected. For
|
||||
# example, here's a run on Windows where a 1.0 second sleep was measured
|
||||
# to take 0.9999999999999858 seconds:
|
||||
# https://ci.appveyor.com/project/njsmith/trio/build/1.0.768/job/3lbdyxl63q3h9s21
|
||||
# I believe that what happened here is that Windows's low clock resolution
|
||||
# meant that our calls to time.monotonic() returned exactly the same
|
||||
# values as the calls inside the actual run loop, but the two subtractions
|
||||
# returned slightly different values because the run loop's clock adds a
|
||||
# random floating point offset to both times, which should cancel out, but
|
||||
# lol floating point we got slightly different rounding errors. (That
|
||||
# value above is exactly 128 ULPs below 1.0, which would make sense if it
|
||||
# started as a 1 ULP error at a different dynamic range.)
|
||||
assert (1 - 1e-8) <= (dur / expected_dur) < 1.5
|
||||
|
||||
return result.unwrap()
|
||||
|
||||
|
||||
# How long to (attempt to) sleep for when testing. Smaller numbers make the
|
||||
# test suite go faster.
|
||||
TARGET = 1.0
|
||||
|
||||
|
||||
@slow
|
||||
async def test_sleep() -> None:
|
||||
async def sleep_1() -> None:
|
||||
await sleep_until(_core.current_time() + TARGET)
|
||||
|
||||
await check_takes_about(sleep_1, TARGET)
|
||||
|
||||
async def sleep_2() -> None:
|
||||
await sleep(TARGET)
|
||||
|
||||
await check_takes_about(sleep_2, TARGET)
|
||||
|
||||
with assert_checkpoints():
|
||||
await sleep(0)
|
||||
# This also serves as a test of the trivial move_on_at
|
||||
with move_on_at(_core.current_time()):
|
||||
with pytest.raises(_core.Cancelled):
|
||||
await sleep(0)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_move_on_after() -> None:
|
||||
async def sleep_3() -> None:
|
||||
with move_on_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
await check_takes_about(sleep_3, TARGET)
|
||||
|
||||
|
||||
async def test_cannot_wake_sleep_forever() -> None:
|
||||
# Test an error occurs if you manually wake sleep_forever().
|
||||
task = trio.lowlevel.current_task()
|
||||
|
||||
async def wake_task() -> None:
|
||||
await trio.lowlevel.checkpoint()
|
||||
trio.lowlevel.reschedule(task, outcome.Value(None))
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
nursery.start_soon(wake_task)
|
||||
with pytest.raises(RuntimeError):
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
class TimeoutScope(Protocol):
|
||||
def __call__(self, seconds: float, *, shield: bool) -> trio.CancelScope: ...
|
||||
|
||||
|
||||
@pytest.mark.parametrize("scope", [move_on_after, fail_after])
|
||||
async def test_context_shields_from_outer(scope: TimeoutScope) -> None:
|
||||
with _core.CancelScope() as outer, scope(TARGET, shield=True) as inner:
|
||||
outer.cancel()
|
||||
try:
|
||||
await trio.lowlevel.checkpoint()
|
||||
except trio.Cancelled:
|
||||
pytest.fail("shield didn't work")
|
||||
inner.shield = False
|
||||
with pytest.raises(trio.Cancelled):
|
||||
await trio.lowlevel.checkpoint()
|
||||
|
||||
|
||||
@slow
|
||||
async def test_move_on_after_moves_on_even_if_shielded() -> None:
|
||||
async def task() -> None:
|
||||
with _core.CancelScope() as outer, move_on_after(TARGET, shield=True):
|
||||
outer.cancel()
|
||||
# The outer scope is cancelled, but this task is protected by the
|
||||
# shield, so it manages to get to sleep until deadline is met
|
||||
await sleep_forever()
|
||||
|
||||
await check_takes_about(task, TARGET)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_fail_after_fails_even_if_shielded() -> None:
|
||||
async def task() -> None:
|
||||
with pytest.raises(TooSlowError), _core.CancelScope() as outer, fail_after(
|
||||
TARGET,
|
||||
shield=True,
|
||||
):
|
||||
outer.cancel()
|
||||
# The outer scope is cancelled, but this task is protected by the
|
||||
# shield, so it manages to get to sleep until deadline is met
|
||||
await sleep_forever()
|
||||
|
||||
await check_takes_about(task, TARGET)
|
||||
|
||||
|
||||
@slow
|
||||
async def test_fail() -> None:
|
||||
async def sleep_4() -> None:
|
||||
with fail_at(_core.current_time() + TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_4, TARGET)
|
||||
|
||||
with fail_at(_core.current_time() + 100):
|
||||
await sleep(0)
|
||||
|
||||
async def sleep_5() -> None:
|
||||
with fail_after(TARGET):
|
||||
await sleep(100)
|
||||
|
||||
with pytest.raises(TooSlowError):
|
||||
await check_takes_about(sleep_5, TARGET)
|
||||
|
||||
with fail_after(100):
|
||||
await sleep(0)
|
||||
|
||||
|
||||
async def test_timeouts_raise_value_error() -> None:
|
||||
# deadlines are allowed to be negative, but not delays.
|
||||
# neither delays nor deadlines are allowed to be NaN
|
||||
|
||||
nan = float("nan")
|
||||
|
||||
for fun, val in (
|
||||
(sleep, -1),
|
||||
(sleep, nan),
|
||||
(sleep_until, nan),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
|
||||
):
|
||||
await fun(val)
|
||||
|
||||
for cm, val in (
|
||||
(fail_after, -1),
|
||||
(fail_after, nan),
|
||||
(fail_at, nan),
|
||||
(move_on_after, -1),
|
||||
(move_on_after, nan),
|
||||
(move_on_at, nan),
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="^(deadline|`seconds`) must (not )*be (non-negative|NaN)$",
|
||||
):
|
||||
with cm(val):
|
||||
pass # pragma: no cover
|
||||
|
||||
|
||||
async def test_timeout_deadline_on_entry(mock_clock: _core.MockClock) -> None:
|
||||
rcs = move_on_after(5)
|
||||
assert rcs.relative_deadline == 5
|
||||
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
with rcs as cs:
|
||||
assert cs.is_relative is None
|
||||
|
||||
# This would previously be start+2
|
||||
assert cs.deadline == start + 5
|
||||
assert cs.relative_deadline == 5
|
||||
|
||||
cs.deadline = start + 3
|
||||
assert cs.deadline == start + 3
|
||||
assert cs.relative_deadline == 3
|
||||
|
||||
cs.relative_deadline = 4
|
||||
assert cs.deadline == start + 4
|
||||
assert cs.relative_deadline == 4
|
||||
|
||||
rcs = move_on_after(5)
|
||||
assert rcs.shield is False
|
||||
rcs.shield = True
|
||||
assert rcs.shield is True
|
||||
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
with rcs as cs:
|
||||
assert cs.deadline == start + 5
|
||||
|
||||
assert rcs is cs
|
||||
|
||||
|
||||
async def test_invalid_access_unentered(mock_clock: _core.MockClock) -> None:
|
||||
cs = move_on_after(5)
|
||||
mock_clock.jump(3)
|
||||
start = _core.current_time()
|
||||
|
||||
match_str = "^unentered relative cancel scope does not have an absolute deadline"
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
assert cs.deadline == start + 5
|
||||
mock_clock.jump(1)
|
||||
# this is hella sketchy, but they *have* been warned
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
assert cs.deadline == start + 6
|
||||
|
||||
with pytest.warns(DeprecationWarning, match=match_str):
|
||||
cs.deadline = 7
|
||||
# now transformed into absolute
|
||||
assert cs.deadline == 7
|
||||
assert not cs.is_relative
|
||||
|
||||
cs = move_on_at(5)
|
||||
|
||||
match_str = (
|
||||
"^unentered non-relative cancel scope does not have a relative deadline$"
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=match_str):
|
||||
assert cs.relative_deadline
|
||||
with pytest.raises(RuntimeError, match=match_str):
|
||||
cs.relative_deadline = 7
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="not implemented")
|
||||
async def test_fail_access_before_entering() -> None: # pragma: no cover
|
||||
my_fail_at = fail_at(5)
|
||||
assert my_fail_at.deadline # type: ignore[attr-defined]
|
||||
my_fail_after = fail_after(5)
|
||||
assert my_fail_after.relative_deadline # type: ignore[attr-defined]
|
66
lib/python3.13/site-packages/trio/_tests/test_tracing.py
Normal file
66
lib/python3.13/site-packages/trio/_tests/test_tracing.py
Normal file
@ -0,0 +1,66 @@
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import trio
|
||||
|
||||
|
||||
async def coro1(event: trio.Event) -> None:
|
||||
event.set()
|
||||
await trio.sleep_forever()
|
||||
|
||||
|
||||
async def coro2(event: trio.Event) -> None:
|
||||
await coro1(event)
|
||||
|
||||
|
||||
async def coro3(event: trio.Event) -> None:
|
||||
await coro2(event)
|
||||
|
||||
|
||||
async def coro2_async_gen(event: trio.Event) -> AsyncGenerator[None, None]:
|
||||
# mypy does not like `yield await trio.lowlevel.checkpoint()` - but that
|
||||
# should be equivalent to splitting the statement
|
||||
await trio.lowlevel.checkpoint()
|
||||
yield
|
||||
await coro1(event)
|
||||
yield # pragma: no cover
|
||||
await trio.lowlevel.checkpoint() # pragma: no cover
|
||||
yield # pragma: no cover
|
||||
|
||||
|
||||
async def coro3_async_gen(event: trio.Event) -> None:
|
||||
async for _ in coro2_async_gen(event):
|
||||
pass
|
||||
|
||||
|
||||
async def test_task_iter_await_frames() -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
event = trio.Event()
|
||||
nursery.start_soon(coro3, event)
|
||||
await event.wait()
|
||||
|
||||
(task,) = nursery.child_tasks
|
||||
|
||||
assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [
|
||||
"coro3",
|
||||
"coro2",
|
||||
"coro1",
|
||||
]
|
||||
|
||||
nursery.cancel_scope.cancel()
|
||||
|
||||
|
||||
async def test_task_iter_await_frames_async_gen() -> None:
|
||||
async with trio.open_nursery() as nursery:
|
||||
event = trio.Event()
|
||||
nursery.start_soon(coro3_async_gen, event)
|
||||
await event.wait()
|
||||
|
||||
(task,) = nursery.child_tasks
|
||||
|
||||
assert [frame.f_code.co_name for frame, _ in task.iter_await_frames()][:3] == [
|
||||
"coro3_async_gen",
|
||||
"coro2_async_gen",
|
||||
"coro1",
|
||||
]
|
||||
|
||||
nursery.cancel_scope.cancel()
|
8
lib/python3.13/site-packages/trio/_tests/test_trio.py
Normal file
8
lib/python3.13/site-packages/trio/_tests/test_trio.py
Normal file
@ -0,0 +1,8 @@
|
||||
def test_trio_import() -> None:
|
||||
import sys
|
||||
|
||||
for module in list(sys.modules.keys()):
|
||||
if module.startswith("trio"):
|
||||
del sys.modules[module]
|
||||
|
||||
import trio # noqa: F401
|
282
lib/python3.13/site-packages/trio/_tests/test_unix_pipes.py
Normal file
282
lib/python3.13/site-packages/trio/_tests/test_unix_pipes.py
Normal file
@ -0,0 +1,282 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import errno
|
||||
import os
|
||||
import select
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core
|
||||
from .._core._tests.tutil import gc_collect_harder, skip_if_fbsd_pipes_broken
|
||||
from ..testing import check_one_way_stream, wait_all_tasks_blocked
|
||||
|
||||
posix = os.name == "posix"
|
||||
pytestmark = pytest.mark.skipif(not posix, reason="posix only")
|
||||
|
||||
assert not TYPE_CHECKING or sys.platform == "unix"
|
||||
|
||||
if posix:
|
||||
from .._unix_pipes import FdStream
|
||||
else:
|
||||
with pytest.raises(ImportError):
|
||||
from .._unix_pipes import FdStream
|
||||
|
||||
|
||||
async def make_pipe() -> tuple[FdStream, FdStream]:
|
||||
"""Makes a new pair of pipes."""
|
||||
(r, w) = os.pipe()
|
||||
return FdStream(w), FdStream(r)
|
||||
|
||||
|
||||
async def make_clogged_pipe():
|
||||
s, r = await make_pipe()
|
||||
try:
|
||||
while True:
|
||||
# We want to totally fill up the pipe buffer.
|
||||
# This requires working around a weird feature that POSIX pipes
|
||||
# have.
|
||||
# If you do a write of <= PIPE_BUF bytes, then it's guaranteed
|
||||
# to either complete entirely, or not at all. So if we tried to
|
||||
# write PIPE_BUF bytes, and the buffer's free space is only
|
||||
# PIPE_BUF/2, then the write will raise BlockingIOError... even
|
||||
# though a smaller write could still succeed! To avoid this,
|
||||
# make sure to write >PIPE_BUF bytes each time, which disables
|
||||
# the special behavior.
|
||||
# For details, search for PIPE_BUF here:
|
||||
# http://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html
|
||||
|
||||
# for the getattr:
|
||||
# https://bitbucket.org/pypy/pypy/issues/2876/selectpipe_buf-is-missing-on-pypy3
|
||||
buf_size = getattr(select, "PIPE_BUF", 8192)
|
||||
os.write(s.fileno(), b"x" * buf_size * 2)
|
||||
except BlockingIOError:
|
||||
pass
|
||||
return s, r
|
||||
|
||||
|
||||
async def test_send_pipe() -> None:
|
||||
r, w = os.pipe()
|
||||
async with FdStream(w) as send:
|
||||
assert send.fileno() == w
|
||||
await send.send_all(b"123")
|
||||
assert (os.read(r, 8)) == b"123"
|
||||
|
||||
os.close(r)
|
||||
|
||||
|
||||
async def test_receive_pipe() -> None:
|
||||
r, w = os.pipe()
|
||||
async with FdStream(r) as recv:
|
||||
assert (recv.fileno()) == r
|
||||
os.write(w, b"123")
|
||||
assert (await recv.receive_some(8)) == b"123"
|
||||
|
||||
os.close(w)
|
||||
|
||||
|
||||
async def test_pipes_combined() -> None:
|
||||
write, read = await make_pipe()
|
||||
count = 2**20
|
||||
|
||||
async def sender() -> None:
|
||||
big = bytearray(count)
|
||||
await write.send_all(big)
|
||||
|
||||
async def reader() -> None:
|
||||
await wait_all_tasks_blocked()
|
||||
received = 0
|
||||
while received < count:
|
||||
received += len(await read.receive_some(4096))
|
||||
|
||||
assert received == count
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(sender)
|
||||
n.start_soon(reader)
|
||||
|
||||
await read.aclose()
|
||||
await write.aclose()
|
||||
|
||||
|
||||
async def test_pipe_errors() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
FdStream(None)
|
||||
|
||||
r, w = os.pipe()
|
||||
os.close(w)
|
||||
async with FdStream(r) as s:
|
||||
with pytest.raises(ValueError, match="^max_bytes must be integer >= 1$"):
|
||||
await s.receive_some(0)
|
||||
|
||||
|
||||
async def test_del() -> None:
|
||||
w, r = await make_pipe()
|
||||
f1, f2 = w.fileno(), r.fileno()
|
||||
del w, r
|
||||
gc_collect_harder()
|
||||
|
||||
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
|
||||
os.close(f1)
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
|
||||
os.close(f2)
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
|
||||
async def test_async_with() -> None:
|
||||
w, r = await make_pipe()
|
||||
async with w, r:
|
||||
pass
|
||||
|
||||
assert w.fileno() == -1
|
||||
assert r.fileno() == -1
|
||||
|
||||
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
|
||||
os.close(w.fileno())
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
with pytest.raises(OSError, match="Bad file descriptor$") as excinfo:
|
||||
os.close(r.fileno())
|
||||
assert excinfo.value.errno == errno.EBADF
|
||||
|
||||
|
||||
async def test_misdirected_aclose_regression() -> None:
|
||||
# https://github.com/python-trio/trio/issues/661#issuecomment-456582356
|
||||
w, r = await make_pipe()
|
||||
old_r_fd = r.fileno()
|
||||
|
||||
# Close the original objects
|
||||
await w.aclose()
|
||||
await r.aclose()
|
||||
|
||||
# Do a little dance to get a new pipe whose receive handle matches the old
|
||||
# receive handle.
|
||||
r2_fd, w2_fd = os.pipe()
|
||||
if r2_fd != old_r_fd: # pragma: no cover
|
||||
os.dup2(r2_fd, old_r_fd)
|
||||
os.close(r2_fd)
|
||||
async with FdStream(old_r_fd) as r2:
|
||||
assert r2.fileno() == old_r_fd
|
||||
|
||||
# And now set up a background task that's working on the new receive
|
||||
# handle
|
||||
async def expect_eof() -> None:
|
||||
assert await r2.receive_some(10) == b""
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_eof)
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# Here's the key test: does calling aclose() again on the *old*
|
||||
# handle, cause the task blocked on the *new* handle to raise
|
||||
# ClosedResourceError?
|
||||
await r.aclose()
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
# Guess we survived! Close the new write handle so that the task
|
||||
# gets an EOF and can exit cleanly.
|
||||
os.close(w2_fd)
|
||||
|
||||
|
||||
async def test_close_at_bad_time_for_receive_some(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
# We used to have race conditions where if one task was using the pipe,
|
||||
# and another closed it at *just* the wrong moment, it would give an
|
||||
# unexpected error instead of ClosedResourceError:
|
||||
# https://github.com/python-trio/trio/issues/661
|
||||
#
|
||||
# This tests what happens if the pipe gets closed in the moment *between*
|
||||
# when receive_some wakes up, and when it tries to call os.read
|
||||
async def expect_closedresourceerror() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
orig_wait_readable = _core._run.TheIOManager.wait_readable
|
||||
|
||||
async def patched_wait_readable(*args, **kwargs) -> None:
|
||||
await orig_wait_readable(*args, **kwargs)
|
||||
await r.aclose()
|
||||
|
||||
monkeypatch.setattr(_core._run.TheIOManager, "wait_readable", patched_wait_readable)
|
||||
s, r = await make_pipe()
|
||||
async with s, r:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_closedresourceerror)
|
||||
await wait_all_tasks_blocked()
|
||||
# Trigger everything by waking up the receiver
|
||||
await s.send_all(b"x")
|
||||
|
||||
|
||||
async def test_close_at_bad_time_for_send_all(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# We used to have race conditions where if one task was using the pipe,
|
||||
# and another closed it at *just* the wrong moment, it would give an
|
||||
# unexpected error instead of ClosedResourceError:
|
||||
# https://github.com/python-trio/trio/issues/661
|
||||
#
|
||||
# This tests what happens if the pipe gets closed in the moment *between*
|
||||
# when send_all wakes up, and when it tries to call os.write
|
||||
async def expect_closedresourceerror() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await s.send_all(b"x" * 100)
|
||||
|
||||
orig_wait_writable = _core._run.TheIOManager.wait_writable
|
||||
|
||||
async def patched_wait_writable(*args, **kwargs) -> None:
|
||||
await orig_wait_writable(*args, **kwargs)
|
||||
await s.aclose()
|
||||
|
||||
monkeypatch.setattr(_core._run.TheIOManager, "wait_writable", patched_wait_writable)
|
||||
s, r = await make_clogged_pipe()
|
||||
async with s, r:
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(expect_closedresourceerror)
|
||||
await wait_all_tasks_blocked()
|
||||
# Trigger everything by waking up the sender. On ppc64el, PIPE_BUF
|
||||
# is 8192 but make_clogged_pipe() ends up writing a total of
|
||||
# 1048576 bytes before the pipe is full, and then a subsequent
|
||||
# receive_some(10000) isn't sufficient for orig_wait_writable() to
|
||||
# return for our subsequent aclose() call. It's necessary to empty
|
||||
# the pipe further before this happens. So we loop here until the
|
||||
# pipe is empty to make sure that the sender wakes up even in this
|
||||
# case. Otherwise patched_wait_writable() never gets to the
|
||||
# aclose(), so expect_closedresourceerror() never returns, the
|
||||
# nursery never finishes all tasks and this test hangs.
|
||||
received_data = await r.receive_some(10000)
|
||||
while received_data:
|
||||
received_data = await r.receive_some(10000)
|
||||
|
||||
|
||||
# On FreeBSD, directories are readable, and we haven't found any other trick
|
||||
# for making an unreadable fd, so there's no way to run this test. Fortunately
|
||||
# the logic this is testing doesn't depend on the platform, so testing on
|
||||
# other platforms is probably good enough.
|
||||
@pytest.mark.skipif(
|
||||
sys.platform.startswith("freebsd"),
|
||||
reason="no way to make read() return a bizarro error on FreeBSD",
|
||||
)
|
||||
async def test_bizarro_OSError_from_receive() -> None:
|
||||
# Make sure that if the read syscall returns some bizarro error, then we
|
||||
# get a BrokenResourceError. This is incredibly unlikely; there's almost
|
||||
# no way to trigger a failure here intentionally (except for EBADF, but we
|
||||
# exploit that to detect file closure, so it takes a different path). So
|
||||
# we set up a strange scenario where the pipe fd somehow transmutes into a
|
||||
# directory fd, causing os.read to raise IsADirectoryError (yes, that's a
|
||||
# real built-in exception type).
|
||||
s, r = await make_pipe()
|
||||
async with s, r:
|
||||
dir_fd = os.open("/", os.O_DIRECTORY, 0)
|
||||
try:
|
||||
os.dup2(dir_fd, r.fileno())
|
||||
with pytest.raises(_core.BrokenResourceError):
|
||||
await r.receive_some(10)
|
||||
finally:
|
||||
os.close(dir_fd)
|
||||
|
||||
|
||||
@skip_if_fbsd_pipes_broken
|
||||
async def test_pipe_fully() -> None:
|
||||
await check_one_way_stream(make_pipe, make_clogged_pipe)
|
275
lib/python3.13/site-packages/trio/_tests/test_util.py
Normal file
275
lib/python3.13/site-packages/trio/_tests/test_util.py
Normal file
@ -0,0 +1,275 @@
|
||||
import signal
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import pytest
|
||||
|
||||
import trio
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
|
||||
from .. import _core
|
||||
from .._core._tests.tutil import (
|
||||
create_asyncio_future_in_new_loop,
|
||||
ignore_coroutine_never_awaited_warnings,
|
||||
)
|
||||
from .._util import (
|
||||
ConflictDetector,
|
||||
NoPublicConstructor,
|
||||
coroutine_or_error,
|
||||
final,
|
||||
fixup_module_metadata,
|
||||
generic_function,
|
||||
is_main_thread,
|
||||
signal_raise,
|
||||
)
|
||||
from ..testing import wait_all_tasks_blocked
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def test_signal_raise() -> None:
|
||||
record = []
|
||||
|
||||
def handler(signum: int, _: object) -> None:
|
||||
record.append(signum)
|
||||
|
||||
old = signal.signal(signal.SIGFPE, handler)
|
||||
try:
|
||||
signal_raise(signal.SIGFPE)
|
||||
finally:
|
||||
signal.signal(signal.SIGFPE, old)
|
||||
assert record == [signal.SIGFPE]
|
||||
|
||||
|
||||
async def test_ConflictDetector() -> None:
|
||||
ul1 = ConflictDetector("ul1")
|
||||
ul2 = ConflictDetector("ul2")
|
||||
|
||||
with ul1:
|
||||
with ul2:
|
||||
print("ok")
|
||||
|
||||
with pytest.raises(_core.BusyResourceError, match="ul1"):
|
||||
with ul1:
|
||||
with ul1:
|
||||
pass # pragma: no cover
|
||||
|
||||
async def wait_with_ul1() -> None:
|
||||
with ul1:
|
||||
await wait_all_tasks_blocked()
|
||||
|
||||
with RaisesGroup(Matcher(_core.BusyResourceError, "ul1")):
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(wait_with_ul1)
|
||||
nursery.start_soon(wait_with_ul1)
|
||||
|
||||
|
||||
def test_module_metadata_is_fixed_up() -> None:
|
||||
import trio
|
||||
import trio.testing
|
||||
|
||||
assert trio.Cancelled.__module__ == "trio"
|
||||
assert trio.open_nursery.__module__ == "trio"
|
||||
assert trio.abc.Stream.__module__ == "trio.abc"
|
||||
assert trio.lowlevel.wait_task_rescheduled.__module__ == "trio.lowlevel"
|
||||
assert trio.testing.trio_test.__module__ == "trio.testing"
|
||||
|
||||
# Also check methods
|
||||
assert trio.lowlevel.ParkingLot.__init__.__module__ == "trio.lowlevel"
|
||||
assert trio.abc.Stream.send_all.__module__ == "trio.abc"
|
||||
|
||||
# And names
|
||||
assert trio.Cancelled.__name__ == "Cancelled"
|
||||
assert trio.Cancelled.__qualname__ == "Cancelled"
|
||||
assert trio.abc.SendStream.send_all.__name__ == "send_all"
|
||||
assert trio.abc.SendStream.send_all.__qualname__ == "SendStream.send_all"
|
||||
assert trio.to_thread.__name__ == "trio.to_thread"
|
||||
assert trio.to_thread.run_sync.__name__ == "run_sync"
|
||||
assert trio.to_thread.run_sync.__qualname__ == "run_sync"
|
||||
|
||||
|
||||
async def test_is_main_thread() -> None:
|
||||
assert is_main_thread()
|
||||
|
||||
def not_main_thread() -> None:
|
||||
assert not is_main_thread()
|
||||
|
||||
await trio.to_thread.run_sync(not_main_thread)
|
||||
|
||||
|
||||
# @coroutine is deprecated since python 3.8, which is fine with us.
|
||||
@pytest.mark.filterwarnings("ignore:.*@coroutine.*:DeprecationWarning")
|
||||
def test_coroutine_or_error() -> None:
|
||||
class Deferred:
|
||||
"Just kidding"
|
||||
|
||||
with ignore_coroutine_never_awaited_warnings():
|
||||
|
||||
async def f() -> None: # pragma: no cover
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(f()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "expecting an async function" in str(excinfo.value)
|
||||
|
||||
import asyncio
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
# not bothering to type this one
|
||||
@asyncio.coroutine # type: ignore[misc]
|
||||
def generator_based_coro() -> Any: # pragma: no cover
|
||||
yield from asyncio.sleep(1)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(generator_based_coro()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(create_asyncio_future_in_new_loop()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
# does not raise arg-type error
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(create_asyncio_future_in_new_loop) # type: ignore[unused-coroutine]
|
||||
assert "asyncio" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(Deferred()) # type: ignore[arg-type, unused-coroutine]
|
||||
assert "twisted" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(lambda: Deferred()) # type: ignore[arg-type, unused-coroutine, return-value]
|
||||
assert "twisted" in str(excinfo.value)
|
||||
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(len, [[1, 2, 3]]) # type: ignore[arg-type, unused-coroutine]
|
||||
|
||||
assert "appears to be synchronous" in str(excinfo.value)
|
||||
|
||||
async def async_gen(_: object) -> Any: # pragma: no cover
|
||||
yield
|
||||
|
||||
# does not give arg-type typing error
|
||||
with pytest.raises(TypeError) as excinfo:
|
||||
coroutine_or_error(async_gen, [0]) # type: ignore[unused-coroutine]
|
||||
msg = "expected an async function but got an async generator"
|
||||
assert msg in str(excinfo.value)
|
||||
|
||||
# Make sure no references are kept around to keep anything alive
|
||||
del excinfo
|
||||
|
||||
|
||||
def test_generic_function() -> None:
|
||||
@generic_function # Decorated function contains "Any".
|
||||
def test_func(arg: T) -> T: # type: ignore[misc]
|
||||
"""Look, a docstring!"""
|
||||
return arg
|
||||
|
||||
assert test_func is test_func[int] is test_func[int, str]
|
||||
assert test_func(42) == test_func[int](42) == 42
|
||||
assert test_func.__doc__ == "Look, a docstring!"
|
||||
assert test_func.__qualname__ == "test_generic_function.<locals>.test_func" # type: ignore[attr-defined]
|
||||
assert test_func.__name__ == "test_func" # type: ignore[attr-defined]
|
||||
assert test_func.__module__ == __name__
|
||||
|
||||
|
||||
def test_final_decorator() -> None:
|
||||
"""Test that subclassing a @final-annotated class is not allowed.
|
||||
|
||||
This checks both runtime results, and verifies that type checkers detect
|
||||
the error statically through the type-ignore comment.
|
||||
"""
|
||||
|
||||
@final
|
||||
class FinalClass:
|
||||
pass
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class SubClass(FinalClass): # type: ignore[misc]
|
||||
pass
|
||||
|
||||
|
||||
def test_no_public_constructor_metaclass() -> None:
|
||||
"""The NoPublicConstructor metaclass prevents calling the constructor directly."""
|
||||
|
||||
class SpecialClass(metaclass=NoPublicConstructor):
|
||||
def __init__(self, a: int, b: float) -> None:
|
||||
"""Check arguments can be passed to __init__."""
|
||||
assert a == 8
|
||||
assert b == 3.14
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
SpecialClass(8, 3.14)
|
||||
|
||||
# Private constructor should not raise, and passes args to __init__.
|
||||
assert isinstance(SpecialClass._create(8, b=3.14), SpecialClass)
|
||||
|
||||
|
||||
def test_fixup_module_metadata() -> None:
|
||||
# Ignores modules not in the trio.X tree.
|
||||
non_trio_module = types.ModuleType("not_trio")
|
||||
non_trio_module.some_func = lambda: None # type: ignore[attr-defined]
|
||||
non_trio_module.some_func.__name__ = "some_func"
|
||||
non_trio_module.some_func.__qualname__ = "some_func"
|
||||
|
||||
fixup_module_metadata(non_trio_module.__name__, vars(non_trio_module))
|
||||
|
||||
assert non_trio_module.some_func.__name__ == "some_func"
|
||||
assert non_trio_module.some_func.__qualname__ == "some_func"
|
||||
|
||||
# Bulild up a fake module to test. Just use lambdas since all we care about is the names.
|
||||
mod = types.ModuleType("trio._somemodule_impl")
|
||||
mod.some_func = lambda: None # type: ignore[attr-defined]
|
||||
mod.some_func.__name__ = "_something_else"
|
||||
mod.some_func.__qualname__ = "_something_else"
|
||||
|
||||
# No __module__ means it's unchanged.
|
||||
mod.not_funclike = types.SimpleNamespace() # type: ignore[attr-defined]
|
||||
mod.not_funclike.__name__ = "not_funclike"
|
||||
|
||||
# Check __qualname__ being absent works.
|
||||
mod.only_has_name = types.SimpleNamespace() # type: ignore[attr-defined]
|
||||
mod.only_has_name.__module__ = "trio._somemodule_impl"
|
||||
mod.only_has_name.__name__ = "only_name"
|
||||
|
||||
# Underscored names are unchanged.
|
||||
mod._private = lambda: None # type: ignore[attr-defined]
|
||||
mod._private.__module__ = "trio._somemodule_impl"
|
||||
mod._private.__name__ = mod._private.__qualname__ = "_private"
|
||||
|
||||
# We recurse into classes.
|
||||
mod.SomeClass = type( # type: ignore[attr-defined]
|
||||
"SomeClass",
|
||||
(),
|
||||
{
|
||||
"__init__": lambda self: None,
|
||||
"method": lambda self: None,
|
||||
},
|
||||
)
|
||||
# Reference loop is fine.
|
||||
mod.SomeClass.recursion = mod.SomeClass # type: ignore[attr-defined]
|
||||
|
||||
fixup_module_metadata("trio.somemodule", vars(mod))
|
||||
assert mod.some_func.__name__ == "some_func"
|
||||
assert mod.some_func.__module__ == "trio.somemodule"
|
||||
assert mod.some_func.__qualname__ == "some_func"
|
||||
|
||||
assert mod.not_funclike.__name__ == "not_funclike"
|
||||
assert mod._private.__name__ == "_private"
|
||||
assert mod._private.__module__ == "trio._somemodule_impl"
|
||||
assert mod._private.__qualname__ == "_private"
|
||||
|
||||
assert mod.only_has_name.__name__ == "only_has_name"
|
||||
assert mod.only_has_name.__module__ == "trio.somemodule"
|
||||
assert not hasattr(mod.only_has_name, "__qualname__")
|
||||
|
||||
assert mod.SomeClass.method.__name__ == "method" # type: ignore[attr-defined]
|
||||
assert mod.SomeClass.method.__module__ == "trio.somemodule" # type: ignore[attr-defined]
|
||||
assert mod.SomeClass.method.__qualname__ == "SomeClass.method" # type: ignore[attr-defined]
|
||||
# Make coverage happy.
|
||||
non_trio_module.some_func()
|
||||
mod.some_func()
|
||||
mod._private()
|
||||
mod.SomeClass().method()
|
225
lib/python3.13/site-packages/trio/_tests/test_wait_for_object.py
Normal file
225
lib/python3.13/site-packages/trio/_tests/test_wait_for_object.py
Normal file
@ -0,0 +1,225 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
on_windows = os.name == "nt"
|
||||
# Mark all the tests in this file as being windows-only
|
||||
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
|
||||
|
||||
import trio
|
||||
|
||||
from .. import _core, _timeouts
|
||||
from .._core._tests.tutil import slow
|
||||
|
||||
if on_windows:
|
||||
from .._core._windows_cffi import Handle, ffi, kernel32
|
||||
from .._wait_for_object import WaitForMultipleObjects_sync, WaitForSingleObject
|
||||
|
||||
|
||||
async def test_WaitForMultipleObjects_sync() -> None:
|
||||
# This does a series of tests where we set/close the handle before
|
||||
# initiating the waiting for it.
|
||||
#
|
||||
# Note that closing the handle (not signaling) will cause the
|
||||
# *initiation* of a wait to return immediately. But closing a handle
|
||||
# that is already being waited on will not stop whatever is waiting
|
||||
# for it.
|
||||
|
||||
# One handle
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle1)
|
||||
WaitForMultipleObjects_sync(handle1)
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync one OK")
|
||||
|
||||
# Two handles, signal first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle1)
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync set first OK")
|
||||
|
||||
# Two handles, signal second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle2)
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync set second OK")
|
||||
|
||||
# Two handles, close first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle1)
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync close first OK")
|
||||
|
||||
# Two handles, close second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle2)
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
WaitForMultipleObjects_sync(handle1, handle2)
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync close second OK")
|
||||
|
||||
|
||||
@slow
|
||||
async def test_WaitForMultipleObjects_sync_slow() -> None:
|
||||
# This does a series of test in which the main thread sync-waits for
|
||||
# handles, while we spawn a thread to set the handles after a short while.
|
||||
|
||||
TIMEOUT = 0.3
|
||||
|
||||
# One handle
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync,
|
||||
WaitForMultipleObjects_sync,
|
||||
handle1,
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
# If we would comment the line below, the above thread will be stuck,
|
||||
# and Trio won't exit this scope
|
||||
kernel32.SetEvent(handle1)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
print("test_WaitForMultipleObjects_sync_slow one OK")
|
||||
|
||||
# Two handles, signal first
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync,
|
||||
WaitForMultipleObjects_sync,
|
||||
handle1,
|
||||
handle2,
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle1)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync_slow thread-set first OK")
|
||||
|
||||
# Two handles, signal second
|
||||
handle1 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle2 = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(
|
||||
trio.to_thread.run_sync,
|
||||
WaitForMultipleObjects_sync,
|
||||
handle1,
|
||||
handle2,
|
||||
)
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle2)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
kernel32.CloseHandle(handle1)
|
||||
kernel32.CloseHandle(handle2)
|
||||
print("test_WaitForMultipleObjects_sync_slow thread-set second OK")
|
||||
|
||||
|
||||
async def test_WaitForSingleObject() -> None:
|
||||
# This does a series of test for setting/closing the handle before
|
||||
# initiating the wait.
|
||||
|
||||
# Test already set
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.SetEvent(handle)
|
||||
await WaitForSingleObject(handle) # should return at once
|
||||
kernel32.CloseHandle(handle)
|
||||
print("test_WaitForSingleObject already set OK")
|
||||
|
||||
# Test already set, as int
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle_int = int(ffi.cast("intptr_t", handle))
|
||||
kernel32.SetEvent(handle)
|
||||
await WaitForSingleObject(handle_int) # should return at once
|
||||
kernel32.CloseHandle(handle)
|
||||
print("test_WaitForSingleObject already set OK")
|
||||
|
||||
# Test already closed
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
kernel32.CloseHandle(handle)
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
await WaitForSingleObject(handle) # should return at once
|
||||
print("test_WaitForSingleObject already closed OK")
|
||||
|
||||
# Not a handle
|
||||
with pytest.raises(TypeError):
|
||||
await WaitForSingleObject("not a handle") # type: ignore[arg-type] # Wrong type
|
||||
# with pytest.raises(OSError):
|
||||
# await WaitForSingleObject(99) # If you're unlucky, it actually IS a handle :(
|
||||
print("test_WaitForSingleObject not a handle OK")
|
||||
|
||||
|
||||
@slow
|
||||
async def test_WaitForSingleObject_slow() -> None:
|
||||
# This does a series of test for setting the handle in another task,
|
||||
# and cancelling the wait task.
|
||||
|
||||
# Set the timeout used in the tests. We test the waiting time against
|
||||
# the timeout with a certain margin.
|
||||
TIMEOUT = 0.3
|
||||
|
||||
async def signal_soon_async(handle: Handle) -> None:
|
||||
await _timeouts.sleep(TIMEOUT)
|
||||
kernel32.SetEvent(handle)
|
||||
|
||||
# Test handle is SET after TIMEOUT in separate coroutine
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(WaitForSingleObject, handle)
|
||||
nursery.start_soon(signal_soon_async, handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow set from task OK")
|
||||
|
||||
# Test handle is SET after TIMEOUT in separate coroutine, as int
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
handle_int = int(ffi.cast("intptr_t", handle))
|
||||
t0 = _core.current_time()
|
||||
|
||||
async with _core.open_nursery() as nursery:
|
||||
nursery.start_soon(WaitForSingleObject, handle_int)
|
||||
nursery.start_soon(signal_soon_async, handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow set from task as int OK")
|
||||
|
||||
# Test handle is CLOSED after 1 sec - NOPE see comment above
|
||||
|
||||
# Test cancellation
|
||||
|
||||
handle = kernel32.CreateEventA(ffi.NULL, True, False, ffi.NULL)
|
||||
t0 = _core.current_time()
|
||||
|
||||
with _timeouts.move_on_after(TIMEOUT):
|
||||
await WaitForSingleObject(handle)
|
||||
|
||||
kernel32.CloseHandle(handle)
|
||||
t1 = _core.current_time()
|
||||
assert TIMEOUT <= (t1 - t0) < 2.0 * TIMEOUT
|
||||
print("test_WaitForSingleObject_slow cancellation OK")
|
112
lib/python3.13/site-packages/trio/_tests/test_windows_pipes.py
Normal file
112
lib/python3.13/site-packages/trio/_tests/test_windows_pipes.py
Normal file
@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from .. import _core
|
||||
from ..testing import check_one_way_stream, wait_all_tasks_blocked
|
||||
|
||||
# Mark all the tests in this file as being windows-only
|
||||
pytestmark = pytest.mark.skipif(sys.platform != "win32", reason="windows only")
|
||||
|
||||
assert ( # Skip type checking when not on Windows
|
||||
sys.platform == "win32" or not TYPE_CHECKING
|
||||
)
|
||||
|
||||
if sys.platform == "win32":
|
||||
from asyncio.windows_utils import pipe
|
||||
|
||||
from .._core._windows_cffi import _handle, kernel32
|
||||
from .._windows_pipes import PipeReceiveStream, PipeSendStream
|
||||
|
||||
|
||||
async def make_pipe() -> tuple[PipeSendStream, PipeReceiveStream]:
|
||||
"""Makes a new pair of pipes."""
|
||||
(r, w) = pipe()
|
||||
return PipeSendStream(w), PipeReceiveStream(r)
|
||||
|
||||
|
||||
async def test_pipe_typecheck() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
PipeSendStream(1.0) # type: ignore[arg-type]
|
||||
with pytest.raises(TypeError):
|
||||
PipeReceiveStream(None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_pipe_error_on_close() -> None:
|
||||
# Make sure we correctly handle a failure from kernel32.CloseHandle
|
||||
r, w = pipe()
|
||||
|
||||
send_stream = PipeSendStream(w)
|
||||
receive_stream = PipeReceiveStream(r)
|
||||
|
||||
assert kernel32.CloseHandle(_handle(r))
|
||||
assert kernel32.CloseHandle(_handle(w))
|
||||
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
await send_stream.aclose()
|
||||
with pytest.raises(OSError, match=r"^\[WinError 6\] The handle is invalid$"):
|
||||
await receive_stream.aclose()
|
||||
|
||||
|
||||
async def test_pipes_combined() -> None:
|
||||
write, read = await make_pipe()
|
||||
count = 2**20
|
||||
replicas = 3
|
||||
|
||||
async def sender() -> None:
|
||||
async with write:
|
||||
big = bytearray(count)
|
||||
for _ in range(replicas):
|
||||
await write.send_all(big)
|
||||
|
||||
async def reader() -> None:
|
||||
async with read:
|
||||
await wait_all_tasks_blocked()
|
||||
total_received = 0
|
||||
while True:
|
||||
# 5000 is chosen because it doesn't evenly divide 2**20
|
||||
received = len(await read.receive_some(5000))
|
||||
if not received:
|
||||
break
|
||||
total_received += received
|
||||
|
||||
assert total_received == count * replicas
|
||||
|
||||
async with _core.open_nursery() as n:
|
||||
n.start_soon(sender)
|
||||
n.start_soon(reader)
|
||||
|
||||
|
||||
async def test_async_with() -> None:
|
||||
w, r = await make_pipe()
|
||||
async with w, r:
|
||||
pass
|
||||
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await w.send_all(b"")
|
||||
with pytest.raises(_core.ClosedResourceError):
|
||||
await r.receive_some(10)
|
||||
|
||||
|
||||
async def test_close_during_write() -> None:
|
||||
w, r = await make_pipe()
|
||||
async with _core.open_nursery() as nursery:
|
||||
|
||||
async def write_forever() -> None:
|
||||
with pytest.raises(_core.ClosedResourceError) as excinfo:
|
||||
while True:
|
||||
await w.send_all(b"x" * 4096)
|
||||
assert "another task" in str(excinfo.value)
|
||||
|
||||
nursery.start_soon(write_forever)
|
||||
await wait_all_tasks_blocked(0.1)
|
||||
await w.aclose()
|
||||
|
||||
|
||||
async def test_pipe_fully() -> None:
|
||||
# passing make_clogged_pipe tests wait_send_all_might_not_block, and we
|
||||
# can't implement that on Windows
|
||||
await check_one_way_stream(make_pipe, None)
|
@ -0,0 +1,176 @@
|
||||
import ast
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from trio._tests.pytest_plugin import skip_if_optional_else_raise
|
||||
|
||||
# imports in gen_exports that are not in `install_requires` in setup.py
|
||||
try:
|
||||
import astor # noqa: F401
|
||||
import isort # noqa: F401
|
||||
except ImportError as error:
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
|
||||
from trio._tools.gen_exports import (
|
||||
File,
|
||||
create_passthrough_args,
|
||||
get_public_methods,
|
||||
process,
|
||||
run_linters,
|
||||
run_ruff,
|
||||
)
|
||||
|
||||
SOURCE = '''from _run import _public
|
||||
from collections import Counter
|
||||
|
||||
class Test:
|
||||
@_public
|
||||
def public_func(self):
|
||||
"""With doc string"""
|
||||
|
||||
@ignore_this
|
||||
@_public
|
||||
@another_decorator
|
||||
async def public_async_func(self) -> Counter:
|
||||
pass # no doc string
|
||||
|
||||
def not_public(self):
|
||||
pass
|
||||
|
||||
async def not_public_async(self):
|
||||
pass
|
||||
'''
|
||||
|
||||
IMPORT_1 = """\
|
||||
from collections import Counter
|
||||
"""
|
||||
|
||||
IMPORT_2 = """\
|
||||
from collections import Counter
|
||||
import os
|
||||
"""
|
||||
|
||||
IMPORT_3 = """\
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from collections import Counter
|
||||
"""
|
||||
|
||||
|
||||
def test_get_public_methods() -> None:
|
||||
methods = list(get_public_methods(ast.parse(SOURCE)))
|
||||
assert {m.name for m in methods} == {"public_func", "public_async_func"}
|
||||
|
||||
|
||||
def test_create_pass_through_args() -> None:
|
||||
testcases = [
|
||||
("def f()", "()"),
|
||||
("def f(one)", "(one)"),
|
||||
("def f(one, two)", "(one, two)"),
|
||||
("def f(one, *args)", "(one, *args)"),
|
||||
(
|
||||
"def f(one, *args, kw1, kw2=None, **kwargs)",
|
||||
"(one, *args, kw1=kw1, kw2=kw2, **kwargs)",
|
||||
),
|
||||
]
|
||||
|
||||
for funcdef, expected in testcases:
|
||||
func_node = ast.parse(funcdef + ":\n pass").body[0]
|
||||
assert isinstance(func_node, ast.FunctionDef)
|
||||
assert create_passthrough_args(func_node) == expected
|
||||
|
||||
|
||||
skip_lints = pytest.mark.skipif(
|
||||
sys.implementation.name != "cpython",
|
||||
reason="gen_exports is internal, black/isort only runs on CPython",
|
||||
)
|
||||
|
||||
|
||||
@skip_lints
|
||||
@pytest.mark.parametrize("imports", [IMPORT_1, IMPORT_2, IMPORT_3])
|
||||
def test_process(
|
||||
tmp_path: Path,
|
||||
imports: str,
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
try:
|
||||
import black # noqa: F401
|
||||
# there's no dedicated CI run that has astor+isort, but lacks black.
|
||||
except ImportError as error: # pragma: no cover
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
modpath = tmp_path / "_module.py"
|
||||
genpath = tmp_path / "_generated_module.py"
|
||||
modpath.write_text(SOURCE, encoding="utf-8")
|
||||
file = File(modpath, "runner", platform="linux", imports=imports)
|
||||
assert not genpath.exists()
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([file], do_test=True)
|
||||
assert excinfo.value.code == 1
|
||||
captured = capsys.readouterr()
|
||||
assert "Generated sources are outdated. Please regenerate." in captured.out
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([file], do_test=False)
|
||||
assert excinfo.value.code == 1
|
||||
captured = capsys.readouterr()
|
||||
assert "Regenerated sources successfully." in captured.out
|
||||
assert genpath.exists()
|
||||
process([file], do_test=True)
|
||||
# But if we change the lookup path it notices
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process(
|
||||
[File(modpath, "runner.io_manager", platform="linux", imports=imports)],
|
||||
do_test=True,
|
||||
)
|
||||
assert excinfo.value.code == 1
|
||||
# Also if the platform is changed.
|
||||
with pytest.raises(SystemExit) as excinfo:
|
||||
process([File(modpath, "runner", imports=imports)], do_test=True)
|
||||
assert excinfo.value.code == 1
|
||||
|
||||
|
||||
@skip_lints
|
||||
def test_run_ruff(tmp_path: Path) -> None:
|
||||
"""Test that processing properly fails if ruff does."""
|
||||
try:
|
||||
import ruff # noqa: F401
|
||||
except ImportError as error: # pragma: no cover
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
file = File(tmp_path / "module.py", "module")
|
||||
|
||||
success, _ = run_ruff(file, "class not valid code ><")
|
||||
assert not success
|
||||
|
||||
test_function = '''def combine_and(data: list[str]) -> str:
|
||||
"""Join values of text, and have 'and' with the last one properly."""
|
||||
if len(data) >= 2:
|
||||
data[-1] = 'and ' + data[-1]
|
||||
if len(data) > 2:
|
||||
return ', '.join(data)
|
||||
return ' '.join(data)'''
|
||||
|
||||
success, response = run_ruff(file, test_function)
|
||||
assert success
|
||||
assert response == test_function
|
||||
|
||||
|
||||
@skip_lints
|
||||
def test_lint_failure(tmp_path: Path) -> None:
|
||||
"""Test that processing properly fails if black or ruff does."""
|
||||
try:
|
||||
import black # noqa: F401
|
||||
import ruff # noqa: F401
|
||||
except ImportError as error: # pragma: no cover
|
||||
skip_if_optional_else_raise(error)
|
||||
|
||||
file = File(tmp_path / "module.py", "module")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
run_linters(file, "class not valid code ><")
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
run_linters(file, "import waffle\n;import trio")
|
@ -0,0 +1,140 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
|
||||
from trio._tools.mypy_annotate import Result, export, main, process_line
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("src", "expected"),
|
||||
[
|
||||
("", None),
|
||||
("a regular line\n", None),
|
||||
(
|
||||
"package\\filename.py:42:8: note: Some info\n",
|
||||
Result(
|
||||
kind="notice",
|
||||
filename="package\\filename.py",
|
||||
start_line=42,
|
||||
start_col=8,
|
||||
end_line=None,
|
||||
end_col=None,
|
||||
message=" Some info",
|
||||
),
|
||||
),
|
||||
(
|
||||
"package/filename.py:42:1:46:3: error: Type error here [code]\n",
|
||||
Result(
|
||||
kind="error",
|
||||
filename="package/filename.py",
|
||||
start_line=42,
|
||||
start_col=1,
|
||||
end_line=46,
|
||||
end_col=3,
|
||||
message=" Type error here [code]",
|
||||
),
|
||||
),
|
||||
(
|
||||
"package/module.py:87: warn: Bad code\n",
|
||||
Result(
|
||||
kind="warning",
|
||||
filename="package/module.py",
|
||||
start_line=87,
|
||||
message=" Bad code",
|
||||
),
|
||||
),
|
||||
],
|
||||
ids=["blank", "normal", "note-wcol", "error-wend", "warn-lineonly"],
|
||||
)
|
||||
def test_processing(src: str, expected: Result | None) -> None:
|
||||
result = process_line(src)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_export(capsys: pytest.CaptureFixture[str]) -> None:
|
||||
results = {
|
||||
Result(
|
||||
kind="notice",
|
||||
filename="package\\filename.py",
|
||||
start_line=42,
|
||||
start_col=8,
|
||||
end_line=None,
|
||||
end_col=None,
|
||||
message=" Some info",
|
||||
): ["Windows", "Mac"],
|
||||
Result(
|
||||
kind="error",
|
||||
filename="package/filename.py",
|
||||
start_line=42,
|
||||
start_col=1,
|
||||
end_line=46,
|
||||
end_col=3,
|
||||
message=" Type error here [code]",
|
||||
): ["Linux", "Mac"],
|
||||
Result(
|
||||
kind="warning",
|
||||
filename="package/module.py",
|
||||
start_line=87,
|
||||
message=" Bad code",
|
||||
): ["Linux"],
|
||||
}
|
||||
export(results)
|
||||
std = capsys.readouterr()
|
||||
assert std.err == ""
|
||||
assert std.out == (
|
||||
"::notice file=package\\filename.py,line=42,col=8,"
|
||||
"title=Mypy-Windows+Mac::package\\filename.py:(42:8): Some info"
|
||||
"\n"
|
||||
"::error file=package/filename.py,line=42,col=1,endLine=46,endColumn=3,"
|
||||
"title=Mypy-Linux+Mac::package/filename.py:(42:1 - 46:3): Type error here [code]"
|
||||
"\n"
|
||||
"::warning file=package/module.py,line=87,"
|
||||
"title=Mypy-Linux::package/module.py:87: Bad code\n"
|
||||
)
|
||||
|
||||
|
||||
def test_endtoend(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
import trio._tools.mypy_annotate as mypy_annotate
|
||||
|
||||
inp_text = """\
|
||||
Mypy begun
|
||||
trio/core.py:15: error: Bad types here [misc]
|
||||
trio/package/module.py:48:4:56:18: warn: Missing annotations [no-untyped-def]
|
||||
Found 3 errors in 29 files
|
||||
"""
|
||||
result_file = tmp_path / "dump.dat"
|
||||
assert not result_file.exists()
|
||||
with monkeypatch.context():
|
||||
monkeypatch.setattr(sys, "stdin", io.StringIO(inp_text))
|
||||
|
||||
mypy_annotate.main(
|
||||
["--dumpfile", str(result_file), "--platform", "SomePlatform"],
|
||||
)
|
||||
|
||||
std = capsys.readouterr()
|
||||
assert std.err == ""
|
||||
assert std.out == inp_text # Echos the original.
|
||||
|
||||
assert result_file.exists()
|
||||
|
||||
main(["--dumpfile", str(result_file)])
|
||||
|
||||
std = capsys.readouterr()
|
||||
assert std.err == ""
|
||||
assert std.out == (
|
||||
"::error file=trio/core.py,line=15,title=Mypy-SomePlatform::trio/core.py:15: Bad types here [misc]\n"
|
||||
"::warning file=trio/package/module.py,line=48,col=4,endLine=56,endColumn=18,"
|
||||
"title=Mypy-SomePlatform::trio/package/module.py:(48:4 - 56:18): Missing "
|
||||
"annotations [no-untyped-def]\n"
|
||||
)
|
@ -0,0 +1,9 @@
|
||||
# https://github.com/python-trio/trio/issues/2775#issuecomment-1702267589
|
||||
# (except platform independent...)
|
||||
import trio
|
||||
from typing_extensions import assert_type
|
||||
|
||||
|
||||
async def fn(s: trio.SocketStream) -> None:
|
||||
result = await s.socket.sendto(b"a", "h")
|
||||
assert_type(result, int)
|
@ -0,0 +1,4 @@
|
||||
# https://github.com/python-trio/trio/issues/2873
|
||||
import trio
|
||||
|
||||
s, r = trio.open_memory_channel[int](0)
|
144
lib/python3.13/site-packages/trio/_tests/type_tests/path.py
Normal file
144
lib/python3.13/site-packages/trio/_tests/type_tests/path.py
Normal file
@ -0,0 +1,144 @@
|
||||
"""Path wrapping is quite complex, ensure all methods are understood as wrapped correctly."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import IO, Any, BinaryIO, List, Tuple
|
||||
|
||||
import trio
|
||||
from trio._file_io import AsyncIOWrapper
|
||||
from typing_extensions import assert_type
|
||||
|
||||
|
||||
def operator_checks(text: str, tpath: trio.Path, ppath: pathlib.Path) -> None:
|
||||
"""Verify operators produce the right results."""
|
||||
assert_type(tpath / ppath, trio.Path)
|
||||
assert_type(tpath / tpath, trio.Path)
|
||||
assert_type(tpath / text, trio.Path)
|
||||
assert_type(text / tpath, trio.Path)
|
||||
|
||||
assert_type(tpath > tpath, bool)
|
||||
assert_type(tpath >= tpath, bool)
|
||||
assert_type(tpath < tpath, bool)
|
||||
assert_type(tpath <= tpath, bool)
|
||||
|
||||
assert_type(tpath > ppath, bool)
|
||||
assert_type(tpath >= ppath, bool)
|
||||
assert_type(tpath < ppath, bool)
|
||||
assert_type(tpath <= ppath, bool)
|
||||
|
||||
assert_type(ppath > tpath, bool)
|
||||
assert_type(ppath >= tpath, bool)
|
||||
assert_type(ppath < tpath, bool)
|
||||
assert_type(ppath <= tpath, bool)
|
||||
|
||||
|
||||
def sync_attrs(path: trio.Path) -> None:
|
||||
assert_type(path.parts, Tuple[str, ...])
|
||||
assert_type(path.drive, str)
|
||||
assert_type(path.root, str)
|
||||
assert_type(path.anchor, str)
|
||||
assert_type(path.parents[3], trio.Path)
|
||||
assert_type(path.parent, trio.Path)
|
||||
assert_type(path.name, str)
|
||||
assert_type(path.suffix, str)
|
||||
assert_type(path.suffixes, List[str])
|
||||
assert_type(path.stem, str)
|
||||
assert_type(path.as_posix(), str)
|
||||
assert_type(path.as_uri(), str)
|
||||
assert_type(path.is_absolute(), bool)
|
||||
if sys.version_info > (3, 9):
|
||||
assert_type(path.is_relative_to(path), bool)
|
||||
assert_type(path.is_reserved(), bool)
|
||||
assert_type(path.joinpath(path, "folder"), trio.Path)
|
||||
assert_type(path.match("*.py"), bool)
|
||||
assert_type(path.relative_to("/usr"), trio.Path)
|
||||
if sys.version_info > (3, 12):
|
||||
assert_type(path.relative_to("/", walk_up=True), bool)
|
||||
assert_type(path.with_name("filename.txt"), trio.Path)
|
||||
if sys.version_info > (3, 9):
|
||||
assert_type(path.with_stem("readme"), trio.Path)
|
||||
assert_type(path.with_suffix(".log"), trio.Path)
|
||||
|
||||
|
||||
async def async_attrs(path: trio.Path) -> None:
|
||||
assert_type(await trio.Path.cwd(), trio.Path)
|
||||
assert_type(await trio.Path.home(), trio.Path)
|
||||
assert_type(await path.stat(), os.stat_result)
|
||||
assert_type(await path.chmod(0o777), None)
|
||||
assert_type(await path.exists(), bool)
|
||||
assert_type(await path.expanduser(), trio.Path)
|
||||
for result in await path.glob("*.py"):
|
||||
assert_type(result, trio.Path)
|
||||
if sys.platform != "win32":
|
||||
assert_type(await path.group(), str)
|
||||
assert_type(await path.is_dir(), bool)
|
||||
assert_type(await path.is_file(), bool)
|
||||
if sys.version_info > (3, 12):
|
||||
assert_type(await path.is_junction(), bool)
|
||||
if sys.platform != "win32":
|
||||
assert_type(await path.is_mount(), bool)
|
||||
assert_type(await path.is_symlink(), bool)
|
||||
assert_type(await path.is_socket(), bool)
|
||||
assert_type(await path.is_fifo(), bool)
|
||||
assert_type(await path.is_block_device(), bool)
|
||||
assert_type(await path.is_char_device(), bool)
|
||||
for child_iter in await path.iterdir():
|
||||
assert_type(child_iter, trio.Path)
|
||||
# TODO: Path.walk() in 3.12
|
||||
assert_type(await path.lchmod(0o111), None)
|
||||
assert_type(await path.lstat(), os.stat_result)
|
||||
assert_type(await path.mkdir(mode=0o777, parents=True, exist_ok=False), None)
|
||||
# Open done separately.
|
||||
if sys.platform != "win32":
|
||||
assert_type(await path.owner(), str)
|
||||
assert_type(await path.read_bytes(), bytes)
|
||||
assert_type(await path.read_text(encoding="utf16", errors="replace"), str)
|
||||
if sys.version_info > (3, 9):
|
||||
assert_type(await path.readlink(), trio.Path)
|
||||
assert_type(await path.rename("another"), trio.Path)
|
||||
assert_type(await path.replace(path), trio.Path)
|
||||
assert_type(await path.resolve(), trio.Path)
|
||||
for child_glob in await path.glob("*.py"):
|
||||
assert_type(child_glob, trio.Path)
|
||||
for child_rglob in await path.rglob("*.py"):
|
||||
assert_type(child_rglob, trio.Path)
|
||||
assert_type(await path.rmdir(), None)
|
||||
assert_type(await path.samefile("something_else"), bool)
|
||||
assert_type(await path.symlink_to("somewhere"), None)
|
||||
if sys.version_info > (3, 10):
|
||||
assert_type(await path.hardlink_to("elsewhere"), None)
|
||||
assert_type(await path.touch(), None)
|
||||
assert_type(await path.unlink(missing_ok=True), None)
|
||||
assert_type(await path.write_bytes(b"123"), int)
|
||||
assert_type(
|
||||
await path.write_text("hello", encoding="utf32le", errors="ignore"),
|
||||
int,
|
||||
)
|
||||
|
||||
|
||||
async def open_results(path: trio.Path, some_int: int, some_str: str) -> None:
|
||||
# Check the overloads.
|
||||
assert_type(await path.open(), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("r"), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("r+"), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("w"), AsyncIOWrapper[io.TextIOWrapper])
|
||||
assert_type(await path.open("rb", buffering=0), AsyncIOWrapper[io.FileIO])
|
||||
assert_type(await path.open("rb+"), AsyncIOWrapper[io.BufferedRandom])
|
||||
assert_type(await path.open("wb"), AsyncIOWrapper[io.BufferedWriter])
|
||||
assert_type(await path.open("rb"), AsyncIOWrapper[io.BufferedReader])
|
||||
assert_type(await path.open("rb", buffering=some_int), AsyncIOWrapper[BinaryIO])
|
||||
assert_type(await path.open(some_str), AsyncIOWrapper[IO[Any]])
|
||||
|
||||
# Check they produce the right types.
|
||||
file_bin = await path.open("rb+")
|
||||
assert_type(await file_bin.read(), bytes)
|
||||
assert_type(await file_bin.write(b"test"), int)
|
||||
assert_type(await file_bin.seek(32), int)
|
||||
|
||||
file_text = await path.open("r+t")
|
||||
assert_type(await file_text.read(), str)
|
||||
assert_type(await file_text.write("test"), int)
|
||||
# TODO: report mypy bug: equiv to https://github.com/microsoft/pyright/issues/6833
|
||||
assert_type(await file_text.readlines(), List[str])
|
@ -0,0 +1,254 @@
|
||||
"""The typing of RaisesGroup involves a lot of deception and lies, since AFAIK what we
|
||||
actually want to achieve is ~impossible. This is because we specify what we expect with
|
||||
instances of RaisesGroup and exception classes, but excinfo.value will be instances of
|
||||
[Base]ExceptionGroup and instances of exceptions. So we need to "translate" from
|
||||
RaisesGroup to ExceptionGroup.
|
||||
|
||||
The way it currently works is that RaisesGroup[E] corresponds to
|
||||
ExceptionInfo[BaseExceptionGroup[E]], so the top-level group will be correct. But
|
||||
RaisesGroup[RaisesGroup[ValueError]] will become
|
||||
ExceptionInfo[BaseExceptionGroup[RaisesGroup[ValueError]]]. To get around that we specify
|
||||
RaisesGroup as a subclass of BaseExceptionGroup during type checking - which should mean
|
||||
that most static type checking for end users should be mostly correct.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Union
|
||||
|
||||
from trio.testing import Matcher, RaisesGroup
|
||||
from typing_extensions import assert_type
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from exceptiongroup import BaseExceptionGroup, ExceptionGroup
|
||||
|
||||
# split into functions to isolate the different scopes
|
||||
|
||||
|
||||
def check_inheritance_and_assignments() -> None:
|
||||
# Check inheritance
|
||||
_: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError)
|
||||
_ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore
|
||||
|
||||
a: BaseExceptionGroup[BaseExceptionGroup[ValueError]]
|
||||
a = RaisesGroup(RaisesGroup(ValueError))
|
||||
a = BaseExceptionGroup("", (BaseExceptionGroup("", (ValueError(),)),))
|
||||
assert a
|
||||
|
||||
|
||||
def check_matcher_typevar_default(e: Matcher) -> object:
|
||||
assert e.exception_type is not None
|
||||
exc: type[BaseException] = e.exception_type
|
||||
# this would previously pass, as the type would be `Any`
|
||||
e.exception_type().blah() # type: ignore
|
||||
return exc # Silence Pyright unused var warning
|
||||
|
||||
|
||||
def check_basic_contextmanager() -> None:
|
||||
# One level of Group is correctly translated - except it's a BaseExceptionGroup
|
||||
# instead of an ExceptionGroup.
|
||||
with RaisesGroup(ValueError) as e:
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
assert_type(e.value, BaseExceptionGroup[ValueError])
|
||||
|
||||
|
||||
def check_basic_matches() -> None:
|
||||
# check that matches gets rid of the naked ValueError in the union
|
||||
exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup("", (ValueError(),))
|
||||
if RaisesGroup(ValueError).matches(exc):
|
||||
assert_type(exc, BaseExceptionGroup[ValueError])
|
||||
|
||||
|
||||
def check_matches_with_different_exception_type() -> None:
|
||||
# This should probably raise some type error somewhere, since
|
||||
# ValueError != KeyboardInterrupt
|
||||
e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup(
|
||||
"",
|
||||
(KeyboardInterrupt(),),
|
||||
)
|
||||
if RaisesGroup(ValueError).matches(e):
|
||||
assert_type(e, BaseExceptionGroup[ValueError])
|
||||
|
||||
|
||||
def check_matcher_init() -> None:
|
||||
def check_exc(exc: BaseException) -> bool:
|
||||
return isinstance(exc, ValueError)
|
||||
|
||||
# Check various combinations of constructor signatures.
|
||||
# At least 1 arg must be provided.
|
||||
Matcher() # type: ignore
|
||||
Matcher(ValueError)
|
||||
Matcher(ValueError, "regex")
|
||||
Matcher(ValueError, "regex", check_exc)
|
||||
Matcher(exception_type=ValueError)
|
||||
Matcher(match="regex")
|
||||
Matcher(check=check_exc)
|
||||
Matcher(ValueError, match="regex")
|
||||
Matcher(match="regex", check=check_exc)
|
||||
|
||||
def check_filenotfound(exc: FileNotFoundError) -> bool:
|
||||
return not exc.filename.endswith(".tmp")
|
||||
|
||||
# If exception_type is provided, that narrows the `check` method's argument.
|
||||
Matcher(FileNotFoundError, check=check_filenotfound)
|
||||
Matcher(ValueError, check=check_filenotfound) # type: ignore
|
||||
Matcher(check=check_filenotfound) # type: ignore
|
||||
Matcher(FileNotFoundError, match="regex", check=check_filenotfound)
|
||||
|
||||
|
||||
def raisesgroup_check_type_narrowing() -> None:
|
||||
"""Check type narrowing on the `check` argument to `RaisesGroup`.
|
||||
All `type: ignore`s are correctly pointing out type errors, except
|
||||
where otherwise noted.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def handle_exc(e: BaseExceptionGroup[BaseException]) -> bool:
|
||||
return True
|
||||
|
||||
def handle_kbi(e: BaseExceptionGroup[KeyboardInterrupt]) -> bool:
|
||||
return True
|
||||
|
||||
def handle_value(e: BaseExceptionGroup[ValueError]) -> bool:
|
||||
return True
|
||||
|
||||
RaisesGroup(BaseException, check=handle_exc)
|
||||
RaisesGroup(BaseException, check=handle_kbi) # type: ignore
|
||||
|
||||
RaisesGroup(Exception, check=handle_exc)
|
||||
RaisesGroup(Exception, check=handle_value) # type: ignore
|
||||
|
||||
RaisesGroup(KeyboardInterrupt, check=handle_exc)
|
||||
RaisesGroup(KeyboardInterrupt, check=handle_kbi)
|
||||
RaisesGroup(KeyboardInterrupt, check=handle_value) # type: ignore
|
||||
|
||||
RaisesGroup(ValueError, check=handle_exc)
|
||||
RaisesGroup(ValueError, check=handle_kbi) # type: ignore
|
||||
RaisesGroup(ValueError, check=handle_value)
|
||||
|
||||
RaisesGroup(ValueError, KeyboardInterrupt, check=handle_exc)
|
||||
RaisesGroup(ValueError, KeyboardInterrupt, check=handle_kbi) # type: ignore
|
||||
RaisesGroup(ValueError, KeyboardInterrupt, check=handle_value) # type: ignore
|
||||
|
||||
|
||||
def raisesgroup_narrow_baseexceptiongroup() -> None:
|
||||
"""Check type narrowing specifically for the container exceptiongroup.
|
||||
This is not currently working, and after playing around with it for a bit
|
||||
I think the only way is to introduce a subclass `NonBaseRaisesGroup`, and overload
|
||||
`__new__` in Raisesgroup to return the subclass when exceptions are non-base.
|
||||
(or make current class BaseRaisesGroup and introduce RaisesGroup for non-base)
|
||||
I encountered problems trying to type this though, see
|
||||
https://github.com/python/mypy/issues/17251
|
||||
That is probably possible to work around by entirely using `__new__` instead of
|
||||
`__init__`, but........ ugh.
|
||||
"""
|
||||
|
||||
def handle_group(e: ExceptionGroup[Exception]) -> bool:
|
||||
return True
|
||||
|
||||
def handle_group_value(e: ExceptionGroup[ValueError]) -> bool:
|
||||
return True
|
||||
|
||||
# should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup
|
||||
RaisesGroup(ValueError, check=handle_group_value) # type: ignore
|
||||
|
||||
# should work, but BaseExceptionGroup does not get narrowed to ExceptionGroup
|
||||
RaisesGroup(Exception, check=handle_group) # type: ignore
|
||||
|
||||
|
||||
def check_matcher_transparent() -> None:
|
||||
with RaisesGroup(Matcher(ValueError)) as e:
|
||||
...
|
||||
_: BaseExceptionGroup[ValueError] = e.value
|
||||
assert_type(e.value, BaseExceptionGroup[ValueError])
|
||||
|
||||
|
||||
def check_nested_raisesgroups_contextmanager() -> None:
|
||||
with RaisesGroup(RaisesGroup(ValueError)) as excinfo:
|
||||
raise ExceptionGroup("foo", (ValueError(),))
|
||||
|
||||
# thanks to inheritance this assignment works
|
||||
_: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value
|
||||
# and it can mostly be treated like an exceptiongroup
|
||||
print(excinfo.value.exceptions[0].exceptions[0])
|
||||
|
||||
# but assert_type reveals the lies
|
||||
print(type(excinfo.value)) # would print "ExceptionGroup"
|
||||
# typing says it's a BaseExceptionGroup
|
||||
assert_type(
|
||||
excinfo.value,
|
||||
BaseExceptionGroup[RaisesGroup[ValueError]],
|
||||
)
|
||||
|
||||
print(type(excinfo.value.exceptions[0])) # would print "ExceptionGroup"
|
||||
# but type checkers are utterly confused
|
||||
assert_type(
|
||||
excinfo.value.exceptions[0],
|
||||
Union[RaisesGroup[ValueError], BaseExceptionGroup[RaisesGroup[ValueError]]],
|
||||
)
|
||||
|
||||
|
||||
def check_nested_raisesgroups_matches() -> None:
|
||||
"""Check nested RaisesGroups with .matches"""
|
||||
exc: ExceptionGroup[ExceptionGroup[ValueError]] = ExceptionGroup(
|
||||
"",
|
||||
(ExceptionGroup("", (ValueError(),)),),
|
||||
)
|
||||
# has the same problems as check_nested_raisesgroups_contextmanager
|
||||
if RaisesGroup(RaisesGroup(ValueError)).matches(exc):
|
||||
assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]])
|
||||
|
||||
|
||||
def check_multiple_exceptions_1() -> None:
|
||||
a = RaisesGroup(ValueError, ValueError)
|
||||
b = RaisesGroup(Matcher(ValueError), Matcher(ValueError))
|
||||
c = RaisesGroup(ValueError, Matcher(ValueError))
|
||||
|
||||
d: BaseExceptionGroup[ValueError]
|
||||
d = a
|
||||
d = b
|
||||
d = c
|
||||
assert d
|
||||
|
||||
|
||||
def check_multiple_exceptions_2() -> None:
|
||||
# This previously failed due to lack of covariance in the TypeVar
|
||||
a = RaisesGroup(Matcher(ValueError), Matcher(TypeError))
|
||||
b = RaisesGroup(Matcher(ValueError), TypeError)
|
||||
c = RaisesGroup(ValueError, TypeError)
|
||||
|
||||
d: BaseExceptionGroup[Exception]
|
||||
d = a
|
||||
d = b
|
||||
d = c
|
||||
assert d
|
||||
|
||||
|
||||
def check_raisesgroup_overloads() -> None:
|
||||
# allow_unwrapped=True does not allow:
|
||||
# multiple exceptions
|
||||
RaisesGroup(ValueError, TypeError, allow_unwrapped=True) # type: ignore
|
||||
# nested RaisesGroup
|
||||
RaisesGroup(RaisesGroup(ValueError), allow_unwrapped=True) # type: ignore
|
||||
# specifying match
|
||||
RaisesGroup(ValueError, match="foo", allow_unwrapped=True) # type: ignore
|
||||
# specifying check
|
||||
RaisesGroup(ValueError, check=bool, allow_unwrapped=True) # type: ignore
|
||||
# allowed variants
|
||||
RaisesGroup(ValueError, allow_unwrapped=True)
|
||||
RaisesGroup(ValueError, allow_unwrapped=True, flatten_subgroups=True)
|
||||
RaisesGroup(Matcher(ValueError), allow_unwrapped=True)
|
||||
|
||||
# flatten_subgroups=True does not allow nested RaisesGroup
|
||||
RaisesGroup(RaisesGroup(ValueError), flatten_subgroups=True) # type: ignore
|
||||
# but rest is plenty fine
|
||||
RaisesGroup(ValueError, TypeError, flatten_subgroups=True)
|
||||
RaisesGroup(ValueError, match="foo", flatten_subgroups=True)
|
||||
RaisesGroup(ValueError, check=bool, flatten_subgroups=True)
|
||||
RaisesGroup(ValueError, flatten_subgroups=True)
|
||||
RaisesGroup(Matcher(ValueError), flatten_subgroups=True)
|
||||
|
||||
# if they're both false we can of course specify nested raisesgroup
|
||||
RaisesGroup(RaisesGroup(ValueError))
|
@ -0,0 +1,29 @@
|
||||
"""Check that started() can only be called for TaskStatus[None]."""
|
||||
|
||||
from trio import TaskStatus
|
||||
from typing_extensions import assert_type
|
||||
|
||||
|
||||
async def check_status(
|
||||
none_status_explicit: TaskStatus[None],
|
||||
none_status_implicit: TaskStatus,
|
||||
int_status: TaskStatus[int],
|
||||
) -> None:
|
||||
assert_type(none_status_explicit, TaskStatus[None])
|
||||
assert_type(none_status_implicit, TaskStatus[None]) # Default typevar
|
||||
assert_type(int_status, TaskStatus[int])
|
||||
|
||||
# Omitting the parameter is only allowed for None.
|
||||
none_status_explicit.started()
|
||||
none_status_implicit.started()
|
||||
int_status.started() # type: ignore
|
||||
|
||||
# Explicit None is allowed.
|
||||
none_status_explicit.started(None)
|
||||
none_status_implicit.started(None)
|
||||
int_status.started(None) # type: ignore
|
||||
|
||||
none_status_explicit.started(42) # type: ignore
|
||||
none_status_implicit.started(42) # type: ignore
|
||||
int_status.started(42)
|
||||
int_status.started(True)
|
Reference in New Issue
Block a user