Updated script that can be controled by Nodejs web app

This commit is contained in:
mac OS
2024-11-25 12:24:18 +07:00
parent c440eda1f4
commit 8b0ab2bd3a
8662 changed files with 1803808 additions and 34 deletions

View File

@ -0,0 +1,87 @@
"""
This namespace represents the core functionality that has to be built-in
and deal with private internal data structures. Things in this namespace
are publicly available in either trio, trio.lowlevel, or trio.testing.
"""
import sys
from ._entry_queue import TrioToken
from ._exceptions import (
BrokenResourceError,
BusyResourceError,
Cancelled,
ClosedResourceError,
EndOfChannel,
RunFinishedError,
TrioInternalError,
WouldBlock,
)
from ._ki import currently_ki_protected, disable_ki_protection, enable_ki_protection
from ._local import RunVar, RunVarToken
from ._mock_clock import MockClock
from ._parking_lot import (
ParkingLot,
ParkingLotStatistics,
add_parking_lot_breaker,
remove_parking_lot_breaker,
)
# Imports that always exist
from ._run import (
TASK_STATUS_IGNORED,
CancelScope,
Nursery,
RunStatistics,
Task,
TaskStatus,
add_instrument,
checkpoint,
checkpoint_if_cancelled,
current_clock,
current_effective_deadline,
current_root_task,
current_statistics,
current_task,
current_time,
current_trio_token,
notify_closing,
open_nursery,
remove_instrument,
reschedule,
run,
spawn_system_task,
start_guest_run,
wait_all_tasks_blocked,
wait_readable,
wait_writable,
)
from ._thread_cache import start_thread_soon
# Has to come after _run to resolve a circular import
from ._traps import (
Abort,
RaiseCancelT,
cancel_shielded_checkpoint,
permanently_detach_coroutine_object,
reattach_detached_coroutine_object,
temporarily_detach_coroutine_object,
wait_task_rescheduled,
)
from ._unbounded_queue import UnboundedQueue, UnboundedQueueStatistics
# Windows imports
if sys.platform == "win32":
from ._run import (
current_iocp,
monitor_completion_key,
readinto_overlapped,
register_with_iocp,
wait_overlapped,
write_overlapped,
)
# Kqueue imports
elif sys.platform != "linux" and sys.platform != "win32":
from ._run import current_kqueue, monitor_kevent, wait_kevent
del sys # It would be better to import sys as _sys, but mypy does not understand it

View File

@ -0,0 +1,216 @@
from __future__ import annotations
import logging
import sys
import warnings
import weakref
from typing import TYPE_CHECKING, NoReturn
import attrs
from .. import _core
from .._util import name_asyncgen
from . import _run
# Used to log exceptions in async generator finalizers
ASYNCGEN_LOGGER = logging.getLogger("trio.async_generator_errors")
if TYPE_CHECKING:
from types import AsyncGeneratorType
from typing import Set
_WEAK_ASYNC_GEN_SET = weakref.WeakSet[AsyncGeneratorType[object, NoReturn]]
_ASYNC_GEN_SET = Set[AsyncGeneratorType[object, NoReturn]]
else:
_WEAK_ASYNC_GEN_SET = weakref.WeakSet
_ASYNC_GEN_SET = set
@attrs.define(eq=False)
class AsyncGenerators:
# Async generators are added to this set when first iterated. Any
# left after the main task exits will be closed before trio.run()
# returns. During most of the run, this is a WeakSet so GC works.
# During shutdown, when we're finalizing all the remaining
# asyncgens after the system nursery has been closed, it's a
# regular set so we don't have to deal with GC firing at
# unexpected times.
alive: _WEAK_ASYNC_GEN_SET | _ASYNC_GEN_SET = attrs.Factory(_WEAK_ASYNC_GEN_SET)
# This collects async generators that get garbage collected during
# the one-tick window between the system nursery closing and the
# init task starting end-of-run asyncgen finalization.
trailing_needs_finalize: _ASYNC_GEN_SET = attrs.Factory(_ASYNC_GEN_SET)
prev_hooks: sys._asyncgen_hooks = attrs.field(init=False)
def install_hooks(self, runner: _run.Runner) -> None:
def firstiter(agen: AsyncGeneratorType[object, NoReturn]) -> None:
if hasattr(_run.GLOBAL_RUN_CONTEXT, "task"):
self.alive.add(agen)
else:
# An async generator first iterated outside of a Trio
# task doesn't belong to Trio. Probably we're in guest
# mode and the async generator belongs to our host.
# The locals dictionary is the only good place to
# remember this fact, at least until
# https://bugs.python.org/issue40916 is implemented.
agen.ag_frame.f_locals["@trio_foreign_asyncgen"] = True
if self.prev_hooks.firstiter is not None:
self.prev_hooks.firstiter(agen)
def finalize_in_trio_context(
agen: AsyncGeneratorType[object, NoReturn],
agen_name: str,
) -> None:
try:
runner.spawn_system_task(
self._finalize_one,
agen,
agen_name,
name=f"close asyncgen {agen_name} (abandoned)",
)
except RuntimeError:
# There is a one-tick window where the system nursery
# is closed but the init task hasn't yet made
# self.asyncgens a strong set to disable GC. We seem to
# have hit it.
self.trailing_needs_finalize.add(agen)
def finalizer(agen: AsyncGeneratorType[object, NoReturn]) -> None:
agen_name = name_asyncgen(agen)
try:
is_ours = not agen.ag_frame.f_locals.get("@trio_foreign_asyncgen")
except AttributeError: # pragma: no cover
is_ours = True
if is_ours:
runner.entry_queue.run_sync_soon(
finalize_in_trio_context,
agen,
agen_name,
)
# Do this last, because it might raise an exception
# depending on the user's warnings filter. (That
# exception will be printed to the terminal and
# ignored, since we're running in GC context.)
warnings.warn(
f"Async generator {agen_name!r} was garbage collected before it "
"had been exhausted. Surround its use in 'async with "
"aclosing(...):' to ensure that it gets cleaned up as soon as "
"you're done using it.",
ResourceWarning,
stacklevel=2,
source=agen,
)
else:
# Not ours -> forward to the host loop's async generator finalizer
if self.prev_hooks.finalizer is not None:
self.prev_hooks.finalizer(agen)
else:
# Host has no finalizer. Reimplement the default
# Python behavior with no hooks installed: throw in
# GeneratorExit, step once, raise RuntimeError if
# it doesn't exit.
closer = agen.aclose()
try:
# If the next thing is a yield, this will raise RuntimeError
# which we allow to propagate
closer.send(None)
except StopIteration:
pass
else:
# If the next thing is an await, we get here. Give a nicer
# error than the default "async generator ignored GeneratorExit"
raise RuntimeError(
f"Non-Trio async generator {agen_name!r} awaited something "
"during finalization; install a finalization hook to "
"support this, or wrap it in 'async with aclosing(...):'",
)
self.prev_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(firstiter=firstiter, finalizer=finalizer) # type: ignore[arg-type] # Finalizer doesn't use AsyncGeneratorType
async def finalize_remaining(self, runner: _run.Runner) -> None:
# This is called from init after shutting down the system nursery.
# The only tasks running at this point are init and
# the run_sync_soon task, and since the system nursery is closed,
# there's no way for user code to spawn more.
assert _core.current_task() is runner.init_task
assert len(runner.tasks) == 2
# To make async generator finalization easier to reason
# about, we'll shut down asyncgen garbage collection by turning
# the alive WeakSet into a regular set.
self.alive = set(self.alive)
# Process all pending run_sync_soon callbacks, in case one of
# them was an asyncgen finalizer that snuck in under the wire.
runner.entry_queue.run_sync_soon(runner.reschedule, runner.init_task)
await _core.wait_task_rescheduled(
lambda _: _core.Abort.FAILED, # pragma: no cover
)
self.alive.update(self.trailing_needs_finalize)
self.trailing_needs_finalize.clear()
# None of the still-living tasks use async generators, so
# every async generator must be suspended at a yield point --
# there's no one to be doing the iteration. That's good,
# because aclose() only works on an asyncgen that's suspended
# at a yield point. (If it's suspended at an event loop trap,
# because someone is in the middle of iterating it, then you
# get a RuntimeError on 3.8+, and a nasty surprise on earlier
# versions due to https://bugs.python.org/issue32526.)
#
# However, once we start aclose() of one async generator, it
# might start fetching the next value from another, thus
# preventing us from closing that other (at least until
# aclose() of the first one is complete). This constraint
# effectively requires us to finalize the remaining asyncgens
# in arbitrary order, rather than doing all of them at the
# same time. On 3.8+ we could defer any generator with
# ag_running=True to a later batch, but that only catches
# the case where our aclose() starts after the user's
# asend()/etc. If our aclose() starts first, then the
# user's asend()/etc will raise RuntimeError, since they're
# probably not checking ag_running.
#
# It might be possible to allow some parallelized cleanup if
# we can determine that a certain set of asyncgens have no
# interdependencies, using gc.get_referents() and such.
# But just doing one at a time will typically work well enough
# (since each aclose() executes in a cancelled scope) and
# is much easier to reason about.
# It's possible that that cleanup code will itself create
# more async generators, so we iterate repeatedly until
# all are gone.
while self.alive:
batch = self.alive
self.alive = _ASYNC_GEN_SET()
for agen in batch:
await self._finalize_one(agen, name_asyncgen(agen))
def close(self) -> None:
sys.set_asyncgen_hooks(*self.prev_hooks)
async def _finalize_one(
self,
agen: AsyncGeneratorType[object, NoReturn],
name: object,
) -> None:
try:
# This shield ensures that finalize_asyncgen never exits
# with an exception, not even a Cancelled. The inside
# is cancelled so there's no deadlock risk.
with _core.CancelScope(shield=True) as cancel_scope:
cancel_scope.cancel()
await agen.aclose()
except BaseException:
ASYNCGEN_LOGGER.exception(
"Exception ignored during finalization of async generator %r -- "
"surround your use of the generator in 'async with aclosing(...):' "
"to raise exceptions like this in the context where they're generated",
name,
)

View File

@ -0,0 +1,130 @@
from __future__ import annotations
from types import TracebackType
from typing import Any, ClassVar, cast
################################################################
# concat_tb
################################################################
# We need to compute a new traceback that is the concatenation of two existing
# tracebacks. This requires copying the entries in 'head' and then pointing
# the final tb_next to 'tail'.
#
# NB: 'tail' might be None, which requires some special handling in the ctypes
# version.
#
# The complication here is that Python doesn't actually support copying or
# modifying traceback objects, so we have to get creative...
#
# On CPython, we use ctypes. On PyPy, we use "transparent proxies".
#
# Jinja2 is a useful source of inspiration:
# https://github.com/pallets/jinja/blob/main/src/jinja2/debug.py
try:
import tputil
except ImportError:
# ctypes it is
# How to handle refcounting? I don't want to use ctypes.py_object because
# I don't understand or trust it, and I don't want to use
# ctypes.pythonapi.Py_{Inc,Dec}Ref because we might clash with user code
# that also tries to use them but with different types. So private _ctypes
# APIs it is!
import _ctypes
import ctypes
class CTraceback(ctypes.Structure):
_fields_: ClassVar = [
("PyObject_HEAD", ctypes.c_byte * object().__sizeof__()),
("tb_next", ctypes.c_void_p),
("tb_frame", ctypes.c_void_p),
("tb_lasti", ctypes.c_int),
("tb_lineno", ctypes.c_int),
]
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
# TracebackType has no public constructor, so allocate one the hard way
try:
raise ValueError
except ValueError as exc:
new_tb = exc.__traceback__
assert new_tb is not None
c_new_tb = CTraceback.from_address(id(new_tb))
# At the C level, tb_next either points to the next traceback or is
# NULL. c_void_p and the .tb_next accessor both convert NULL to None,
# but we shouldn't DECREF None just because we assigned to a NULL
# pointer! Here we know that our new traceback has only 1 frame in it,
# so we can assume the tb_next field is NULL.
assert c_new_tb.tb_next is None
# If tb_next is None, then we want to set c_new_tb.tb_next to NULL,
# which it already is, so we're done. Otherwise, we have to actually
# do some work:
if tb_next is not None:
_ctypes.Py_INCREF(tb_next) # type: ignore[attr-defined]
c_new_tb.tb_next = id(tb_next)
assert c_new_tb.tb_frame is not None
_ctypes.Py_INCREF(base_tb.tb_frame) # type: ignore[attr-defined]
old_tb_frame = new_tb.tb_frame
c_new_tb.tb_frame = id(base_tb.tb_frame)
_ctypes.Py_DECREF(old_tb_frame) # type: ignore[attr-defined]
c_new_tb.tb_lasti = base_tb.tb_lasti
c_new_tb.tb_lineno = base_tb.tb_lineno
try:
return new_tb
finally:
# delete references from locals to avoid creating cycles
# see test_cancel_scope_exit_doesnt_create_cyclic_garbage
del new_tb, old_tb_frame
else:
# http://doc.pypy.org/en/latest/objspace-proxies.html
def copy_tb(base_tb: TracebackType, tb_next: TracebackType | None) -> TracebackType:
# tputil.ProxyOperation is PyPy-only, and there's no way to specify
# cpython/pypy in current type checkers.
def controller(operation: tputil.ProxyOperation) -> Any | None: # type: ignore[no-any-unimported]
# Rationale for pragma: I looked fairly carefully and tried a few
# things, and AFAICT it's not actually possible to get any
# 'opname' that isn't __getattr__ or __getattribute__. So there's
# no missing test we could add, and no value in coverage nagging
# us about adding one.
if (
operation.opname
in {
"__getattribute__",
"__getattr__",
}
and operation.args[0] == "tb_next"
): # pragma: no cover
return tb_next
return operation.delegate() # Delegate is reverting to original behaviour
return cast(
TracebackType,
tputil.make_proxy(controller, type(base_tb), base_tb),
) # Returns proxy to traceback
# this is used for collapsing single-exception ExceptionGroups when using
# `strict_exception_groups=False`. Once that is retired this function and its helper can
# be removed as well.
def concat_tb(
head: TracebackType | None,
tail: TracebackType | None,
) -> TracebackType | None:
# We have to use an iterative algorithm here, because in the worst case
# this might be a RecursionError stack that is by definition too deep to
# process by recursion!
head_tbs = []
pointer = head
while pointer is not None:
head_tbs.append(pointer)
pointer = pointer.tb_next
current_head = tail
for head_tb in reversed(head_tbs):
current_head = copy_tb(head_tb, tb_next=current_head)
return current_head

View File

@ -0,0 +1,220 @@
from __future__ import annotations
import threading
from collections import deque
from typing import TYPE_CHECKING, Callable, NoReturn, Tuple
import attrs
from .. import _core
from .._util import NoPublicConstructor, final
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
from typing_extensions import TypeVarTuple, Unpack
PosArgsT = TypeVarTuple("PosArgsT")
Function = Callable[..., object]
Job = Tuple[Function, Tuple[object, ...]]
@attrs.define
class EntryQueue:
# This used to use a queue.Queue. but that was broken, because Queues are
# implemented in Python, and not reentrant -- so it was thread-safe, but
# not signal-safe. deque is implemented in C, so each operation is atomic
# WRT threads (and this is guaranteed in the docs), AND each operation is
# atomic WRT signal delivery (signal handlers can run on either side, but
# not *during* a deque operation). dict makes similar guarantees - and
# it's even ordered!
queue: deque[Job] = attrs.Factory(deque)
idempotent_queue: dict[Job, None] = attrs.Factory(dict)
wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
done: bool = False
# Must be a reentrant lock, because it's acquired from signal handlers.
# RLock is signal-safe as of cpython 3.2. NB that this does mean that the
# lock is effectively *disabled* when we enter from signal context. The
# way we use the lock this is OK though, because when
# run_sync_soon is called from a signal it's atomic WRT the
# main thread -- it just might happen at some inconvenient place. But if
# you look at the one place where the main thread holds the lock, it's
# just to make 1 assignment, so that's atomic WRT a signal anyway.
lock: threading.RLock = attrs.Factory(threading.RLock)
async def task(self) -> None:
assert _core.currently_ki_protected()
# RLock has two implementations: a signal-safe version in _thread, and
# and signal-UNsafe version in threading. We need the signal safe
# version. Python 3.2 and later should always use this anyway, but,
# since the symptoms if this goes wrong are just "weird rare
# deadlocks", then let's make a little check.
# See:
# https://bugs.python.org/issue13697#msg237140
assert self.lock.__class__.__module__ == "_thread"
def run_cb(job: Job) -> None:
# We run this with KI protection enabled; it's the callback's
# job to disable it if it wants it disabled. Exceptions are
# treated like system task exceptions (i.e., converted into
# TrioInternalError and cause everything to shut down).
sync_fn, args = job
try:
sync_fn(*args)
except BaseException as exc:
async def kill_everything(exc: BaseException) -> NoReturn:
raise exc
try:
_core.spawn_system_task(kill_everything, exc)
except RuntimeError:
# We're quite late in the shutdown process and the
# system nursery is already closed.
# TODO(2020-06): this is a gross hack and should
# be fixed soon when we address #1607.
parent_nursery = _core.current_task().parent_nursery
if parent_nursery is None:
raise AssertionError(
"Internal error: `parent_nursery` should never be `None`",
) from exc # pragma: no cover
parent_nursery.start_soon(kill_everything, exc)
# This has to be carefully written to be safe in the face of new items
# being queued while we iterate, and to do a bounded amount of work on
# each pass:
def run_all_bounded() -> None:
for _ in range(len(self.queue)):
run_cb(self.queue.popleft())
for job in list(self.idempotent_queue):
del self.idempotent_queue[job]
run_cb(job)
try:
while True:
run_all_bounded()
if not self.queue and not self.idempotent_queue:
await self.wakeup.wait_woken()
else:
await _core.checkpoint()
except _core.Cancelled:
# Keep the work done with this lock held as minimal as possible,
# because it doesn't protect us against concurrent signal delivery
# (see the comment above). Notice that this code would still be
# correct if written like:
# self.done = True
# with self.lock:
# pass
# because all we want is to force run_sync_soon
# to either be completely before or completely after the write to
# done. That's why we don't need the lock to protect
# against signal handlers.
with self.lock:
self.done = True
# No more jobs will be submitted, so just clear out any residual
# ones:
run_all_bounded()
assert not self.queue
assert not self.idempotent_queue
def close(self) -> None:
self.wakeup.close()
def size(self) -> int:
return len(self.queue) + len(self.idempotent_queue)
def run_sync_soon(
self,
sync_fn: Callable[[Unpack[PosArgsT]], object],
*args: Unpack[PosArgsT],
idempotent: bool = False,
) -> None:
with self.lock:
if self.done:
raise _core.RunFinishedError("run() has exited")
# We have to hold the lock all the way through here, because
# otherwise the main thread might exit *while* we're doing these
# calls, and then our queue item might not be processed, or the
# wakeup call might trigger an OSError b/c the IO manager has
# already been shut down.
if idempotent:
self.idempotent_queue[(sync_fn, args)] = None
else:
self.queue.append((sync_fn, args))
self.wakeup.wakeup_thread_and_signal_safe()
@final
@attrs.define(eq=False)
class TrioToken(metaclass=NoPublicConstructor):
"""An opaque object representing a single call to :func:`trio.run`.
It has no public constructor; instead, see :func:`current_trio_token`.
This object has two uses:
1. It lets you re-enter the Trio run loop from external threads or signal
handlers. This is the low-level primitive that :func:`trio.to_thread`
and `trio.from_thread` use to communicate with worker threads, that
`trio.open_signal_receiver` uses to receive notifications about
signals, and so forth.
2. Each call to :func:`trio.run` has exactly one associated
:class:`TrioToken` object, so you can use it to identify a particular
call.
"""
_reentry_queue: EntryQueue
def run_sync_soon(
self,
sync_fn: Callable[[Unpack[PosArgsT]], object],
*args: Unpack[PosArgsT],
idempotent: bool = False,
) -> None:
"""Schedule a call to ``sync_fn(*args)`` to occur in the context of a
Trio task.
This is safe to call from the main thread, from other threads, and
from signal handlers. This is the fundamental primitive used to
re-enter the Trio run loop from outside of it.
The call will happen "soon", but there's no guarantee about exactly
when, and no mechanism provided for finding out when it's happened.
If you need this, you'll have to build your own.
The call is effectively run as part of a system task (see
:func:`~trio.lowlevel.spawn_system_task`). In particular this means
that:
* :exc:`KeyboardInterrupt` protection is *enabled* by default; if
you want ``sync_fn`` to be interruptible by control-C, then you
need to use :func:`~trio.lowlevel.disable_ki_protection`
explicitly.
* If ``sync_fn`` raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. You
should be careful that ``sync_fn`` doesn't crash.
All calls with ``idempotent=False`` are processed in strict
first-in first-out order.
If ``idempotent=True``, then ``sync_fn`` and ``args`` must be
hashable, and Trio will make a best-effort attempt to discard any
call submission which is equal to an already-pending call. Trio
will process these in first-in first-out order.
Any ordering guarantees apply separately to ``idempotent=False``
and ``idempotent=True`` calls; there's no rule for how calls in the
different categories are ordered with respect to each other.
:raises trio.RunFinishedError:
if the associated call to :func:`trio.run`
has already exited. (Any call that *doesn't* raise this error
is guaranteed to be fully processed before :func:`trio.run`
exits.)
"""
self._reentry_queue.run_sync_soon(sync_fn, *args, idempotent=idempotent)

View File

@ -0,0 +1,113 @@
from trio._util import NoPublicConstructor, final
class TrioInternalError(Exception):
"""Raised by :func:`run` if we encounter a bug in Trio, or (possibly) a
misuse of one of the low-level :mod:`trio.lowlevel` APIs.
This should never happen! If you get this error, please file a bug.
Unfortunately, if you get this error it also means that all bets are off
Trio doesn't know what is going on and its normal invariants may be void.
(For example, we might have "lost track" of a task. Or lost track of all
tasks.) Again, though, this shouldn't happen.
"""
class RunFinishedError(RuntimeError):
"""Raised by `trio.from_thread.run` and similar functions if the
corresponding call to :func:`trio.run` has already finished.
"""
class WouldBlock(Exception):
"""Raised by ``X_nowait`` functions if ``X`` would block."""
@final
class Cancelled(BaseException, metaclass=NoPublicConstructor):
"""Raised by blocking calls if the surrounding scope has been cancelled.
You should let this exception propagate, to be caught by the relevant
cancel scope. To remind you of this, it inherits from :exc:`BaseException`
instead of :exc:`Exception`, just like :exc:`KeyboardInterrupt` and
:exc:`SystemExit` do. This means that if you write something like::
try:
...
except Exception:
...
then this *won't* catch a :exc:`Cancelled` exception.
You cannot raise :exc:`Cancelled` yourself. Attempting to do so
will produce a :exc:`TypeError`. Use :meth:`cancel_scope.cancel()
<trio.CancelScope.cancel>` instead.
.. note::
In the US it's also common to see this word spelled "canceled", with
only one "l". This is a `recent
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=5&smoothing=3&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
and `US-specific
<https://books.google.com/ngrams/graph?content=canceled%2Ccancelled&year_start=1800&year_end=2000&corpus=18&smoothing=3&share=&direct_url=t1%3B%2Ccanceled%3B%2Cc0%3B.t1%3B%2Ccancelled%3B%2Cc0>`__
innovation, and even in the US both forms are still commonly used. So
for consistency with the rest of the world and with "cancellation"
(which always has two "l"s), Trio uses the two "l" spelling
everywhere.
"""
def __str__(self) -> str:
return "Cancelled"
class BusyResourceError(Exception):
"""Raised when a task attempts to use a resource that some other task is
already using, and this would lead to bugs and nonsense.
For example, if two tasks try to send data through the same socket at the
same time, Trio will raise :class:`BusyResourceError` instead of letting
the data get scrambled.
"""
class ClosedResourceError(Exception):
"""Raised when attempting to use a resource after it has been closed.
Note that "closed" here means that *your* code closed the resource,
generally by calling a method with a name like ``close`` or ``aclose``, or
by exiting a context manager. If a problem arises elsewhere for example,
because of a network failure, or because a remote peer closed their end of
a connection then that should be indicated by a different exception
class, like :exc:`BrokenResourceError` or an :exc:`OSError` subclass.
"""
class BrokenResourceError(Exception):
"""Raised when an attempt to use a resource fails due to external
circumstances.
For example, you might get this if you try to send data on a stream where
the remote side has already closed the connection.
You *don't* get this error if *you* closed the resource in that case you
get :class:`ClosedResourceError`.
This exception's ``__cause__`` attribute will often contain more
information about the underlying error.
"""
class EndOfChannel(Exception):
"""Raised when trying to receive from a :class:`trio.abc.ReceiveChannel`
that has no more data to receive.
This is analogous to an "end-of-file" condition, but for channels.
"""

View File

@ -0,0 +1,51 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
from ._instrumentation import Instrument
__all__ = ["add_instrument", "remove_instrument"]
def add_instrument(instrument: Instrument) -> None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.add_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def remove_instrument(instrument: Instrument) -> None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.instruments.remove_instrument(instrument)
except AttributeError:
raise RuntimeError("must be called from async context") from None

View File

@ -0,0 +1,98 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
from .._file_io import _HasFileNo
assert not TYPE_CHECKING or sys.platform == "linux"
__all__ = ["notify_closing", "wait_readable", "wait_writable"]
async def wait_readable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_writable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def notify_closing(fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None

View File

@ -0,0 +1,156 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Callable, ContextManager
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
import select
from .. import _core
from .._file_io import _HasFileNo
from ._traps import Abort, RaiseCancelT
assert not TYPE_CHECKING or sys.platform == "darwin"
__all__ = [
"current_kqueue",
"monitor_kevent",
"notify_closing",
"wait_kevent",
"wait_readable",
"wait_writable",
]
def current_kqueue() -> select.kqueue:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_kqueue()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def monitor_kevent(
ident: int,
filter: int,
) -> ContextManager[_core.UnboundedQueue[select.kevent]]:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_kevent(ident, filter)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_kevent(
ident: int,
filter: int,
abort_func: Callable[[RaiseCancelT], Abort],
) -> Abort:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_kevent(
ident,
filter,
abort_func,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_readable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_writable(fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def notify_closing(fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(fd)
except AttributeError:
raise RuntimeError("must be called from async context") from None

View File

@ -0,0 +1,209 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, ContextManager
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import GLOBAL_RUN_CONTEXT
if TYPE_CHECKING:
from typing_extensions import Buffer
from .._file_io import _HasFileNo
from ._unbounded_queue import UnboundedQueue
from ._windows_cffi import CData, Handle
assert not TYPE_CHECKING or sys.platform == "win32"
__all__ = [
"current_iocp",
"monitor_completion_key",
"notify_closing",
"readinto_overlapped",
"register_with_iocp",
"wait_overlapped",
"wait_readable",
"wait_writable",
"write_overlapped",
]
async def wait_readable(sock: _HasFileNo | int) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``sock`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``sock`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_readable(sock)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_writable(sock: _HasFileNo | int) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``sock``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_writable(sock)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def notify_closing(handle: Handle | int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.notify_closing(handle)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def register_with_iocp(handle: int | CData) -> None:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.register_with_iocp(handle)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_overlapped(handle_: int | CData, lpOverlapped: CData | int) -> object:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.wait_overlapped(
handle_,
lpOverlapped,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def write_overlapped(
handle: int | CData,
data: Buffer,
file_offset: int = 0,
) -> int:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.write_overlapped(
handle,
data,
file_offset,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def readinto_overlapped(
handle: int | CData,
buffer: Buffer,
file_offset: int = 0,
) -> int:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.io_manager.readinto_overlapped(
handle,
buffer,
file_offset,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_iocp() -> int:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.current_iocp()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def monitor_completion_key() -> ContextManager[tuple[int, UnboundedQueue[object]]]:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__ and `#52
<https://github.com/python-trio/trio/issues/52>`__.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.io_manager.monitor_completion_key()
except AttributeError:
raise RuntimeError("must be called from async context") from None

View File

@ -0,0 +1,273 @@
# ***********************************************************
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
# *************************************************************
from __future__ import annotations
import sys
from typing import TYPE_CHECKING, Any
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
from ._run import _NO_SEND, GLOBAL_RUN_CONTEXT, RunStatistics, Task
if TYPE_CHECKING:
import contextvars
from collections.abc import Awaitable, Callable
from outcome import Outcome
from typing_extensions import Unpack
from .._abc import Clock
from ._entry_queue import TrioToken
from ._run import PosArgT
__all__ = [
"current_clock",
"current_root_task",
"current_statistics",
"current_time",
"current_trio_token",
"reschedule",
"spawn_system_task",
"wait_all_tasks_blocked",
]
def current_statistics() -> RunStatistics:
"""Returns ``RunStatistics``, which contains run-loop-level debugging information.
Currently, the following fields are defined:
* ``tasks_living`` (int): The number of tasks that have been spawned
and not yet exited.
* ``tasks_runnable`` (int): The number of tasks that are currently
queued on the run queue (as opposed to blocked waiting for something
to happen).
* ``seconds_to_next_deadline`` (float): The time until the next
pending cancel scope deadline. May be negative if the deadline has
expired but we haven't yet processed cancellations. May be
:data:`~math.inf` if there are no pending deadlines.
* ``run_sync_soon_queue_size`` (int): The number of
unprocessed callbacks queued via
:meth:`trio.lowlevel.TrioToken.run_sync_soon`.
* ``io_statistics`` (object): Some statistics from Trio's I/O
backend. This always has an attribute ``backend`` which is a string
naming which operating-system-specific I/O backend is in use; the
other attributes vary between backends.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_statistics()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_time() -> float:
"""Returns the current time according to Trio's internal clock.
Returns:
float: The current time.
Raises:
RuntimeError: if not inside a call to :func:`trio.run`.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_time()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_clock() -> Clock:
"""Returns the current :class:`~trio.abc.Clock`."""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_clock()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_root_task() -> Task | None:
"""Returns the current root :class:`Task`.
This is the task that is the ultimate parent of all other tasks.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_root_task()
except AttributeError:
raise RuntimeError("must be called from async context") from None
def reschedule(task: Task, next_send: Outcome[Any] = _NO_SEND) -> None:
"""Reschedule the given task with the given
:class:`outcome.Outcome`.
See :func:`wait_task_rescheduled` for the gory details.
There must be exactly one call to :func:`reschedule` for every call to
:func:`wait_task_rescheduled`. (And when counting, keep in mind that
returning :data:`Abort.SUCCEEDED` from an abort callback is equivalent
to calling :func:`reschedule` once.)
Args:
task (trio.lowlevel.Task): the task to be rescheduled. Must be blocked
in a call to :func:`wait_task_rescheduled`.
next_send (outcome.Outcome): the value (or error) to return (or
raise) from :func:`wait_task_rescheduled`.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.reschedule(task, next_send)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def spawn_system_task(
async_fn: Callable[[Unpack[PosArgT]], Awaitable[object]],
*args: Unpack[PosArgT],
name: object = None,
context: contextvars.Context | None = None,
) -> Task:
"""Spawn a "system" task.
System tasks have a few differences from regular tasks:
* They don't need an explicit nursery; instead they go into the
internal "system nursery".
* If a system task raises an exception, then it's converted into a
:exc:`~trio.TrioInternalError` and *all* tasks are cancelled. If you
write a system task, you should be careful to make sure it doesn't
crash.
* System tasks are automatically cancelled when the main task exits.
* By default, system tasks have :exc:`KeyboardInterrupt` protection
*enabled*. If you want your task to be interruptible by control-C,
then you need to use :func:`disable_ki_protection` explicitly (and
come up with some plan for what to do with a
:exc:`KeyboardInterrupt`, given that system tasks aren't allowed to
raise exceptions).
* System tasks do not inherit context variables from their creator.
Towards the end of a call to :meth:`trio.run`, after the main
task and all system tasks have exited, the system nursery
becomes closed. At this point, new calls to
:func:`spawn_system_task` will raise ``RuntimeError("Nursery
is closed to new arrivals")`` instead of creating a system
task. It's possible to encounter this state either in
a ``finally`` block in an async generator, or in a callback
passed to :meth:`TrioToken.run_sync_soon` at the right moment.
Args:
async_fn: An async callable.
args: Positional arguments for ``async_fn``. If you want to pass
keyword arguments, use :func:`functools.partial`.
name: The name for this task. Only used for debugging/introspection
(e.g. ``repr(task_obj)``). If this isn't a string,
:func:`spawn_system_task` will try to make it one. A common use
case is if you're wrapping a function before spawning a new
task, you might pass the original function as the ``name=`` to
make debugging easier.
context: An optional ``contextvars.Context`` object with context variables
to use for this task. You would normally get a copy of the current
context with ``context = contextvars.copy_context()`` and then you would
pass that ``context`` object here.
Returns:
Task: the newly spawned task
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.spawn_system_task(
async_fn,
*args,
name=name,
context=context,
)
except AttributeError:
raise RuntimeError("must be called from async context") from None
def current_trio_token() -> TrioToken:
"""Retrieve the :class:`TrioToken` for the current call to
:func:`trio.run`.
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return GLOBAL_RUN_CONTEXT.runner.current_trio_token()
except AttributeError:
raise RuntimeError("must be called from async context") from None
async def wait_all_tasks_blocked(cushion: float = 0.0) -> None:
"""Block until there are no runnable tasks.
This is useful in testing code when you want to give other tasks a
chance to "settle down". The calling task is blocked, and doesn't wake
up until all other tasks are also blocked for at least ``cushion``
seconds. (Setting a non-zero ``cushion`` is intended to handle cases
like two tasks talking to each other over a local socket, where we
want to ignore the potential brief moment between a send and receive
when all tasks are blocked.)
Note that ``cushion`` is measured in *real* time, not the Trio clock
time.
If there are multiple tasks blocked in :func:`wait_all_tasks_blocked`,
then the one with the shortest ``cushion`` is the one woken (and
this task becoming unblocked resets the timers for the remaining
tasks). If there are multiple tasks that have exactly the same
``cushion``, then all are woken.
You should also consider :class:`trio.testing.Sequencer`, which
provides a more explicit way to control execution ordering within a
test, and will often produce more readable tests.
Example:
Here's an example of one way to test that Trio's locks are fair: we
take the lock in the parent, start a child, wait for the child to be
blocked waiting for the lock (!), and then check that we can't
release and immediately re-acquire the lock::
async def lock_taker(lock):
await lock.acquire()
lock.release()
async def test_lock_fairness():
lock = trio.Lock()
await lock.acquire()
async with trio.open_nursery() as nursery:
nursery.start_soon(lock_taker, lock)
# child hasn't run yet, we have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
await trio.testing.wait_all_tasks_blocked()
# now the child has run and is blocked on lock.acquire(), we
# still have the lock
assert lock.locked()
assert lock._owner is trio.lowlevel.current_task()
lock.release()
try:
# The child has a prior claim, so we can't have it
lock.acquire_nowait()
except trio.WouldBlock:
assert lock._owner is not trio.lowlevel.current_task()
print("PASS")
else:
print("FAIL")
"""
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
try:
return await GLOBAL_RUN_CONTEXT.runner.wait_all_tasks_blocked(cushion)
except AttributeError:
raise RuntimeError("must be called from async context") from None

View File

@ -0,0 +1,108 @@
import logging
import types
from typing import Any, Callable, Dict, Sequence, TypeVar
from .._abc import Instrument
# Used to log exceptions in instruments
INSTRUMENT_LOGGER = logging.getLogger("trio.abc.Instrument")
F = TypeVar("F", bound=Callable[..., Any])
# Decorator to mark methods public. This does nothing by itself, but
# trio/_tools/gen_exports.py looks for it.
def _public(fn: F) -> F:
return fn
class Instruments(Dict[str, Dict[Instrument, None]]):
"""A collection of `trio.abc.Instrument` organized by hook.
Instrumentation calls are rather expensive, and we don't want a
rarely-used instrument (like before_run()) to slow down hot
operations (like before_task_step()). Thus, we cache the set of
instruments to be called for each hook, and skip the instrumentation
call if there's nothing currently installed for that hook.
"""
__slots__ = ()
def __init__(self, incoming: Sequence[Instrument]):
self["_all"] = {}
for instrument in incoming:
self.add_instrument(instrument)
@_public
def add_instrument(self, instrument: Instrument) -> None:
"""Start instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to activate.
If ``instrument`` is already active, does nothing.
"""
if instrument in self["_all"]:
return
self["_all"][instrument] = None
try:
for name in dir(instrument):
if name.startswith("_"):
continue
try:
prototype = getattr(Instrument, name)
except AttributeError:
continue
impl = getattr(instrument, name)
if isinstance(impl, types.MethodType) and impl.__func__ is prototype:
# Inherited unchanged from _abc.Instrument
continue
self.setdefault(name, {})[instrument] = None
except:
self.remove_instrument(instrument)
raise
@_public
def remove_instrument(self, instrument: Instrument) -> None:
"""Stop instrumenting the current run loop with the given instrument.
Args:
instrument (trio.abc.Instrument): The instrument to de-activate.
Raises:
KeyError: if the instrument is not currently active. This could
occur either because you never added it, or because you added it
and then it raised an unhandled exception and was automatically
deactivated.
"""
# If instrument isn't present, the KeyError propagates out
self["_all"].pop(instrument)
for hookname, instruments in list(self.items()):
if instrument in instruments:
del instruments[instrument]
if not instruments:
del self[hookname]
def call(self, hookname: str, *args: Any) -> None:
"""Call hookname(*args) on each applicable instrument.
You must first check whether there are any instruments installed for
that hook, e.g.::
if "before_task_step" in instruments:
instruments.call("before_task_step", task)
"""
for instrument in list(self[hookname]):
try:
getattr(instrument, hookname)(*args)
except BaseException:
self.remove_instrument(instrument)
INSTRUMENT_LOGGER.exception(
"Exception raised when calling %r on instrument %r. "
"Instrument has been disabled.",
hookname,
instrument,
)

View File

@ -0,0 +1,31 @@
from __future__ import annotations
import copy
from typing import TYPE_CHECKING
import outcome
from .. import _core
if TYPE_CHECKING:
from ._io_epoll import EpollWaiters
from ._io_windows import AFDWaiters
# Utility function shared between _io_epoll and _io_windows
def wake_all(waiters: EpollWaiters | AFDWaiters, exc: BaseException) -> None:
try:
current_task = _core.current_task()
except RuntimeError:
current_task = None
raise_at_end = False
for attr_name in ["read_task", "write_task"]:
task = getattr(waiters, attr_name)
if task is not None:
if task is current_task:
raise_at_end = True
else:
_core.reschedule(task, outcome.Error(copy.copy(exc)))
setattr(waiters, attr_name, None)
if raise_at_end:
raise exc

View File

@ -0,0 +1,387 @@
from __future__ import annotations
import contextlib
import select
import sys
from collections import defaultdict
from typing import TYPE_CHECKING, Literal
import attrs
from .. import _core
from ._io_common import wake_all
from ._run import Task, _public
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from .._core import Abort, RaiseCancelT
from .._file_io import _HasFileNo
@attrs.define(eq=False)
class EpollWaiters:
read_task: Task | None = None
write_task: Task | None = None
current_flags: int = 0
assert not TYPE_CHECKING or sys.platform == "linux"
EventResult: TypeAlias = "list[tuple[int, int]]"
@attrs.frozen(eq=False)
class _EpollStatistics:
tasks_waiting_read: int
tasks_waiting_write: int
backend: Literal["epoll"] = attrs.field(init=False, default="epoll")
# Some facts about epoll
# ----------------------
#
# Internally, an epoll object is sort of like a WeakKeyDictionary where the
# keys are tuples of (fd number, file object). When you call epoll_ctl, you
# pass in an fd; that gets converted to an (fd number, file object) tuple by
# looking up the fd in the process's fd table at the time of the call. When an
# event happens on the file object, epoll_wait drops the file object part, and
# just returns the fd number in its event. So from the outside it looks like
# it's keeping a table of fds, but really it's a bit more complicated. This
# has some subtle consequences.
#
# In general, file objects inside the kernel are reference counted. Each entry
# in a process's fd table holds a strong reference to the corresponding file
# object, and most operations that use file objects take a temporary strong
# reference while they're working. So when you call close() on an fd, that
# might or might not cause the file object to be deallocated -- it depends on
# whether there are any other references to that file object. Some common ways
# this can happen:
#
# - after calling dup(), you have two fds in the same process referring to the
# same file object. Even if you close one fd (= remove that entry from the
# fd table), the file object will be kept alive by the other fd.
# - when calling fork(), the child inherits a copy of the parent's fd table,
# so all the file objects get another reference. (But if the fork() is
# followed by exec(), then all of the child's fds that have the CLOEXEC flag
# set will be closed at that point.)
# - most syscalls that work on fds take a strong reference to the underlying
# file object while they're using it. So there's one thread blocked in
# read(fd), and then another thread calls close() on the last fd referring
# to that object, the underlying file won't actually be closed until
# after read() returns.
#
# However, epoll does *not* take a reference to any of the file objects in its
# interest set (that's what makes it similar to a WeakKeyDictionary). File
# objects inside an epoll interest set will be deallocated if all *other*
# references to them are closed. And when that happens, the epoll object will
# automatically deregister that file object and stop reporting events on it.
# So that's quite handy.
#
# But, what happens if we do this?
#
# fd1 = open(...)
# epoll_ctl(EPOLL_CTL_ADD, fd1, ...)
# fd2 = dup(fd1)
# close(fd1)
#
# In this case, the dup() keeps the underlying file object alive, so it
# remains registered in the epoll object's interest set, as the tuple (fd1,
# file object). But, fd1 no longer refers to this file object! You might think
# there was some magic to handle this, but unfortunately no; the consequences
# are totally predictable from what I said above:
#
# If any events occur on the file object, then epoll will report them as
# happening on fd1, even though that doesn't make sense.
#
# Perhaps we would like to deregister fd1 to stop getting nonsensical events.
# But how? When we call epoll_ctl, we have to pass an fd number, which will
# get expanded to an (fd number, file object) tuple. We can't pass fd1,
# because when epoll_ctl tries to look it up, it won't find our file object.
# And we can't pass fd2, because that will get expanded to (fd2, file object),
# which is a different lookup key. In fact, it's *impossible* to de-register
# this fd!
#
# We could even have fd1 get assigned to another file object, and then we can
# have multiple keys registered simultaneously using the same fd number, like:
# (fd1, file object 1), (fd1, file object 2). And if events happen on either
# file object, then epoll will happily report that something happened to
# "fd1".
#
# Now here's what makes this especially nasty: suppose the old file object
# becomes, say, readable. That means that every time we call epoll_wait, it
# will return immediately to tell us that "fd1" is readable. Normally, we
# would handle this by de-registering fd1, waking up the corresponding call to
# wait_readable, then the user will call read() or recv() or something, and
# we're fine. But if this happens on a stale fd where we can't remove the
# registration, then we might get stuck in a state where epoll_wait *always*
# returns immediately, so our event loop becomes unable to sleep, and now our
# program is burning 100% of the CPU doing nothing, with no way out.
#
#
# What does this mean for Trio?
# -----------------------------
#
# Since we don't control the user's code, we have no way to guarantee that we
# don't get stuck with stale fd's in our epoll interest set. For example, a
# user could call wait_readable(fd) in one task, and then while that's
# running, they might close(fd) from another task. In this situation, they're
# *supposed* to call notify_closing(fd) to let us know what's happening, so we
# can interrupt the wait_readable() call and avoid getting into this mess. And
# that's the only thing that can possibly work correctly in all cases. But
# sometimes user code has bugs. So if this does happen, we'd like to degrade
# gracefully, and survive without corrupting Trio's internal state or
# otherwise causing the whole program to explode messily.
#
# Our solution: we always use EPOLLONESHOT. This way, we might get *one*
# spurious event on a stale fd, but then epoll will automatically silence it
# until we explicitly say that we want more events... and if we have a stale
# fd, then we actually can't re-enable it! So we can't get stuck in an
# infinite busy-loop. If there's a stale fd hanging around, then it might
# cause a spurious `BusyResourceError`, or cause one wait_* call to return
# before it should have... but in general, the wait_* functions are allowed to
# have some spurious wakeups; the user code will just attempt the operation,
# get EWOULDBLOCK, and call wait_* again. And the program as a whole will
# survive, any exceptions will propagate, etc.
#
# As a bonus, EPOLLONESHOT also saves us having to explicitly deregister fds
# on the normal wakeup path, so it's a bit more efficient in general.
#
# However, EPOLLONESHOT has a few trade-offs to consider:
#
# First, you can't combine EPOLLONESHOT with EPOLLEXCLUSIVE. This is a bit sad
# in one somewhat rare case: if you have a multi-process server where a group
# of processes all share the same listening socket, then EPOLLEXCLUSIVE can be
# used to avoid "thundering herd" problems when a new connection comes in. But
# this isn't too bad. It's not clear if EPOLLEXCLUSIVE even works for us
# anyway:
#
# https://stackoverflow.com/questions/41582560/how-does-epolls-epollexclusive-mode-interact-with-level-triggering
#
# And it's not clear that EPOLLEXCLUSIVE is a great approach either:
#
# https://blog.cloudflare.com/the-sad-state-of-linux-socket-balancing/
#
# And if we do need to support this, we could always add support through some
# more-specialized API in the future. So this isn't a blocker to using
# EPOLLONESHOT.
#
# Second, EPOLLONESHOT does not actually *deregister* the fd after delivering
# an event (EPOLL_CTL_DEL). Instead, it keeps the fd registered, but
# effectively does an EPOLL_CTL_MOD to set the fd's interest flags to
# all-zeros. So we could still end up with an fd hanging around in the
# interest set for a long time, even if we're not using it.
#
# Fortunately, this isn't a problem, because it's only a weak reference if
# we have a stale fd that's been silenced by EPOLLONESHOT, then it wastes a
# tiny bit of kernel memory remembering this fd that can never be revived, but
# when the underlying file object is eventually closed, that memory will be
# reclaimed. So that's OK.
#
# The other issue is that when someone calls wait_*, using EPOLLONESHOT means
# that if we have ever waited for this fd before, we have to use EPOLL_CTL_MOD
# to re-enable it; but if it's a new fd, we have to use EPOLL_CTL_ADD. How do
# we know which one to use? There's no reasonable way to track which fds are
# currently registered -- remember, we're assuming the user might have gone
# and rearranged their fds without telling us!
#
# Fortunately, this also has a simple solution: if we wait on a socket or
# other fd once, then we'll probably wait on it lots of times. And the epoll
# object itself knows which fds it already has registered. So when an fd comes
# in, we optimistically assume that it's been waited on before, and try doing
# EPOLL_CTL_MOD. And if that fails with an ENOENT error, then we try again
# with EPOLL_CTL_ADD.
#
# So that's why this code is the way it is. And now you know more than you
# wanted to about how epoll works.
@attrs.define(eq=False)
class EpollIOManager:
# Using lambda here because otherwise crash on import with gevent monkey patching
# See https://github.com/python-trio/trio/issues/2848
_epoll: select.epoll = attrs.Factory(lambda: select.epoll())
# {fd: EpollWaiters}
_registered: defaultdict[int, EpollWaiters] = attrs.Factory(
lambda: defaultdict(EpollWaiters),
)
_force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
_force_wakeup_fd: int | None = None
def __attrs_post_init__(self) -> None:
self._epoll.register(self._force_wakeup.wakeup_sock, select.EPOLLIN)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self) -> _EpollStatistics:
tasks_waiting_read = 0
tasks_waiting_write = 0
for waiter in self._registered.values():
if waiter.read_task is not None:
tasks_waiting_read += 1
if waiter.write_task is not None:
tasks_waiting_write += 1
return _EpollStatistics(
tasks_waiting_read=tasks_waiting_read,
tasks_waiting_write=tasks_waiting_write,
)
def close(self) -> None:
self._epoll.close()
self._force_wakeup.close()
def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
# Return value must be False-y IFF the timeout expired, NOT if any I/O
# happened or force_wakeup was called. Otherwise it can be anything; gets
# passed straight through to process_events.
def get_events(self, timeout: float) -> EventResult:
# max_events must be > 0 or epoll gets cranky
# accessing self._registered from a thread looks dangerous, but it's
# OK because it doesn't matter if our value is a little bit off.
max_events = max(1, len(self._registered))
return self._epoll.poll(timeout, max_events)
def process_events(self, events: EventResult) -> None:
for fd, flags in events:
if fd == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
waiters = self._registered[fd]
# EPOLLONESHOT always clears the flags when an event is delivered
waiters.current_flags = 0
# Clever hack stolen from selectors.EpollSelector: an event
# with EPOLLHUP or EPOLLERR flags wakes both readers and
# writers.
if flags & ~select.EPOLLIN and waiters.write_task is not None:
_core.reschedule(waiters.write_task)
waiters.write_task = None
if flags & ~select.EPOLLOUT and waiters.read_task is not None:
_core.reschedule(waiters.read_task)
waiters.read_task = None
self._update_registrations(fd)
def _update_registrations(self, fd: int) -> None:
waiters = self._registered[fd]
wanted_flags = 0
if waiters.read_task is not None:
wanted_flags |= select.EPOLLIN
if waiters.write_task is not None:
wanted_flags |= select.EPOLLOUT
if wanted_flags != waiters.current_flags:
try:
try:
# First try EPOLL_CTL_MOD
self._epoll.modify(fd, wanted_flags | select.EPOLLONESHOT)
except OSError:
# If that fails, it might be a new fd; try EPOLL_CTL_ADD
self._epoll.register(fd, wanted_flags | select.EPOLLONESHOT)
waiters.current_flags = wanted_flags
except OSError as exc:
# If everything fails, probably it's a bad fd, e.g. because
# the fd was closed behind our back. In this case we don't
# want to try to unregister the fd, because that will probably
# fail too. Just clear our state and wake everyone up.
del self._registered[fd]
# This could raise (in case we're calling this inside one of
# the to-be-woken tasks), so we have to do it last.
wake_all(waiters, exc)
return
if not wanted_flags:
del self._registered[fd]
async def _epoll_wait(self, fd: int | _HasFileNo, attr_name: str) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
waiters = self._registered[fd]
if getattr(waiters, attr_name) is not None:
raise _core.BusyResourceError(
"another task is already reading / writing this fd",
)
setattr(waiters, attr_name, _core.current_task())
self._update_registrations(fd)
def abort(_: RaiseCancelT) -> Abort:
setattr(waiters, attr_name, None)
self._update_registrations(fd)
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort)
@_public
async def wait_readable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._epoll_wait(fd, "read_task")
@_public
async def wait_writable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._epoll_wait(fd, "write_task")
@_public
def notify_closing(self, fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
if not isinstance(fd, int):
fd = fd.fileno()
wake_all(
self._registered[fd],
_core.ClosedResourceError("another task closed this fd"),
)
del self._registered[fd]
with contextlib.suppress(OSError, ValueError):
self._epoll.unregister(fd)

View File

@ -0,0 +1,292 @@
from __future__ import annotations
import errno
import select
import sys
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Iterator, Literal
import attrs
import outcome
from .. import _core
from ._run import _public
from ._wakeup_socketpair import WakeupSocketpair
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from .._core import Abort, RaiseCancelT, Task, UnboundedQueue
from .._file_io import _HasFileNo
assert not TYPE_CHECKING or (sys.platform != "linux" and sys.platform != "win32")
EventResult: TypeAlias = "list[select.kevent]"
@attrs.frozen(eq=False)
class _KqueueStatistics:
tasks_waiting: int
monitors: int
backend: Literal["kqueue"] = attrs.field(init=False, default="kqueue")
@attrs.define(eq=False)
class KqueueIOManager:
_kqueue: select.kqueue = attrs.Factory(select.kqueue)
# {(ident, filter): Task or UnboundedQueue}
_registered: dict[tuple[int, int], Task | UnboundedQueue[select.kevent]] = (
attrs.Factory(dict)
)
_force_wakeup: WakeupSocketpair = attrs.Factory(WakeupSocketpair)
_force_wakeup_fd: int | None = None
def __attrs_post_init__(self) -> None:
force_wakeup_event = select.kevent(
self._force_wakeup.wakeup_sock,
select.KQ_FILTER_READ,
select.KQ_EV_ADD,
)
self._kqueue.control([force_wakeup_event], 0)
self._force_wakeup_fd = self._force_wakeup.wakeup_sock.fileno()
def statistics(self) -> _KqueueStatistics:
tasks_waiting = 0
monitors = 0
for receiver in self._registered.values():
if type(receiver) is _core.Task:
tasks_waiting += 1
else:
monitors += 1
return _KqueueStatistics(tasks_waiting=tasks_waiting, monitors=monitors)
def close(self) -> None:
self._kqueue.close()
self._force_wakeup.close()
def force_wakeup(self) -> None:
self._force_wakeup.wakeup_thread_and_signal_safe()
def get_events(self, timeout: float) -> EventResult:
# max_events must be > 0 or kqueue gets cranky
# and we generally want this to be strictly larger than the actual
# number of events we get, so that we can tell that we've gotten
# all the events in just 1 call.
max_events = len(self._registered) + 1
events = []
while True:
batch = self._kqueue.control([], max_events, timeout)
events += batch
if len(batch) < max_events:
break
else:
timeout = 0
# and loop back to the start
return events
def process_events(self, events: EventResult) -> None:
for event in events:
key = (event.ident, event.filter)
if event.ident == self._force_wakeup_fd:
self._force_wakeup.drain()
continue
receiver = self._registered[key]
if event.flags & select.KQ_EV_ONESHOT:
del self._registered[key]
if isinstance(receiver, _core.Task):
_core.reschedule(receiver, outcome.Value(event))
else:
receiver.put_nowait(event)
# kevent registration is complicated -- e.g. aio submission can
# implicitly perform a EV_ADD, and EVFILT_PROC with NOTE_TRACK will
# automatically register filters for child processes. So our lowlevel
# API is *very* low-level: we expose the kqueue itself for adding
# events or sticking into AIO submission structs, and split waiting
# off into separate methods. It's your responsibility to make sure
# that handle_io never receives an event without a corresponding
# registration! This may be challenging if you want to be careful
# about e.g. KeyboardInterrupt. Possibly this API could be improved to
# be more ergonomic...
@_public
def current_kqueue(self) -> select.kqueue:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
return self._kqueue
@contextmanager
@_public
def monitor_kevent(
self,
ident: int,
filter: int,
) -> Iterator[_core.UnboundedQueue[select.kevent]]:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair",
)
q = _core.UnboundedQueue[select.kevent]()
self._registered[key] = q
try:
yield q
finally:
del self._registered[key]
@_public
async def wait_kevent(
self,
ident: int,
filter: int,
abort_func: Callable[[RaiseCancelT], Abort],
) -> Abort:
"""TODO: these are implemented, but are currently more of a sketch than
anything real. See `#26
<https://github.com/python-trio/trio/issues/26>`__.
"""
key = (ident, filter)
if key in self._registered:
raise _core.BusyResourceError(
"attempt to register multiple listeners for same ident/filter pair",
)
self._registered[key] = _core.current_task()
def abort(raise_cancel: RaiseCancelT) -> Abort:
r = abort_func(raise_cancel)
if r is _core.Abort.SUCCEEDED:
del self._registered[key]
return r
# wait_task_rescheduled does not have its return type typed
return await _core.wait_task_rescheduled(abort) # type: ignore[no-any-return]
async def _wait_common(
self,
fd: int | _HasFileNo,
filter: int,
) -> None:
if not isinstance(fd, int):
fd = fd.fileno()
flags = select.KQ_EV_ADD | select.KQ_EV_ONESHOT
event = select.kevent(fd, filter, flags)
self._kqueue.control([event], 0)
def abort(_: RaiseCancelT) -> Abort:
event = select.kevent(fd, filter, select.KQ_EV_DELETE)
try:
self._kqueue.control([event], 0)
except OSError as exc:
# kqueue tracks individual fds (*not* the underlying file
# object, see _io_epoll.py for a long discussion of why this
# distinction matters), and automatically deregisters an event
# if the fd is closed. So if kqueue.control says that it
# doesn't know about this event, then probably it's because
# the fd was closed behind our backs. (Too bad we can't ask it
# to wake us up when this happens, versus discovering it after
# the fact... oh well, you can't have everything.)
#
# FreeBSD reports this using EBADF. macOS uses ENOENT.
if exc.errno in (errno.EBADF, errno.ENOENT): # pragma: no branch
pass
else: # pragma: no cover
# As far as we know, this branch can't happen.
raise
return _core.Abort.SUCCEEDED
await self.wait_kevent(fd, filter, abort)
@_public
async def wait_readable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is readable.
On Unix systems, ``fd`` must either be an integer file descriptor,
or else an object with a ``.fileno()`` method which returns an
integer file descriptor. Any kind of file descriptor can be passed,
though the exact semantics will depend on your kernel. For example,
this probably won't do anything useful for on-disk files.
On Windows systems, ``fd`` must either be an integer ``SOCKET``
handle, or else an object with a ``.fileno()`` method which returns
an integer ``SOCKET`` handle. File descriptors aren't supported,
and neither are handles that refer to anything besides a
``SOCKET``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become readable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._wait_common(fd, select.KQ_FILTER_READ)
@_public
async def wait_writable(self, fd: int | _HasFileNo) -> None:
"""Block until the kernel reports that the given object is writable.
See `wait_readable` for the definition of ``fd``.
:raises trio.BusyResourceError:
if another task is already waiting for the given socket to
become writable.
:raises trio.ClosedResourceError:
if another task calls :func:`notify_closing` while this
function is still working.
"""
await self._wait_common(fd, select.KQ_FILTER_WRITE)
@_public
def notify_closing(self, fd: int | _HasFileNo) -> None:
"""Notify waiters of the given object that it will be closed.
Call this before closing a file descriptor (on Unix) or socket (on
Windows). This will cause any `wait_readable` or `wait_writable`
calls on the given object to immediately wake up and raise
`~trio.ClosedResourceError`.
This doesn't actually close the object you still have to do that
yourself afterwards. Also, you want to be careful to make sure no
new tasks start waiting on the object in between when you call this
and when it's actually closed. So to close something properly, you
usually want to do these steps in order:
1. Explicitly mark the object as closed, so that any new attempts
to use it will abort before they start.
2. Call `notify_closing` to wake up any already-existing users.
3. Actually close the object.
It's also possible to do them in a different order if that's more
convenient, *but only if* you make sure not to have any checkpoints in
between the steps. This way they all happen in a single atomic
step, so other tasks won't be able to tell what order they happened
in anyway.
"""
if not isinstance(fd, int):
fd = fd.fileno()
for filter_ in [select.KQ_FILTER_READ, select.KQ_FILTER_WRITE]:
key = (fd, filter_)
receiver = self._registered.get(key)
if receiver is None:
continue
if type(receiver) is _core.Task:
event = select.kevent(fd, filter_, select.KQ_EV_DELETE)
self._kqueue.control([event], 0)
exc = _core.ClosedResourceError("another task closed this fd")
_core.reschedule(receiver, outcome.Error(exc))
del self._registered[key]
else:
# XX this is an interesting example of a case where being able
# to close a queue would be useful...
raise NotImplementedError(
"can't close an fd that monitor_kevent is using",
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,237 @@
from __future__ import annotations
import inspect
import signal
import sys
from functools import wraps
from typing import TYPE_CHECKING, Final, Protocol, TypeVar
import attrs
from .._util import is_main_thread
CallableT = TypeVar("CallableT", bound="Callable[..., object]")
RetT = TypeVar("RetT")
if TYPE_CHECKING:
import types
from collections.abc import Callable
from typing_extensions import ParamSpec, TypeGuard
ArgsT = ParamSpec("ArgsT")
# In ordinary single-threaded Python code, when you hit control-C, it raises
# an exception and automatically does all the regular unwinding stuff.
#
# In Trio code, we would like hitting control-C to raise an exception and
# automatically do all the regular unwinding stuff. In particular, we would
# like to maintain our invariant that all tasks always run to completion (one
# way or another), by unwinding all of them.
#
# But it's basically impossible to write the core task running code in such a
# way that it can maintain this invariant in the face of KeyboardInterrupt
# exceptions arising at arbitrary bytecode positions. Similarly, if a
# KeyboardInterrupt happened at the wrong moment inside pretty much any of our
# inter-task synchronization or I/O primitives, then the system state could
# get corrupted and prevent our being able to clean up properly.
#
# So, we need a way to defer KeyboardInterrupt processing from these critical
# sections.
#
# Things that don't work:
#
# - Listen for SIGINT and process it in a system task: works fine for
# well-behaved programs that regularly pass through the event loop, but if
# user-code goes into an infinite loop then it can't be interrupted. Which
# is unfortunate, since dealing with infinite loops is what
# KeyboardInterrupt is for!
#
# - Use pthread_sigmask to disable signal delivery during critical section:
# (a) windows has no pthread_sigmask, (b) python threads start with all
# signals unblocked, so if there are any threads around they'll receive the
# signal and then tell the main thread to run the handler, even if the main
# thread has that signal blocked.
#
# - Install a signal handler which checks a global variable to decide whether
# to raise the exception immediately (if we're in a non-critical section),
# or to schedule it on the event loop (if we're in a critical section). The
# problem here is that it's impossible to transition safely out of user code:
#
# with keyboard_interrupt_enabled:
# msg = coro.send(value)
#
# If this raises a KeyboardInterrupt, it might be because the coroutine got
# interrupted and has unwound... or it might be the KeyboardInterrupt
# arrived just *after* 'send' returned, so the coroutine is still running,
# but we just lost the message it sent. (And worse, in our actual task
# runner, the send is hidden inside a utility function etc.)
#
# Solution:
#
# Mark *stack frames* as being interrupt-safe or interrupt-unsafe, and from
# the signal handler check which kind of frame we're currently in when
# deciding whether to raise or schedule the exception.
#
# There are still some cases where this can fail, like if someone hits
# control-C while the process is in the event loop, and then it immediately
# enters an infinite loop in user code. In this case the user has to hit
# control-C a second time. And of course if the user code is written so that
# it doesn't actually exit after a task crashes and everything gets cancelled,
# then there's not much to be done. (Hitting control-C repeatedly might help,
# but in general the solution is to kill the process some other way, just like
# for any Python program that's written to catch and ignore
# KeyboardInterrupt.)
# We use this special string as a unique key into the frame locals dictionary.
# The @ ensures it is not a valid identifier and can't clash with any possible
# real local name. See: https://github.com/python-trio/trio/issues/469
LOCALS_KEY_KI_PROTECTION_ENABLED: Final = "@TRIO_KI_PROTECTION_ENABLED"
# NB: according to the signal.signal docs, 'frame' can be None on entry to
# this function:
def ki_protection_enabled(frame: types.FrameType | None) -> bool:
while frame is not None:
if LOCALS_KEY_KI_PROTECTION_ENABLED in frame.f_locals:
return bool(frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED])
if frame.f_code.co_name == "__del__":
return True
frame = frame.f_back
return True
def currently_ki_protected() -> bool:
r"""Check whether the calling code has :exc:`KeyboardInterrupt` protection
enabled.
It's surprisingly easy to think that one's :exc:`KeyboardInterrupt`
protection is enabled when it isn't, or vice-versa. This function tells
you what Trio thinks of the matter, which makes it useful for ``assert``\s
and unit tests.
Returns:
bool: True if protection is enabled, and False otherwise.
"""
return ki_protection_enabled(sys._getframe())
# This is to support the async_generator package necessary for aclosing on <3.10
# functions decorated @async_generator are given this magic property that's a
# reference to the object itself
# see python-trio/async_generator/async_generator/_impl.py
def legacy_isasyncgenfunction(
obj: object,
) -> TypeGuard[Callable[..., types.AsyncGeneratorType[object, object]]]:
return getattr(obj, "_async_gen_function", None) == id(obj)
def _ki_protection_decorator(
enabled: bool,
) -> Callable[[Callable[ArgsT, RetT]], Callable[ArgsT, RetT]]:
# The "ignore[return-value]" below is because the inspect functions cast away the
# original return type of fn, making it just CoroutineType[Any, Any, Any] etc.
# ignore[misc] is because @wraps() is passed a callable with Any in the return type.
def decorator(fn: Callable[ArgsT, RetT]) -> Callable[ArgsT, RetT]:
# In some version of Python, isgeneratorfunction returns true for
# coroutine functions, so we have to check for coroutine functions
# first.
if inspect.iscoroutinefunction(fn):
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
# See the comment for regular generators below
coro = fn(*args, **kwargs)
assert coro.cr_frame is not None, "Coroutine frame should exist"
coro.cr_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return coro # type: ignore[return-value]
return wrapper
elif inspect.isgeneratorfunction(fn):
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
# It's important that we inject this directly into the
# generator's locals, as opposed to setting it here and then
# doing 'yield from'. The reason is, if a generator is
# throw()n into, then it may magically pop to the top of the
# stack. And @contextmanager generators in particular are a
# case where we often want KI protection, and which are often
# thrown into! See:
# https://bugs.python.org/issue29590
gen = fn(*args, **kwargs)
gen.gi_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return gen # type: ignore[return-value]
return wrapper
elif inspect.isasyncgenfunction(fn) or legacy_isasyncgenfunction(fn):
@wraps(fn) # type: ignore[arg-type]
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT: # type: ignore[misc]
# See the comment for regular generators above
agen = fn(*args, **kwargs)
agen.ag_frame.f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return agen # type: ignore[return-value]
return wrapper
else:
@wraps(fn)
def wrapper(*args: ArgsT.args, **kwargs: ArgsT.kwargs) -> RetT:
sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = enabled
return fn(*args, **kwargs)
return wrapper
return decorator
# pyright workaround: https://github.com/microsoft/pyright/issues/5866
class KIProtectionSignature(Protocol):
__name__: str
def __call__(self, f: CallableT, /) -> CallableT:
pass
# the following `type: ignore`s are because we use ParamSpec internally, but want to allow overloads
enable_ki_protection: KIProtectionSignature = _ki_protection_decorator(True) # type: ignore[assignment]
enable_ki_protection.__name__ = "enable_ki_protection"
disable_ki_protection: KIProtectionSignature = _ki_protection_decorator(False) # type: ignore[assignment]
disable_ki_protection.__name__ = "disable_ki_protection"
@attrs.define(slots=False)
class KIManager:
handler: Callable[[int, types.FrameType | None], None] | None = None
def install(
self,
deliver_cb: Callable[[], object],
restrict_keyboard_interrupt_to_checkpoints: bool,
) -> None:
assert self.handler is None
if (
not is_main_thread()
or signal.getsignal(signal.SIGINT) != signal.default_int_handler
):
return
def handler(signum: int, frame: types.FrameType | None) -> None:
assert signum == signal.SIGINT
protection_enabled = ki_protection_enabled(frame)
if protection_enabled or restrict_keyboard_interrupt_to_checkpoints:
deliver_cb()
else:
raise KeyboardInterrupt
self.handler = handler
signal.signal(signal.SIGINT, handler)
def close(self) -> None:
if self.handler is not None:
if signal.getsignal(signal.SIGINT) is self.handler:
signal.signal(signal.SIGINT, signal.default_int_handler)
self.handler = None

View File

@ -0,0 +1,104 @@
from __future__ import annotations
from typing import Generic, TypeVar, cast
# Runvar implementations
import attrs
from .._util import NoPublicConstructor, final
from . import _run
T = TypeVar("T")
@final
class _NoValue: ...
@final
@attrs.define(eq=False)
class RunVarToken(Generic[T], metaclass=NoPublicConstructor):
_var: RunVar[T]
previous_value: T | type[_NoValue] = _NoValue
redeemed: bool = attrs.field(default=False, init=False)
@classmethod
def _empty(cls, var: RunVar[T]) -> RunVarToken[T]:
return cls._create(var)
@final
@attrs.define(eq=False, repr=False)
class RunVar(Generic[T]):
"""The run-local variant of a context variable.
:class:`RunVar` objects are similar to context variable objects,
except that they are shared across a single call to :func:`trio.run`
rather than a single task.
"""
_name: str
_default: T | type[_NoValue] = _NoValue
def get(self, default: T | type[_NoValue] = _NoValue) -> T:
"""Gets the value of this :class:`RunVar` for the current run call."""
try:
return cast(T, _run.GLOBAL_RUN_CONTEXT.runner._locals[self])
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
except KeyError:
# contextvars consistency
# `type: ignore` awaiting https://github.com/python/mypy/issues/15553 to be fixed & released
if default is not _NoValue:
return default # type: ignore[return-value]
if self._default is not _NoValue:
return self._default # type: ignore[return-value]
raise LookupError(self) from None
def set(self, value: T) -> RunVarToken[T]:
"""Sets the value of this :class:`RunVar` for this current run
call.
"""
try:
old_value = self.get()
except LookupError:
token = RunVarToken._empty(self)
else:
token = RunVarToken[T]._create(self, old_value)
# This can't fail, because if we weren't in Trio context then the
# get() above would have failed.
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = value
return token
def reset(self, token: RunVarToken[T]) -> None:
"""Resets the value of this :class:`RunVar` to what it was
previously specified by the token.
"""
if token is None:
raise TypeError("token must not be none")
if token.redeemed:
raise ValueError("token has already been used")
if token._var is not self:
raise ValueError("token is not for us")
previous = token.previous_value
try:
if previous is _NoValue:
_run.GLOBAL_RUN_CONTEXT.runner._locals.pop(self)
else:
_run.GLOBAL_RUN_CONTEXT.runner._locals[self] = previous
except AttributeError:
raise RuntimeError("Cannot be used outside of a run context") from None
token.redeemed = True
def __repr__(self) -> str:
return f"<RunVar name={self._name!r}>"

View File

@ -0,0 +1,164 @@
import time
from math import inf
from .. import _core
from .._abc import Clock
from .._util import final
from ._run import GLOBAL_RUN_CONTEXT
################################################################
# The glorious MockClock
################################################################
# Prior art:
# https://twistedmatrix.com/documents/current/api/twisted.internet.task.Clock.html
# https://github.com/ztellman/manifold/issues/57
@final
class MockClock(Clock):
"""A user-controllable clock suitable for writing tests.
Args:
rate (float): the initial :attr:`rate`.
autojump_threshold (float): the initial :attr:`autojump_threshold`.
.. attribute:: rate
How many seconds of clock time pass per second of real time. Default is
0.0, i.e. the clock only advances through manuals calls to :meth:`jump`
or when the :attr:`autojump_threshold` is triggered. You can assign to
this attribute to change it.
.. attribute:: autojump_threshold
The clock keeps an eye on the run loop, and if at any point it detects
that all tasks have been blocked for this many real seconds (i.e.,
according to the actual clock, not this clock), then the clock
automatically jumps ahead to the run loop's next scheduled
timeout. Default is :data:`math.inf`, i.e., to never autojump. You can
assign to this attribute to change it.
Basically the idea is that if you have code or tests that use sleeps
and timeouts, you can use this to make it run much faster, totally
automatically. (At least, as long as those sleeps/timeouts are
happening inside Trio; if your test involves talking to external
service and waiting for it to timeout then obviously we can't help you
there.)
You should set this to the smallest value that lets you reliably avoid
"false alarms" where some I/O is in flight (e.g. between two halves of
a socketpair) but the threshold gets triggered and time gets advanced
anyway. This will depend on the details of your tests and test
environment. If you aren't doing any I/O (like in our sleeping example
above) then just set it to zero, and the clock will jump whenever all
tasks are blocked.
.. note:: If you use ``autojump_threshold`` and
`wait_all_tasks_blocked` at the same time, then you might wonder how
they interact, since they both cause things to happen after the run
loop goes idle for some time. The answer is:
`wait_all_tasks_blocked` takes priority. If there's a task blocked
in `wait_all_tasks_blocked`, then the autojump feature treats that
as active task and does *not* jump the clock.
"""
def __init__(self, rate: float = 0.0, autojump_threshold: float = inf):
# when the real clock said 'real_base', the virtual time was
# 'virtual_base', and since then it's advanced at 'rate' virtual
# seconds per real second.
self._real_base = 0.0
self._virtual_base = 0.0
self._rate = 0.0
self._autojump_threshold = 0.0
# kept as an attribute so that our tests can monkeypatch it
self._real_clock = time.perf_counter
# use the property update logic to set initial values
self.rate = rate
self.autojump_threshold = autojump_threshold
def __repr__(self) -> str:
return f"<MockClock, time={self.current_time():.7f}, rate={self._rate} @ {id(self):#x}>"
@property
def rate(self) -> float:
return self._rate
@rate.setter
def rate(self, new_rate: float) -> None:
if new_rate < 0:
raise ValueError("rate must be >= 0")
else:
real = self._real_clock()
virtual = self._real_to_virtual(real)
self._virtual_base = virtual
self._real_base = real
self._rate = float(new_rate)
@property
def autojump_threshold(self) -> float:
return self._autojump_threshold
@autojump_threshold.setter
def autojump_threshold(self, new_autojump_threshold: float) -> None:
self._autojump_threshold = float(new_autojump_threshold)
self._try_resync_autojump_threshold()
# runner.clock_autojump_threshold is an internal API that isn't easily
# usable by custom third-party Clock objects. If you need access to this
# functionality, let us know, and we'll figure out how to make a public
# API. Discussion:
#
# https://github.com/python-trio/trio/issues/1587
def _try_resync_autojump_threshold(self) -> None:
try:
runner = GLOBAL_RUN_CONTEXT.runner
if runner.is_guest:
runner.force_guest_tick_asap()
except AttributeError:
pass
else:
runner.clock_autojump_threshold = self._autojump_threshold
# Invoked by the run loop when runner.clock_autojump_threshold is
# exceeded.
def _autojump(self) -> None:
statistics = _core.current_statistics()
jump = statistics.seconds_to_next_deadline
if 0 < jump < inf:
self.jump(jump)
def _real_to_virtual(self, real: float) -> float:
real_offset = real - self._real_base
virtual_offset = self._rate * real_offset
return self._virtual_base + virtual_offset
def start_clock(self) -> None:
self._try_resync_autojump_threshold()
def current_time(self) -> float:
return self._real_to_virtual(self._real_clock())
def deadline_to_sleep_time(self, deadline: float) -> float:
virtual_timeout = deadline - self.current_time()
if virtual_timeout <= 0:
return 0
elif self._rate > 0:
return virtual_timeout / self._rate
else:
return 999999999
def jump(self, seconds: float) -> None:
"""Manually advance the clock by the given number of seconds.
Args:
seconds (float): the number of seconds to jump the clock forward.
Raises:
ValueError: if you try to pass a negative value for ``seconds``.
"""
if seconds < 0:
raise ValueError("time can't go backwards")
self._virtual_base += seconds

View File

@ -0,0 +1,317 @@
# ParkingLot provides an abstraction for a fair waitqueue with cancellation
# and requeueing support. Inspiration:
#
# https://webkit.org/blog/6161/locking-in-webkit/
# https://amanieu.github.io/parking_lot/
#
# which were in turn heavily influenced by
#
# http://gee.cs.oswego.edu/dl/papers/aqs.pdf
#
# Compared to these, our use of cooperative scheduling allows some
# simplifications (no need for internal locking). On the other hand, the need
# to support Trio's strong cancellation semantics adds some complications
# (tasks need to know where they're queued so they can cancel). Also, in the
# above work, the ParkingLot is a global structure that holds a collection of
# waitqueues keyed by lock address, and which are opportunistically allocated
# and destroyed as contention arises; this allows the worst-case memory usage
# for all waitqueues to be O(#tasks). Here we allocate a separate wait queue
# for each synchronization object, so we're O(#objects + #tasks). This isn't
# *so* bad since compared to our synchronization objects are heavier than
# theirs and our tasks are lighter, so for us #objects is smaller and #tasks
# is larger.
#
# This is in the core because for two reasons. First, it's used by
# UnboundedQueue, and UnboundedQueue is used for a number of things in the
# core. And second, it's responsible for providing fairness to all of our
# high-level synchronization primitives (locks, queues, etc.). For now with
# our FIFO scheduler this is relatively trivial (it's just a FIFO waitqueue),
# but in the future we ever start support task priorities or fair scheduling
#
# https://github.com/python-trio/trio/issues/32
#
# then all we'll have to do is update this. (Well, full-fledged task
# priorities might also require priority inheritance, which would require more
# work.)
#
# For discussion of data structures to use here, see:
#
# https://github.com/dabeaz/curio/issues/136
#
# (and also the articles above). Currently we use a SortedDict ordered by a
# global monotonic counter that ensures FIFO ordering. The main advantage of
# this is that it's easy to implement :-). An intrusive doubly-linked list
# would also be a natural approach, so long as we only handle FIFO ordering.
#
# XX: should we switch to the shared global ParkingLot approach?
#
# XX: we should probably add support for "parking tokens" to allow for
# task-fair RWlock (basically: when parking a task needs to be able to mark
# itself as a reader or a writer, and then a task-fair wakeup policy is, wake
# the next task, and if it's a reader than keep waking tasks so long as they
# are readers). Without this I think you can implement write-biased or
# read-biased RWlocks (by using two parking lots and drawing from whichever is
# preferred), but not task-fair -- and task-fair plays much more nicely with
# WFQ. (Consider what happens in the two-lot implementation if you're
# write-biased but all the pending writers are blocked at the scheduler level
# by the WFQ logic...)
# ...alternatively, "phase-fair" RWlocks are pretty interesting:
# http://www.cs.unc.edu/~anderson/papers/ecrts09b.pdf
# Useful summary:
# https://docs.oracle.com/javase/7/docs/api/java/util/concurrent/locks/ReadWriteLock.html
#
# XX: if we do add WFQ, then we might have to drop the current feature where
# unpark returns the tasks that were unparked. Rationale: suppose that at the
# time we call unpark, the next task is deprioritized... and then, before it
# becomes runnable, a new task parks which *is* runnable. Ideally we should
# immediately wake the new task, and leave the old task on the queue for
# later. But this means we can't commit to which task we are unparking when
# unpark is called.
#
# See: https://github.com/python-trio/trio/issues/53
from __future__ import annotations
import inspect
import math
from collections import OrderedDict
from typing import TYPE_CHECKING
import attrs
import outcome
from .. import _core
from .._util import final
if TYPE_CHECKING:
from collections.abc import Iterator
from ._run import Task
GLOBAL_PARKING_LOT_BREAKER: dict[Task, list[ParkingLot]] = {}
def add_parking_lot_breaker(task: Task, lot: ParkingLot) -> None:
"""Register a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`.
raises:
trio.BrokenResourceError: if the task has already exited.
"""
if inspect.getcoroutinestate(task.coro) == inspect.CORO_CLOSED:
raise _core._exceptions.BrokenResourceError(
"Attempted to add already exited task as lot breaker.",
)
if task not in GLOBAL_PARKING_LOT_BREAKER:
GLOBAL_PARKING_LOT_BREAKER[task] = [lot]
else:
GLOBAL_PARKING_LOT_BREAKER[task].append(lot)
def remove_parking_lot_breaker(task: Task, lot: ParkingLot) -> None:
"""Deregister a task as a breaker for a lot. See :meth:`ParkingLot.break_lot`"""
try:
GLOBAL_PARKING_LOT_BREAKER[task].remove(lot)
except (KeyError, ValueError):
raise RuntimeError(
"Attempted to remove task as breaker for a lot it is not registered for",
) from None
if not GLOBAL_PARKING_LOT_BREAKER[task]:
del GLOBAL_PARKING_LOT_BREAKER[task]
@attrs.frozen
class ParkingLotStatistics:
"""An object containing debugging information for a ParkingLot.
Currently, the following fields are defined:
* ``tasks_waiting`` (int): The number of tasks blocked on this lot's
:meth:`trio.lowlevel.ParkingLot.park` method.
"""
tasks_waiting: int
@final
@attrs.define(eq=False)
class ParkingLot:
"""A fair wait queue with cancellation and requeueing.
This class encapsulates the tricky parts of implementing a wait
queue. It's useful for implementing higher-level synchronization
primitives like queues and locks.
In addition to the methods below, you can use ``len(parking_lot)`` to get
the number of parked tasks, and ``if parking_lot: ...`` to check whether
there are any parked tasks.
"""
# {task: None}, we just want a deque where we can quickly delete random
# items
_parked: OrderedDict[Task, None] = attrs.field(factory=OrderedDict, init=False)
broken_by: list[Task] = attrs.field(factory=list, init=False)
def __len__(self) -> int:
"""Returns the number of parked tasks."""
return len(self._parked)
def __bool__(self) -> bool:
"""True if there are parked tasks, False otherwise."""
return bool(self._parked)
# XX this currently returns None
# if we ever add the ability to repark while one's resuming place in
# line (for false wakeups), then we could have it return a ticket that
# abstracts the "place in line" concept.
@_core.enable_ki_protection
async def park(self) -> None:
"""Park the current task until woken by a call to :meth:`unpark` or
:meth:`unpark_all`.
Raises:
BrokenResourceError: if attempting to park in a broken lot, or the lot
breaks before we get to unpark.
"""
if self.broken_by:
raise _core.BrokenResourceError(
f"Attempted to park in parking lot broken by {self.broken_by}",
)
task = _core.current_task()
self._parked[task] = None
task.custom_sleep_data = self
def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:
del task.custom_sleep_data._parked[task]
return _core.Abort.SUCCEEDED
await _core.wait_task_rescheduled(abort_fn)
def _pop_several(self, count: int | float) -> Iterator[Task]: # noqa: PYI041
if isinstance(count, float):
if math.isinf(count):
count = len(self._parked)
else:
raise ValueError("Cannot pop a non-integer number of tasks.")
else:
count = min(count, len(self._parked))
for _ in range(count):
task, _ = self._parked.popitem(last=False)
yield task
@_core.enable_ki_protection
def unpark(self, *, count: int | float = 1) -> list[Task]: # noqa: PYI041
"""Unpark one or more tasks.
This wakes up ``count`` tasks that are blocked in :meth:`park`. If
there are fewer than ``count`` tasks parked, then wakes as many tasks
are available and then returns successfully.
Args:
count (int | math.inf): the number of tasks to unpark.
"""
tasks = list(self._pop_several(count))
for task in tasks:
_core.reschedule(task)
return tasks
def unpark_all(self) -> list[Task]:
"""Unpark all parked tasks."""
return self.unpark(count=len(self))
@_core.enable_ki_protection
def repark(
self,
new_lot: ParkingLot,
*,
count: int | float = 1, # noqa: PYI041
) -> None:
"""Move parked tasks from one :class:`ParkingLot` object to another.
This dequeues ``count`` tasks from one lot, and requeues them on
another, preserving order. For example::
async def parker(lot):
print("sleeping")
await lot.park()
print("woken")
async def main():
lot1 = trio.lowlevel.ParkingLot()
lot2 = trio.lowlevel.ParkingLot()
async with trio.open_nursery() as nursery:
nursery.start_soon(parker, lot1)
await trio.testing.wait_all_tasks_blocked()
assert len(lot1) == 1
assert len(lot2) == 0
lot1.repark(lot2)
assert len(lot1) == 0
assert len(lot2) == 1
# This wakes up the task that was originally parked in lot1
lot2.unpark()
If there are fewer than ``count`` tasks parked, then reparks as many
tasks as are available and then returns successfully.
Args:
new_lot (ParkingLot): the parking lot to move tasks to.
count (int|math.inf): the number of tasks to move.
"""
if not isinstance(new_lot, ParkingLot):
raise TypeError("new_lot must be a ParkingLot")
for task in self._pop_several(count):
new_lot._parked[task] = None
task.custom_sleep_data = new_lot
def repark_all(self, new_lot: ParkingLot) -> None:
"""Move all parked tasks from one :class:`ParkingLot` object to
another.
See :meth:`repark` for details.
"""
return self.repark(new_lot, count=len(self))
def break_lot(self, task: Task | None = None) -> None:
"""Break this lot, with ``task`` noted as the task that broke it.
This causes all parked tasks to raise an error, and any
future tasks attempting to park to error. Unpark & repark become no-ops as the
parking lot is empty.
The error raised contains a reference to the task sent as a parameter. The task
is also saved in the parking lot in the ``broken_by`` attribute.
"""
if task is None:
task = _core.current_task()
# if lot is already broken, just mark this as another breaker and return
if self.broken_by:
self.broken_by.append(task)
return
self.broken_by.append(task)
for parked_task in self._parked:
_core.reschedule(
parked_task,
outcome.Error(
_core.BrokenResourceError(f"Parking lot broken by {task}"),
),
)
self._parked.clear()
def statistics(self) -> ParkingLotStatistics:
"""Return an object containing debugging information.
Currently the following fields are defined:
* ``tasks_waiting``: The number of tasks blocked on this lot's
:meth:`park` method.
"""
return ParkingLotStatistics(tasks_waiting=len(self._parked))

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,333 @@
from __future__ import annotations
import contextlib
import sys
import weakref
from math import inf
from typing import TYPE_CHECKING, NoReturn
import pytest
from ... import _core
from .tutil import gc_collect_harder, restore_unraisablehook
if TYPE_CHECKING:
from collections.abc import AsyncGenerator
def test_asyncgen_basics() -> None:
collected = []
async def example(cause: str) -> AsyncGenerator[int, None]:
try:
with contextlib.suppress(GeneratorExit):
yield 42
await _core.checkpoint()
except _core.Cancelled:
assert "exhausted" not in cause
task_name = _core.current_task().name
assert cause in task_name or task_name == "<init>"
assert _core.current_effective_deadline() == -inf
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
collected.append(cause)
else:
assert "async_main" in _core.current_task().name
assert "exhausted" in cause
assert _core.current_effective_deadline() == inf
await _core.checkpoint()
collected.append(cause)
saved = []
async def async_main() -> None:
# GC'ed before exhausted
with pytest.warns(
ResourceWarning,
match="Async generator.*collected before.*exhausted",
):
assert await example("abandoned").asend(None) == 42
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert collected.pop() == "abandoned"
aiter_ = example("exhausted 1")
try:
assert await aiter_.asend(None) == 42
finally:
await aiter_.aclose()
assert collected.pop() == "exhausted 1"
# Also fine if you exhaust it at point of use
async for val in example("exhausted 2"):
assert val == 42
assert collected.pop() == "exhausted 2"
gc_collect_harder()
# No problems saving the geniter when using either of these patterns
aiter_ = example("exhausted 3")
try:
saved.append(aiter_)
assert await aiter_.asend(None) == 42
finally:
await aiter_.aclose()
assert collected.pop() == "exhausted 3"
# Also fine if you exhaust it at point of use
saved.append(example("exhausted 4"))
async for val in saved[-1]:
assert val == 42
assert collected.pop() == "exhausted 4"
# Leave one referenced-but-unexhausted and make sure it gets cleaned up
saved.append(example("outlived run"))
assert await saved[-1].asend(None) == 42
assert collected == []
_core.run(async_main)
assert collected.pop() == "outlived run"
for agen in saved:
assert agen.ag_frame is None # all should now be exhausted
async def test_asyncgen_throws_during_finalization(
caplog: pytest.LogCaptureFixture,
) -> None:
record = []
async def agen() -> AsyncGenerator[int, None]:
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("crashing")
raise ValueError("oops")
with restore_unraisablehook():
await agen().asend(None)
gc_collect_harder()
await _core.wait_all_tasks_blocked()
assert record == ["crashing"]
# Following type ignore is because typing for LogCaptureFixture is wrong
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info # type: ignore[misc]
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "during finalization of async generator" in caplog.records[0].message
def test_firstiter_after_closing() -> None:
saved = []
record = []
async def funky_agen() -> AsyncGenerator[int, None]:
try:
yield 1
except GeneratorExit:
record.append("cleanup 1")
raise
try:
yield 2
finally:
record.append("cleanup 2")
await funky_agen().asend(None)
async def async_main() -> None:
aiter_ = funky_agen()
saved.append(aiter_)
assert await aiter_.asend(None) == 1
assert await aiter_.asend(None) == 2
_core.run(async_main)
assert record == ["cleanup 2", "cleanup 1"]
def test_interdependent_asyncgen_cleanup_order() -> None:
saved: list[AsyncGenerator[int, None]] = []
record: list[int | str] = []
async def innermost() -> AsyncGenerator[int, None]:
try:
yield 1
finally:
await _core.cancel_shielded_checkpoint()
record.append("innermost")
async def agen(
label: int,
inner: AsyncGenerator[int, None],
) -> AsyncGenerator[int, None]:
try:
yield await inner.asend(None)
finally:
# Either `inner` has already been cleaned up, or
# we're about to exhaust it. Either way, we wind
# up with `record` containing the labels in
# innermost-to-outermost order.
with pytest.raises(StopAsyncIteration):
await inner.asend(None)
record.append(label)
async def async_main() -> None:
# This makes a chain of 101 interdependent asyncgens:
# agen(99)'s cleanup will iterate agen(98)'s will iterate
# ... agen(0)'s will iterate innermost()'s
ag_chain = innermost()
for idx in range(100):
ag_chain = agen(idx, ag_chain)
saved.append(ag_chain)
assert await ag_chain.asend(None) == 1
assert record == []
_core.run(async_main)
assert record == ["innermost", *range(100)]
@restore_unraisablehook()
def test_last_minute_gc_edge_case() -> None:
saved: list[AsyncGenerator[int, None]] = []
record = []
needs_retry = True
async def agen() -> AsyncGenerator[int, None]:
try:
yield 1
finally:
record.append("cleaned up")
def collect_at_opportune_moment(token: _core._entry_queue.TrioToken) -> None:
runner = _core._run.GLOBAL_RUN_CONTEXT.runner
assert runner.system_nursery is not None
if runner.system_nursery._closed and isinstance(
runner.asyncgens.alive,
weakref.WeakSet,
):
saved.clear()
record.append("final collection")
gc_collect_harder()
record.append("done")
else:
try:
token.run_sync_soon(collect_at_opportune_moment, token)
except _core.RunFinishedError: # pragma: no cover
nonlocal needs_retry
needs_retry = True
async def async_main() -> None:
token = _core.current_trio_token()
token.run_sync_soon(collect_at_opportune_moment, token)
saved.append(agen())
await saved[-1].asend(None)
# Actually running into the edge case requires that the run_sync_soon task
# execute in between the system nursery's closure and the strong-ification
# of runner.asyncgens. There's about a 25% chance that it doesn't
# (if the run_sync_soon task runs before init on one tick and after init
# on the next tick); if we try enough times, we can make the chance of
# failure as small as we want.
for _attempt in range(50):
needs_retry = False
del record[:]
del saved[:]
_core.run(async_main)
if needs_retry: # pragma: no cover
assert record == ["cleaned up"]
else:
assert record == ["final collection", "done", "cleaned up"]
break
else: # pragma: no cover
pytest.fail(
"Didn't manage to hit the trailing_finalizer_asyncgens case "
f"despite trying {_attempt} times",
)
async def step_outside_async_context(aiter_: AsyncGenerator[int, None]) -> None:
# abort_fns run outside of task context, at least if they're
# triggered by a deadline expiry rather than a direct
# cancellation. Thus, an asyncgen first iterated inside one
# will appear non-Trio, and since no other hooks were installed,
# will use the last-ditch fallback handling (that tries to mimic
# CPython's behavior with no hooks).
#
# NB: the strangeness with aiter being an attribute of abort_fn is
# to make it as easy as possible to ensure we don't hang onto a
# reference to aiter inside the guts of the run loop.
def abort_fn(_: _core.RaiseCancelT) -> _core.Abort:
with pytest.raises(StopIteration, match="42"):
abort_fn.aiter.asend(None).send(None) # type: ignore[attr-defined] # Callables don't have attribute "aiter"
del abort_fn.aiter # type: ignore[attr-defined]
return _core.Abort.SUCCEEDED
abort_fn.aiter = aiter_ # type: ignore[attr-defined]
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_task_rescheduled, abort_fn)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.deadline = _core.current_time()
async def test_fallback_when_no_hook_claims_it(
capsys: pytest.CaptureFixture[str],
) -> None:
async def well_behaved() -> AsyncGenerator[int, None]:
yield 42
async def yields_after_yield() -> AsyncGenerator[int, None]:
with pytest.raises(GeneratorExit):
yield 42
yield 100
async def awaits_after_yield() -> AsyncGenerator[int, None]:
with pytest.raises(GeneratorExit):
yield 42
await _core.cancel_shielded_checkpoint()
with restore_unraisablehook():
await step_outside_async_context(well_behaved())
gc_collect_harder()
assert capsys.readouterr().err == ""
await step_outside_async_context(yields_after_yield())
gc_collect_harder()
assert "ignored GeneratorExit" in capsys.readouterr().err
await step_outside_async_context(awaits_after_yield())
gc_collect_harder()
assert "awaited something during finalization" in capsys.readouterr().err
def test_delegation_to_existing_hooks() -> None:
record = []
def my_firstiter(agen: AsyncGenerator[object, NoReturn]) -> None:
record.append("firstiter " + agen.ag_frame.f_locals["arg"])
def my_finalizer(agen: AsyncGenerator[object, NoReturn]) -> None:
record.append("finalizer " + agen.ag_frame.f_locals["arg"])
async def example(arg: str) -> AsyncGenerator[int, None]:
try:
yield 42
finally:
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
record.append("trio collected " + arg)
async def async_main() -> None:
await step_outside_async_context(example("theirs"))
assert await example("ours").asend(None) == 42
gc_collect_harder()
assert record == ["firstiter theirs", "finalizer theirs"]
record[:] = []
await _core.wait_all_tasks_blocked()
assert record == ["trio collected ours"]
with restore_unraisablehook():
old_hooks = sys.get_asyncgen_hooks()
sys.set_asyncgen_hooks(my_firstiter, my_finalizer)
try:
_core.run(async_main)
finally:
assert sys.get_asyncgen_hooks() == (my_firstiter, my_finalizer)
sys.set_asyncgen_hooks(*old_hooks)

View File

@ -0,0 +1,102 @@
from __future__ import annotations
import gc
import sys
from traceback import extract_tb
from typing import TYPE_CHECKING, Callable, NoReturn
import pytest
from .._concat_tb import concat_tb
if TYPE_CHECKING:
from types import TracebackType
if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
def raiser1() -> NoReturn:
raiser1_2()
def raiser1_2() -> NoReturn:
raiser1_3()
def raiser1_3() -> NoReturn:
raise ValueError("raiser1_string")
def raiser2() -> NoReturn:
raiser2_2()
def raiser2_2() -> NoReturn:
raise KeyError("raiser2_string")
def get_exc(raiser: Callable[[], NoReturn]) -> Exception:
try:
raiser()
except Exception as exc:
return exc
raise AssertionError("raiser should always raise") # pragma: no cover
def get_tb(raiser: Callable[[], NoReturn]) -> TracebackType | None:
return get_exc(raiser).__traceback__
def test_concat_tb() -> None:
tb1 = get_tb(raiser1)
tb2 = get_tb(raiser2)
# These return a list of (filename, lineno, fn name, text) tuples
# https://docs.python.org/3/library/traceback.html#traceback.extract_tb
entries1 = extract_tb(tb1)
entries2 = extract_tb(tb2)
tb12 = concat_tb(tb1, tb2)
assert extract_tb(tb12) == entries1 + entries2
tb21 = concat_tb(tb2, tb1)
assert extract_tb(tb21) == entries2 + entries1
# Check degenerate cases
assert extract_tb(concat_tb(None, tb1)) == entries1
assert extract_tb(concat_tb(tb1, None)) == entries1
assert concat_tb(None, None) is None
# Make sure the original tracebacks didn't get mutated by mistake
assert extract_tb(get_tb(raiser1)) == entries1
assert extract_tb(get_tb(raiser2)) == entries2
# Unclear if this can still fail, removing the `del` from _concat_tb.copy_tb does not seem
# to trigger it (on a platform where the `del` is executed)
@pytest.mark.skipif(
sys.implementation.name != "cpython",
reason="Only makes sense with refcounting GC",
)
def test_ExceptionGroup_catch_doesnt_create_cyclic_garbage() -> None:
# https://github.com/python-trio/trio/pull/2063
gc.collect()
old_flags = gc.get_debug()
def make_multi() -> NoReturn:
raise ExceptionGroup("", [get_exc(raiser1), get_exc(raiser2)])
try:
gc.set_debug(gc.DEBUG_SAVEALL)
with pytest.raises(ExceptionGroup) as excinfo:
# covers ~~MultiErrorCatcher.__exit__ and~~ _concat_tb.copy_tb
# TODO: is the above comment true anymore? as this no longer uses MultiError.catch
raise make_multi()
for exc in excinfo.value.exceptions:
assert isinstance(exc, (ValueError, KeyError))
gc.collect()
assert not gc.garbage
finally:
gc.set_debug(old_flags)
gc.garbage.clear()

View File

@ -0,0 +1,666 @@
from __future__ import annotations
import asyncio
import contextlib
import contextvars
import queue
import signal
import socket
import sys
import threading
import time
import traceback
import warnings
from functools import partial
from math import inf
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Awaitable,
Callable,
NoReturn,
TypeVar,
)
import pytest
from outcome import Outcome
import trio
import trio.testing
from trio.abc import Instrument
from ..._util import signal_raise
from .tutil import gc_collect_harder, restore_unraisablehook
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from trio._channel import MemorySendChannel
T = TypeVar("T")
InHost: TypeAlias = Callable[[object], None]
# The simplest possible "host" loop.
# Nice features:
# - we can run code "outside" of trio using the schedule function passed to
# our main
# - final result is returned
# - any unhandled exceptions cause an immediate crash
def trivial_guest_run(
trio_fn: Callable[..., Awaitable[T]],
*,
in_host_after_start: Callable[[], None] | None = None,
**start_guest_run_kwargs: Any,
) -> T:
todo: queue.Queue[tuple[str, Outcome[T] | Callable[..., object]]] = queue.Queue()
host_thread = threading.current_thread()
def run_sync_soon_threadsafe(fn: Callable[[], object]) -> None:
nonlocal todo
if host_thread is threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail,
"run_sync_soon_threadsafe called from host thread",
)
todo.put(("run", crash))
todo.put(("run", fn))
def run_sync_soon_not_threadsafe(fn: Callable[[], object]) -> None:
nonlocal todo
if host_thread is not threading.current_thread(): # pragma: no cover
crash = partial(
pytest.fail,
"run_sync_soon_not_threadsafe called from worker thread",
)
todo.put(("run", crash))
todo.put(("run", fn))
def done_callback(outcome: Outcome[T]) -> None:
nonlocal todo
todo.put(("unwrap", outcome))
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_not_threadsafe,
run_sync_soon_threadsafe=run_sync_soon_threadsafe,
run_sync_soon_not_threadsafe=run_sync_soon_not_threadsafe,
done_callback=done_callback,
**start_guest_run_kwargs,
)
if in_host_after_start is not None:
in_host_after_start()
try:
while True:
op, obj = todo.get()
if op == "run":
assert not isinstance(obj, Outcome)
obj()
elif op == "unwrap":
assert isinstance(obj, Outcome)
return obj.unwrap()
else: # pragma: no cover
raise NotImplementedError(f"{op!r} not handled")
finally:
# Make sure that exceptions raised here don't capture these, so that
# if an exception does cause us to abandon a run then the Trio state
# has a chance to be GC'ed and warn about it.
del todo, run_sync_soon_threadsafe, done_callback
def test_guest_trivial() -> None:
async def trio_return(in_host: InHost) -> str:
await trio.lowlevel.checkpoint()
return "ok"
assert trivial_guest_run(trio_return) == "ok"
async def trio_fail(in_host: InHost) -> NoReturn:
raise KeyError("whoopsiedaisy")
with pytest.raises(KeyError, match="whoopsiedaisy"):
trivial_guest_run(trio_fail)
def test_guest_can_do_io() -> None:
async def trio_main(in_host: InHost) -> None:
record = []
a, b = trio.socket.socketpair()
with a, b:
async with trio.open_nursery() as nursery:
async def do_receive() -> None:
record.append(await a.recv(1))
nursery.start_soon(do_receive)
await trio.testing.wait_all_tasks_blocked()
await b.send(b"x")
assert record == [b"x"]
trivial_guest_run(trio_main)
def test_guest_is_initialized_when_start_returns() -> None:
trio_token = None
record = []
async def trio_main(in_host: InHost) -> str:
record.append("main task ran")
await trio.lowlevel.checkpoint()
assert trio.lowlevel.current_trio_token() is trio_token
return "ok"
def after_start() -> None:
# We should get control back before the main task executes any code
assert record == []
nonlocal trio_token
trio_token = trio.lowlevel.current_trio_token()
trio_token.run_sync_soon(record.append, "run_sync_soon cb ran")
@trio.lowlevel.spawn_system_task
async def early_task() -> None:
record.append("system task ran")
await trio.lowlevel.checkpoint()
res = trivial_guest_run(trio_main, in_host_after_start=after_start)
assert res == "ok"
assert set(record) == {"system task ran", "main task ran", "run_sync_soon cb ran"}
class BadClock:
def start_clock(self) -> NoReturn:
raise ValueError("whoops")
def after_start_never_runs() -> None: # pragma: no cover
pytest.fail("shouldn't get here")
# Errors during initialization (which can only be TrioInternalErrors)
# are raised out of start_guest_run, not out of the done_callback
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(
trio_main,
clock=BadClock(),
in_host_after_start=after_start_never_runs,
)
def test_host_can_directly_wake_trio_task() -> None:
async def trio_main(in_host: InHost) -> str:
ev = trio.Event()
in_host(ev.set)
await ev.wait()
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_host_altering_deadlines_wakes_trio_up() -> None:
def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
cscope.deadline = new_deadline
async def trio_main(in_host: InHost) -> str:
with trio.CancelScope() as cscope:
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep_forever()
assert cscope.cancelled_caught
with trio.CancelScope() as cscope:
# also do a change that doesn't affect the next deadline, just to
# exercise that path
in_host(lambda: set_deadline(cscope, 1e6))
in_host(lambda: set_deadline(cscope, -inf))
await trio.sleep(999)
assert cscope.cancelled_caught
return "ok"
assert trivial_guest_run(trio_main) == "ok"
def test_guest_mode_sniffio_integration() -> None:
from sniffio import current_async_library, thread_local as sniffio_library
async def trio_main(in_host: InHost) -> str:
async def synchronize() -> None:
"""Wait for all in_host() calls issued so far to complete."""
evt = trio.Event()
in_host(evt.set)
await evt.wait()
# Host and guest have separate sniffio_library contexts
in_host(partial(setattr, sniffio_library, "name", "nullio"))
await synchronize()
assert current_async_library() == "trio"
record = []
in_host(lambda: record.append(current_async_library()))
await synchronize()
assert record == ["nullio"]
assert current_async_library() == "trio"
return "ok"
try:
assert trivial_guest_run(trio_main) == "ok"
finally:
sniffio_library.name = None
def test_warn_set_wakeup_fd_overwrite() -> None:
assert signal.set_wakeup_fd(-1) == -1
async def trio_main(in_host: InHost) -> str:
return "ok"
a, b = socket.socketpair()
with a, b:
a.setblocking(False)
# Warn if there's already a wakeup fd
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert trivial_guest_run(trio_main) == "ok"
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
signal.set_wakeup_fd(a.fileno())
try:
with pytest.warns(RuntimeWarning, match="signal handling code.*collided"):
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=False)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
# Don't warn if there isn't already a wakeup fd
with warnings.catch_warnings():
warnings.simplefilter("error")
assert trivial_guest_run(trio_main) == "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(trio_main, host_uses_signal_set_wakeup_fd=True)
== "ok"
)
# If there's already a wakeup fd, but we've been told to trust it,
# then it's left alone and there's no warning
signal.set_wakeup_fd(a.fileno())
try:
async def trio_check_wakeup_fd_unaltered(in_host: InHost) -> str:
fd = signal.set_wakeup_fd(-1)
assert fd == a.fileno()
signal.set_wakeup_fd(fd)
return "ok"
with warnings.catch_warnings():
warnings.simplefilter("error")
assert (
trivial_guest_run(
trio_check_wakeup_fd_unaltered,
host_uses_signal_set_wakeup_fd=True,
)
== "ok"
)
finally:
assert signal.set_wakeup_fd(-1) == a.fileno()
def test_host_wakeup_doesnt_trigger_wait_all_tasks_blocked() -> None:
# This is designed to hit the branch in unrolled_run where:
# idle_primed=True
# runner.runq is empty
# events is Truth-y
# ...and confirm that in this case, wait_all_tasks_blocked does not get
# triggered.
def set_deadline(cscope: trio.CancelScope, new_deadline: float) -> None:
print(f"setting deadline {new_deadline}")
cscope.deadline = new_deadline
async def trio_main(in_host: InHost) -> str:
async def sit_in_wait_all_tasks_blocked(watb_cscope: trio.CancelScope) -> None:
with watb_cscope:
# Overall point of this test is that this
# wait_all_tasks_blocked should *not* return normally, but
# only by cancellation.
await trio.testing.wait_all_tasks_blocked(cushion=9999)
raise AssertionError( # pragma: no cover
"wait_all_tasks_blocked should *not* return normally, "
"only by cancellation.",
)
assert watb_cscope.cancelled_caught
async def get_woken_by_host_deadline(watb_cscope: trio.CancelScope) -> None:
with trio.CancelScope() as cscope:
print("scheduling stuff to happen")
# Altering the deadline from the host, to something in the
# future, will cause the run loop to wake up, but then
# discover that there is nothing to do and go back to sleep.
# This should *not* trigger wait_all_tasks_blocked.
#
# So the 'before_io_wait' here will wait until we're blocking
# with the wait_all_tasks_blocked primed, and then schedule a
# deadline change. The critical test is that this should *not*
# wake up 'sit_in_wait_all_tasks_blocked'.
#
# The after we've had a chance to wake up
# 'sit_in_wait_all_tasks_blocked', we want the test to
# actually end. So in after_io_wait we schedule a second host
# call to tear things down.
class InstrumentHelper(Instrument):
def __init__(self) -> None:
self.primed = False
def before_io_wait(self, timeout: float) -> None:
print(f"before_io_wait({timeout})")
if timeout == 9999: # pragma: no branch
assert not self.primed
in_host(lambda: set_deadline(cscope, 1e9))
self.primed = True
def after_io_wait(self, timeout: float) -> None:
if self.primed: # pragma: no branch
print("instrument triggered")
in_host(lambda: cscope.cancel())
trio.lowlevel.remove_instrument(self)
trio.lowlevel.add_instrument(InstrumentHelper())
await trio.sleep_forever()
assert cscope.cancelled_caught
watb_cscope.cancel()
async with trio.open_nursery() as nursery:
watb_cscope = trio.CancelScope()
nursery.start_soon(sit_in_wait_all_tasks_blocked, watb_cscope)
await trio.testing.wait_all_tasks_blocked()
nursery.start_soon(get_woken_by_host_deadline, watb_cscope)
return "ok"
assert trivial_guest_run(trio_main) == "ok"
@restore_unraisablehook()
def test_guest_warns_if_abandoned() -> None:
# This warning is emitted from the garbage collector. So we have to make
# sure that our abandoned run is garbage. The easiest way to do this is to
# put it into a function, so that we're sure all the local state,
# traceback frames, etc. are garbage once it returns.
def do_abandoned_guest_run() -> None:
async def abandoned_main(in_host: InHost) -> None:
in_host(lambda: 1 / 0)
while True:
await trio.lowlevel.checkpoint()
with pytest.raises(ZeroDivisionError):
trivial_guest_run(abandoned_main)
with pytest.warns(RuntimeWarning, match="Trio guest run got abandoned"):
do_abandoned_guest_run()
gc_collect_harder()
# If you have problems some day figuring out what's holding onto a
# reference to the unrolled_run generator and making this test fail,
# then this might be useful to help track it down. (It assumes you
# also hack start_guest_run so that it does 'global W; W =
# weakref(unrolled_run_gen)'.)
#
# import gc
# print(trio._core._run.W)
# targets = [trio._core._run.W()]
# for i in range(15):
# new_targets = []
# for target in targets:
# new_targets += gc.get_referrers(target)
# new_targets.remove(targets)
# print("#####################")
# print(f"depth {i}: {len(new_targets)}")
# print(new_targets)
# targets = new_targets
with pytest.raises(RuntimeError):
trio.current_time()
def aiotrio_run(
trio_fn: Callable[..., Awaitable[T]],
*,
pass_not_threadsafe: bool = True,
**start_guest_run_kwargs: Any,
) -> T:
loop = asyncio.new_event_loop()
async def aio_main() -> T:
trio_done_fut = loop.create_future()
def trio_done_callback(main_outcome: Outcome[object]) -> None:
print(f"trio_fn finished: {main_outcome!r}")
trio_done_fut.set_result(main_outcome)
if pass_not_threadsafe:
start_guest_run_kwargs["run_sync_soon_not_threadsafe"] = loop.call_soon
trio.lowlevel.start_guest_run(
trio_fn,
run_sync_soon_threadsafe=loop.call_soon_threadsafe,
done_callback=trio_done_callback,
**start_guest_run_kwargs,
)
return (await trio_done_fut).unwrap() # type: ignore[no-any-return]
try:
return loop.run_until_complete(aio_main())
finally:
loop.close()
def test_guest_mode_on_asyncio() -> None:
async def trio_main() -> str:
print("trio_main!")
to_trio, from_aio = trio.open_memory_channel[int](float("inf"))
from_trio: asyncio.Queue[int] = asyncio.Queue()
aio_task = asyncio.ensure_future(aio_pingpong(from_trio, to_trio))
# Make sure we have at least one tick where we don't need to go into
# the thread
await trio.lowlevel.checkpoint()
from_trio.put_nowait(0)
async for n in from_aio:
print(f"trio got: {n}")
from_trio.put_nowait(n + 1)
if n >= 10:
aio_task.cancel()
return "trio-main-done"
raise AssertionError("should never be reached") # pragma: no cover
async def aio_pingpong(
from_trio: asyncio.Queue[int],
to_trio: MemorySendChannel[int],
) -> None:
print("aio_pingpong!")
try:
while True:
n = await from_trio.get()
print(f"aio got: {n}")
to_trio.send_nowait(n + 1)
except asyncio.CancelledError:
raise
except: # pragma: no cover
traceback.print_exc()
raise
assert (
aiotrio_run(
trio_main,
# Not all versions of asyncio we test on can actually be trusted,
# but this test doesn't care about signal handling, and it's
# easier to just avoid the warnings.
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
assert (
aiotrio_run(
trio_main,
# Also check that passing only call_soon_threadsafe works, via the
# fallback path where we use it for everything.
pass_not_threadsafe=False,
host_uses_signal_set_wakeup_fd=True,
)
== "trio-main-done"
)
def test_guest_mode_internal_errors(
monkeypatch: pytest.MonkeyPatch,
recwarn: pytest.WarningsRecorder,
) -> None:
with monkeypatch.context() as m:
async def crash_in_run_loop(in_host: InHost) -> None:
m.setattr("trio._core._run.GLOBAL_RUN_CONTEXT.runner.runq", "HI")
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_run_loop)
with monkeypatch.context() as m:
async def crash_in_io(in_host: InHost) -> None:
m.setattr("trio._core._run.TheIOManager.get_events", None)
await trio.lowlevel.checkpoint()
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_io)
with monkeypatch.context() as m:
async def crash_in_worker_thread_io(in_host: InHost) -> None:
t = threading.current_thread()
old_get_events = trio._core._run.TheIOManager.get_events
def bad_get_events(*args: Any) -> object:
if threading.current_thread() is not t:
raise ValueError("oh no!")
else:
return old_get_events(*args)
m.setattr("trio._core._run.TheIOManager.get_events", bad_get_events)
await trio.sleep(1)
with pytest.raises(trio.TrioInternalError):
trivial_guest_run(crash_in_worker_thread_io)
gc_collect_harder()
def test_guest_mode_ki() -> None:
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Check SIGINT in Trio func and in host func
async def trio_main(in_host: InHost) -> None:
with pytest.raises(KeyboardInterrupt):
signal_raise(signal.SIGINT)
# Host SIGINT should get injected into Trio
in_host(partial(signal_raise, signal.SIGINT))
await trio.sleep(10)
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main)
assert excinfo.value.__context__ is None
# Signal handler should be restored properly on exit
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
# Also check chaining in the case where KI is injected after main exits
final_exc = KeyError("whoa")
async def trio_main_raising(in_host: InHost) -> NoReturn:
in_host(partial(signal_raise, signal.SIGINT))
raise final_exc
with pytest.raises(KeyboardInterrupt) as excinfo:
trivial_guest_run(trio_main_raising)
assert excinfo.value.__context__ is final_exc
assert signal.getsignal(signal.SIGINT) is signal.default_int_handler
def test_guest_mode_autojump_clock_threshold_changing() -> None:
# This is super obscure and probably no-one will ever notice, but
# technically mutating the MockClock.autojump_threshold from the host
# should wake up the guest, so let's test it.
clock = trio.testing.MockClock()
DURATION = 120
async def trio_main(in_host: InHost) -> None:
assert trio.current_time() == 0
in_host(lambda: setattr(clock, "autojump_threshold", 0))
await trio.sleep(DURATION)
assert trio.current_time() == DURATION
start = time.monotonic()
trivial_guest_run(trio_main, clock=clock)
end = time.monotonic()
# Should be basically instantaneous, but we'll leave a generous buffer to
# account for any CI weirdness
assert end - start < DURATION / 2
@restore_unraisablehook()
def test_guest_mode_asyncgens() -> None:
import sniffio
record = set()
async def agen(label: str) -> AsyncGenerator[int, None]:
assert sniffio.current_async_library() == label
try:
yield 1
finally:
library = sniffio.current_async_library()
with contextlib.suppress(trio.Cancelled):
await sys.modules[library].sleep(0)
record.add((label, library))
async def iterate_in_aio() -> None:
await agen("asyncio").asend(None)
async def trio_main() -> None:
task = asyncio.ensure_future(iterate_in_aio())
done_evt = trio.Event()
task.add_done_callback(lambda _: done_evt.set())
with trio.fail_after(1):
await done_evt.wait()
await agen("trio").asend(None)
gc_collect_harder()
# Ensure we don't pollute the thread-level context if run under
# an asyncio without contextvars support (3.6)
context = contextvars.copy_context()
context.run(aiotrio_run, trio_main, host_uses_signal_set_wakeup_fd=True)
assert record == {("asyncio", "asyncio"), ("trio", "trio")}

View File

@ -0,0 +1,266 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Container, Iterable, NoReturn
import attrs
import pytest
from ... import _abc, _core
from .tutil import check_sequence_matches
if TYPE_CHECKING:
from ...lowlevel import Task
@attrs.define(eq=False, slots=False)
class TaskRecorder(_abc.Instrument):
record: list[tuple[str, Task | None]] = attrs.Factory(list)
def before_run(self) -> None:
self.record.append(("before_run", None))
def task_scheduled(self, task: Task) -> None:
self.record.append(("schedule", task))
def before_task_step(self, task: Task) -> None:
assert task is _core.current_task()
self.record.append(("before", task))
def after_task_step(self, task: Task) -> None:
assert task is _core.current_task()
self.record.append(("after", task))
def after_run(self) -> None:
self.record.append(("after_run", None))
def filter_tasks(self, tasks: Container[Task]) -> Iterable[tuple[str, Task | None]]:
for item in self.record:
if item[0] in ("schedule", "before", "after") and item[1] in tasks:
yield item
if item[0] in ("before_run", "after_run"):
yield item
def test_instruments(recwarn: object) -> None:
r1 = TaskRecorder()
r2 = TaskRecorder()
r3 = TaskRecorder()
task = None
# We use a child task for this, because the main task does some extra
# bookkeeping stuff that can leak into the instrument results, and we
# don't want to deal with it.
async def task_fn() -> None:
nonlocal task
task = _core.current_task()
for _ in range(4):
await _core.checkpoint()
# replace r2 with r3, to test that we can manipulate them as we go
_core.remove_instrument(r2)
with pytest.raises(KeyError):
_core.remove_instrument(r2)
# add is idempotent
_core.add_instrument(r3)
_core.add_instrument(r3)
for _ in range(1):
await _core.checkpoint()
async def main() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(task_fn)
_core.run(main, instruments=[r1, r2])
# It sleeps 5 times, so it runs 6 times. Note that checkpoint()
# reschedules the task immediately upon yielding, before the
# after_task_step event fires.
expected = (
[("before_run", None), ("schedule", task)]
+ [("before", task), ("schedule", task), ("after", task)] * 5
+ [("before", task), ("after", task), ("after_run", None)]
)
assert r1.record == r2.record + r3.record
assert task is not None
assert list(r1.filter_tasks([task])) == expected
def test_instruments_interleave() -> None:
tasks = {}
async def two_step1() -> None:
tasks["t1"] = _core.current_task()
await _core.checkpoint()
async def two_step2() -> None:
tasks["t2"] = _core.current_task()
await _core.checkpoint()
async def main() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(two_step1)
nursery.start_soon(two_step2)
r = TaskRecorder()
_core.run(main, instruments=[r])
expected = [
("before_run", None),
("schedule", tasks["t1"]),
("schedule", tasks["t2"]),
{
("before", tasks["t1"]),
("schedule", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("schedule", tasks["t2"]),
("after", tasks["t2"]),
},
{
("before", tasks["t1"]),
("after", tasks["t1"]),
("before", tasks["t2"]),
("after", tasks["t2"]),
},
("after_run", None),
]
print(list(r.filter_tasks(tasks.values())))
check_sequence_matches(list(r.filter_tasks(tasks.values())), expected)
def test_null_instrument() -> None:
# undefined instrument methods are skipped
class NullInstrument(_abc.Instrument):
def something_unrelated(self) -> None:
pass # pragma: no cover
async def main() -> None:
await _core.checkpoint()
_core.run(main, instruments=[NullInstrument()])
def test_instrument_before_after_run() -> None:
record = []
class BeforeAfterRun(_abc.Instrument):
def before_run(self) -> None:
record.append("before_run")
def after_run(self) -> None:
record.append("after_run")
async def main() -> None:
pass
_core.run(main, instruments=[BeforeAfterRun()])
assert record == ["before_run", "after_run"]
def test_instrument_task_spawn_exit() -> None:
record = []
class SpawnExitRecorder(_abc.Instrument):
def task_spawned(self, task: Task) -> None:
record.append(("spawned", task))
def task_exited(self, task: Task) -> None:
record.append(("exited", task))
async def main() -> Task:
return _core.current_task()
main_task = _core.run(main, instruments=[SpawnExitRecorder()])
assert ("spawned", main_task) in record
assert ("exited", main_task) in record
# This test also tests having a crash before the initial task is even spawned,
# which is very difficult to handle.
def test_instruments_crash(caplog: pytest.LogCaptureFixture) -> None:
record = []
class BrokenInstrument(_abc.Instrument):
def task_scheduled(self, task: Task) -> NoReturn:
record.append("scheduled")
raise ValueError("oops")
def close(self) -> None:
# Shouldn't be called -- tests that the instrument disabling logic
# works right.
record.append("closed") # pragma: no cover
async def main() -> Task:
record.append("main ran")
return _core.current_task()
r = TaskRecorder()
main_task = _core.run(main, instruments=[r, BrokenInstrument()])
assert record == ["scheduled", "main ran"]
# the TaskRecorder kept going throughout, even though the BrokenInstrument
# was disabled
assert ("after", main_task) in r.record
assert ("after_run", None) in r.record
# And we got a log message
assert caplog.records[0].exc_info is not None
exc_type, exc_value, exc_traceback = caplog.records[0].exc_info
assert exc_type is ValueError
assert str(exc_value) == "oops"
assert "Instrument has been disabled" in caplog.records[0].message
def test_instruments_monkeypatch() -> None:
class NullInstrument(_abc.Instrument):
pass
instrument = NullInstrument()
async def main() -> None:
record: list[Task] = []
# Changing the set of hooks implemented by an instrument after
# it's installed doesn't make them start being called right away
instrument.before_task_step = ( # type: ignore[method-assign]
record.append # type: ignore[assignment] # append is pos-only
)
await _core.checkpoint()
await _core.checkpoint()
assert len(record) == 0
# But if we remove and re-add the instrument, the new hooks are
# picked up
_core.remove_instrument(instrument)
_core.add_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.remove_instrument(instrument)
await _core.checkpoint()
await _core.checkpoint()
assert record.count(_core.current_task()) == 2
_core.run(main, instruments=[instrument])
def test_instrument_that_raises_on_getattr() -> None:
class EvilInstrument(_abc.Instrument):
def task_exited(self, task: Task) -> NoReturn:
raise AssertionError("this should never happen") # pragma: no cover
@property
def after_run(self) -> NoReturn:
raise ValueError("oops")
async def main() -> None:
with pytest.raises(ValueError, match="^oops$"):
_core.add_instrument(EvilInstrument())
# Make sure the instrument is fully removed from the per-method lists
runner = _core.current_task()._runner
assert "after_run" not in runner.instruments
assert "task_exited" not in runner.instruments
_core.run(main)

View File

@ -0,0 +1,480 @@
from __future__ import annotations
import random
import socket as stdlib_socket
from contextlib import suppress
from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
import pytest
import trio
from ... import _core
from ...testing import assert_checkpoints, wait_all_tasks_blocked
# Cross-platform tests for IO handling
if TYPE_CHECKING:
from collections.abc import Generator
from typing_extensions import ParamSpec
ArgsT = ParamSpec("ArgsT")
def fill_socket(sock: stdlib_socket.socket) -> None:
try:
while True:
sock.send(b"x" * 65536)
except BlockingIOError:
pass
def drain_socket(sock: stdlib_socket.socket) -> None:
try:
while True:
sock.recv(65536)
except BlockingIOError:
pass
WaitSocket = Callable[[stdlib_socket.socket], Awaitable[object]]
SocketPair = Tuple[stdlib_socket.socket, stdlib_socket.socket]
RetT = TypeVar("RetT")
@pytest.fixture
def socketpair() -> Generator[SocketPair, None, None]:
pair = stdlib_socket.socketpair()
for sock in pair:
sock.setblocking(False)
yield pair
for sock in pair:
sock.close()
def also_using_fileno(
fn: Callable[[stdlib_socket.socket | int], RetT],
) -> list[Callable[[stdlib_socket.socket], RetT]]:
def fileno_wrapper(fileobj: stdlib_socket.socket) -> RetT:
return fn(fileobj.fileno())
name = f"<{fn.__name__} on fileno>"
fileno_wrapper.__name__ = fileno_wrapper.__qualname__ = name
return [fn, fileno_wrapper]
# Decorators that feed in different settings for wait_readable / wait_writable
# / notify_closing.
# Note that if you use all three decorators on the same test, it will run all
# N**3 *combinations*
read_socket_test = pytest.mark.parametrize(
"wait_readable",
also_using_fileno(trio.lowlevel.wait_readable),
ids=lambda fn: fn.__name__,
)
write_socket_test = pytest.mark.parametrize(
"wait_writable",
also_using_fileno(trio.lowlevel.wait_writable),
ids=lambda fn: fn.__name__,
)
notify_closing_test = pytest.mark.parametrize(
"notify_closing",
also_using_fileno(trio.lowlevel.notify_closing),
ids=lambda fn: fn.__name__,
)
# XX These tests are all a bit dicey because they can't distinguish between
# wait_on_{read,writ}able blocking the way it should, versus blocking
# momentarily and then immediately resuming.
@read_socket_test
@write_socket_test
async def test_wait_basic(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
) -> None:
a, b = socketpair
# They start out writable()
with assert_checkpoints():
await wait_writable(a)
# But readable() blocks until data arrives
record = []
async def block_on_read() -> None:
try:
with assert_checkpoints():
await wait_readable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("readable")
assert a.recv(10) == b"x"
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
fill_socket(a)
# Now writable will block, but readable won't
with assert_checkpoints():
await wait_readable(b)
record = []
async def block_on_write() -> None:
try:
with assert_checkpoints():
await wait_writable(a)
except _core.Cancelled:
record.append("cancelled")
else:
record.append("writable")
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
assert record == []
drain_socket(b)
# check cancellation
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_read)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
fill_socket(a)
record = []
async with _core.open_nursery() as nursery:
nursery.start_soon(block_on_write)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == ["cancelled"]
@read_socket_test
async def test_double_read(socketpair: SocketPair, wait_readable: WaitSocket) -> None:
a, b = socketpair
# You can't have two tasks trying to read from a socket at the same time
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_readable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_readable(a)
nursery.cancel_scope.cancel()
@write_socket_test
async def test_double_write(socketpair: SocketPair, wait_writable: WaitSocket) -> None:
a, b = socketpair
# You can't have two tasks trying to write to a socket at the same time
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(wait_writable, a)
await wait_all_tasks_blocked()
with pytest.raises(_core.BusyResourceError):
await wait_writable(a)
nursery.cancel_scope.cancel()
@read_socket_test
@write_socket_test
@notify_closing_test
async def test_interrupted_by_close(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
notify_closing: Callable[[stdlib_socket.socket], object],
) -> None:
a, b = socketpair
async def reader() -> None:
with pytest.raises(_core.ClosedResourceError):
await wait_readable(a)
async def writer() -> None:
with pytest.raises(_core.ClosedResourceError):
await wait_writable(a)
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(reader)
nursery.start_soon(writer)
await wait_all_tasks_blocked()
notify_closing(a)
@read_socket_test
@write_socket_test
async def test_socket_simultaneous_read_write(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
) -> None:
record: list[str] = []
async def r_task(sock: stdlib_socket.socket) -> None:
await wait_readable(sock)
record.append("r_task")
async def w_task(sock: stdlib_socket.socket) -> None:
await wait_writable(sock)
record.append("w_task")
a, b = socketpair
fill_socket(a)
async with _core.open_nursery() as nursery:
nursery.start_soon(r_task, a)
nursery.start_soon(w_task, a)
await wait_all_tasks_blocked()
assert record == []
b.send(b"x")
await wait_all_tasks_blocked()
assert record == ["r_task"]
drain_socket(b)
await wait_all_tasks_blocked()
assert record == ["r_task", "w_task"]
@read_socket_test
@write_socket_test
async def test_socket_actual_streaming(
socketpair: SocketPair,
wait_readable: WaitSocket,
wait_writable: WaitSocket,
) -> None:
a, b = socketpair
# Use a small send buffer on one of the sockets to increase the chance of
# getting partial writes
a.setsockopt(stdlib_socket.SOL_SOCKET, stdlib_socket.SO_SNDBUF, 10000)
N = 1000000 # 1 megabyte
MAX_CHUNK = 65536
results: dict[str, int] = {}
async def sender(sock: stdlib_socket.socket, seed: int, key: str) -> None:
r = random.Random(seed)
sent = 0
while sent < N:
print("sent", sent)
chunk = bytearray(r.randrange(MAX_CHUNK))
while chunk:
with assert_checkpoints():
await wait_writable(sock)
this_chunk_size = sock.send(chunk)
sent += this_chunk_size
del chunk[:this_chunk_size]
sock.shutdown(stdlib_socket.SHUT_WR)
results[key] = sent
async def receiver(sock: stdlib_socket.socket, key: str) -> None:
received = 0
while True:
print("received", received)
with assert_checkpoints():
await wait_readable(sock)
this_chunk_size = len(sock.recv(MAX_CHUNK))
if not this_chunk_size:
break
received += this_chunk_size
results[key] = received
async with _core.open_nursery() as nursery:
nursery.start_soon(sender, a, 0, "send_a")
nursery.start_soon(sender, b, 1, "send_b")
nursery.start_soon(receiver, a, "recv_a")
nursery.start_soon(receiver, b, "recv_b")
assert results["send_a"] == results["recv_b"]
assert results["send_b"] == results["recv_a"]
async def test_notify_closing_on_invalid_object() -> None:
# It should either be a no-op (generally on Unix, where we don't know
# which fds are valid), or an OSError (on Windows, where we currently only
# support sockets, so we have to do some validation to figure out whether
# it's a socket or a regular handle).
got_oserror = False
got_no_error = False
try:
trio.lowlevel.notify_closing(-1)
except OSError:
got_oserror = True
else:
got_no_error = True
assert got_oserror or got_no_error
async def test_wait_on_invalid_object() -> None:
# We definitely want to raise an error everywhere if you pass in an
# invalid fd to wait_*
for wait in [trio.lowlevel.wait_readable, trio.lowlevel.wait_writable]:
with stdlib_socket.socket() as s:
fileno = s.fileno()
# We just closed the socket and don't do anything else in between, so
# we can be confident that the fileno hasn't be reassigned.
with pytest.raises(
OSError,
match=r"^\[\w+ \d+] (Bad file descriptor|An operation was attempted on something that is not a socket)$",
):
await wait(fileno)
async def test_io_manager_statistics() -> None:
def check(*, expected_readers: int, expected_writers: int) -> None:
statistics = _core.current_statistics()
print(statistics)
iostats = statistics.io_statistics
if iostats.backend == "epoll" or iostats.backend == "windows":
assert iostats.tasks_waiting_read == expected_readers
assert iostats.tasks_waiting_write == expected_writers
else:
assert iostats.backend == "kqueue"
assert iostats.tasks_waiting == expected_readers + expected_writers
a1, b1 = stdlib_socket.socketpair()
a2, b2 = stdlib_socket.socketpair()
a3, b3 = stdlib_socket.socketpair()
for sock in [a1, b1, a2, b2, a3, b3]:
sock.setblocking(False)
with a1, b1, a2, b2, a3, b3:
# let the call_soon_task settle down
await wait_all_tasks_blocked()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
# We want:
# - one socket with a writer blocked
# - two sockets with a reader blocked
# - a socket with both blocked
fill_socket(a1)
fill_socket(a3)
async with _core.open_nursery() as nursery:
nursery.start_soon(_core.wait_writable, a1)
nursery.start_soon(_core.wait_readable, a2)
nursery.start_soon(_core.wait_readable, b2)
nursery.start_soon(_core.wait_writable, a3)
nursery.start_soon(_core.wait_readable, a3)
await wait_all_tasks_blocked()
# +1 for call_soon_task
check(expected_readers=3 + 1, expected_writers=2)
nursery.cancel_scope.cancel()
# 1 for call_soon_task
check(expected_readers=1, expected_writers=0)
async def test_can_survive_unnotified_close() -> None:
# An "unnotified" close is when the user closes an fd/socket/handle
# directly, without calling notify_closing first. This should never happen
# -- users should call notify_closing before closing things. But, just in
# case they don't, we would still like to avoid exploding.
#
# Acceptable behaviors:
# - wait_* never return, but can be cancelled cleanly
# - wait_* exit cleanly
# - wait_* raise an OSError
#
# Not acceptable:
# - getting stuck in an uncancellable state
# - TrioInternalError blowing up the whole run
#
# This test exercises some tricky "unnotified close" scenarios, to make
# sure we get the "acceptable" behaviors.
async def allow_OSError(
async_func: Callable[ArgsT, Awaitable[object]],
*args: ArgsT.args,
**kwargs: ArgsT.kwargs,
) -> None:
with suppress(OSError):
await async_func(*args, **kwargs)
with stdlib_socket.socket() as s:
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# We hit different paths on Windows depending on whether we close the last
# handle to the object (which produces a LOCAL_CLOSE notification and
# wakes up wait_readable), or only close one of the handles (which leaves
# wait_readable pending until cancelled).
with stdlib_socket.socket() as s, s.dup() as s2: # noqa: F841
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, s)
await wait_all_tasks_blocked()
s.close()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
# A more elaborate case, with two tasks waiting. On windows and epoll,
# the two tasks get muxed together onto a single underlying wait
# operation. So when they're cancelled, there's a brief moment where one
# of the tasks is cancelled but the other isn't, so we try to re-issue the
# underlying wait operation. But here, the handle we were going to use to
# do that has been pulled out from under our feet... so test that we can
# survive this.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2:
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
await wait_all_tasks_blocked()
a.close()
nursery.cancel_scope.cancel()
# A similar case, but now the single-task-wakeup happens due to I/O
# arriving, not a cancellation, so the operation gets re-issued from
# handle_io context rather than abort context.
a, b = stdlib_socket.socketpair()
with a, b, a.dup() as a2:
print(f"a={a.fileno()}, b={b.fileno()}, a2={a2.fileno()}")
a.setblocking(False)
b.setblocking(False)
fill_socket(a)
e = trio.Event()
# We want to wait for the kernel to process the wakeup on 'a', if any.
# But depending on the platform, we might not get a wakeup on 'a'. So
# we put one task to sleep waiting on 'a', and we put a second task to
# sleep waiting on 'a2', with the idea that the 'a2' notification will
# definitely arrive, and when it does then we can assume that whatever
# notification was going to arrive for 'a' has also arrived.
async def wait_readable_a2_then_set() -> None:
await trio.lowlevel.wait_readable(a2)
e.set()
async with trio.open_nursery() as nursery:
nursery.start_soon(allow_OSError, trio.lowlevel.wait_readable, a)
nursery.start_soon(allow_OSError, trio.lowlevel.wait_writable, a)
nursery.start_soon(wait_readable_a2_then_set)
await wait_all_tasks_blocked()
a.close()
b.send(b"x")
# Make sure that the wakeup has been received and everything has
# settled before cancelling the wait_writable.
await e.wait()
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()

View File

@ -0,0 +1,517 @@
from __future__ import annotations
import contextlib
import inspect
import signal
import threading
from typing import TYPE_CHECKING, AsyncIterator, Callable, Iterator
import outcome
import pytest
from trio.testing import RaisesGroup
try:
from async_generator import async_generator, yield_
except ImportError: # pragma: no cover
async_generator = yield_ = None
from ... import _core
from ..._abc import Instrument
from ..._timeouts import sleep
from ..._util import signal_raise
from ...testing import wait_all_tasks_blocked
if TYPE_CHECKING:
from ..._core import Abort, RaiseCancelT
def ki_self() -> None:
signal_raise(signal.SIGINT)
def test_ki_self() -> None:
with pytest.raises(KeyboardInterrupt):
ki_self()
async def test_ki_enabled() -> None:
# Regular tasks aren't KI-protected
assert not _core.currently_ki_protected()
# Low-level call-soon callbacks are KI-protected
token = _core.current_trio_token()
record = []
def check() -> None:
record.append(_core.currently_ki_protected())
token.run_sync_soon(check)
await wait_all_tasks_blocked()
assert record == [True]
@_core.enable_ki_protection
def protected() -> None:
assert _core.currently_ki_protected()
unprotected()
@_core.disable_ki_protection
def unprotected() -> None:
assert not _core.currently_ki_protected()
protected()
@_core.enable_ki_protection
async def aprotected() -> None:
assert _core.currently_ki_protected()
await aunprotected()
@_core.disable_ki_protection
async def aunprotected() -> None:
assert not _core.currently_ki_protected()
await aprotected()
# make sure that the decorator here overrides the automatic manipulation
# that start_soon() does:
async with _core.open_nursery() as nursery:
nursery.start_soon(aprotected)
nursery.start_soon(aunprotected)
@_core.enable_ki_protection
def gen_protected() -> Iterator[None]:
assert _core.currently_ki_protected()
yield
for _ in gen_protected():
pass
@_core.disable_ki_protection
def gen_unprotected() -> Iterator[None]:
assert not _core.currently_ki_protected()
yield
for _ in gen_unprotected():
pass
# This used to be broken due to
#
# https://bugs.python.org/issue29590
#
# Specifically, after a coroutine is resumed with .throw(), then the stack
# makes it look like the immediate caller is the function that called
# .throw(), not the actual caller. So child() here would have a caller deep in
# the guts of the run loop, and always be protected, even when it shouldn't
# have been. (Solution: we don't use .throw() anymore.)
async def test_ki_enabled_after_yield_briefly() -> None:
@_core.enable_ki_protection
async def protected() -> None:
await child(True)
@_core.disable_ki_protection
async def unprotected() -> None:
await child(False)
async def child(expected: bool) -> None:
import traceback
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await _core.checkpoint()
traceback.print_stack()
assert _core.currently_ki_protected() == expected
await protected()
await unprotected()
# This also used to be broken due to
# https://bugs.python.org/issue29590
async def test_generator_based_context_manager_throw() -> None:
@contextlib.contextmanager
@_core.enable_ki_protection
def protected_manager() -> Iterator[None]:
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
with protected_manager():
assert not _core.currently_ki_protected()
with pytest.raises(KeyError):
# This is the one that used to fail
with protected_manager():
raise KeyError
# the async_generator package isn't typed, hence all the type: ignores
@pytest.mark.skipif(async_generator is None, reason="async_generator not installed")
async def test_async_generator_agen_protection() -> None:
@_core.enable_ki_protection
@async_generator # type: ignore[misc] # untyped generator
async def agen_protected1() -> None:
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
@async_generator # type: ignore[misc] # untyped generator
async def agen_unprotected1() -> None:
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
# Swap the order of the decorators:
@async_generator # type: ignore[misc] # untyped generator
@_core.enable_ki_protection
async def agen_protected2() -> None:
assert _core.currently_ki_protected()
try:
await yield_()
finally:
assert _core.currently_ki_protected()
@async_generator # type: ignore[misc] # untyped generator
@_core.disable_ki_protection
async def agen_unprotected2() -> None:
assert not _core.currently_ki_protected()
try:
await yield_()
finally:
assert not _core.currently_ki_protected()
await _check_agen(agen_protected1)
await _check_agen(agen_protected2)
await _check_agen(agen_unprotected1)
await _check_agen(agen_unprotected2)
async def test_native_agen_protection() -> None:
# Native async generators
@_core.enable_ki_protection
async def agen_protected() -> AsyncIterator[None]:
assert _core.currently_ki_protected()
try:
yield
finally:
assert _core.currently_ki_protected()
@_core.disable_ki_protection
async def agen_unprotected() -> AsyncIterator[None]:
assert not _core.currently_ki_protected()
try:
yield
finally:
assert not _core.currently_ki_protected()
await _check_agen(agen_protected)
await _check_agen(agen_unprotected)
async def _check_agen(agen_fn: Callable[[], AsyncIterator[None]]) -> None:
async for _ in agen_fn():
assert not _core.currently_ki_protected()
# asynccontextmanager insists that the function passed must itself be an
# async gen function, not a wrapper around one
if inspect.isasyncgenfunction(agen_fn):
async with contextlib.asynccontextmanager(agen_fn)():
assert not _core.currently_ki_protected()
# Another case that's tricky due to:
# https://bugs.python.org/issue29590
with pytest.raises(KeyError):
async with contextlib.asynccontextmanager(agen_fn)():
raise KeyError
# Test the case where there's no magic local anywhere in the call stack
def test_ki_disabled_out_of_context() -> None:
assert _core.currently_ki_protected()
def test_ki_disabled_in_del() -> None:
def nestedfunction() -> bool:
return _core.currently_ki_protected()
def __del__() -> None:
assert _core.currently_ki_protected()
assert nestedfunction()
@_core.disable_ki_protection
def outerfunction() -> None:
assert not _core.currently_ki_protected()
assert not nestedfunction()
__del__()
__del__()
outerfunction()
assert nestedfunction()
def test_ki_protection_works() -> None:
async def sleeper(name: str, record: set[str]) -> None:
try:
while True:
await _core.checkpoint()
except _core.Cancelled:
record.add(name + " ok")
async def raiser(name: str, record: set[str]) -> None:
try:
# os.kill runs signal handlers before returning, so we don't need
# to worry that the handler will be delayed
print("killing, protection =", _core.currently_ki_protected())
ki_self()
except KeyboardInterrupt:
print("raised!")
# Make sure we aren't getting cancelled as well as siginted
await _core.checkpoint()
record.add(name + " raise ok")
raise
else:
print("didn't raise!")
# If we didn't raise (b/c protected), then we *should* get
# cancelled at the next opportunity
try:
await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
except _core.Cancelled:
record.add(name + " cancel ok")
# simulated control-C during raiser, which is *unprotected*
print("check 1")
record_set: set[str] = set()
async def check_unprotected_kill() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record_set)
nursery.start_soon(sleeper, "s2", record_set)
nursery.start_soon(raiser, "r1", record_set)
# raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup
with RaisesGroup(KeyboardInterrupt):
_core.run(check_unprotected_kill)
assert record_set == {"s1 ok", "s2 ok", "r1 raise ok"}
# simulated control-C during raiser, which is *protected*, so the KI gets
# delivered to the main task instead
print("check 2")
record_set = set()
async def check_protected_kill() -> None:
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper, "s1", record_set)
nursery.start_soon(sleeper, "s2", record_set)
nursery.start_soon(_core.enable_ki_protection(raiser), "r1", record_set)
# __aexit__ blocks, and then receives the KI
# raises inside a nursery, so the KeyboardInterrupt is wrapped in an ExceptionGroup
with RaisesGroup(KeyboardInterrupt):
_core.run(check_protected_kill)
assert record_set == {"s1 ok", "s2 ok", "r1 cancel ok"}
# kill at last moment still raises (run_sync_soon until it raises an
# error, then kill)
print("check 3")
async def check_kill_during_shutdown() -> None:
token = _core.current_trio_token()
def kill_during_shutdown() -> None:
assert _core.currently_ki_protected()
try:
token.run_sync_soon(kill_during_shutdown)
except _core.RunFinishedError:
# it's too late for regular handling! handle this!
print("kill! kill!")
ki_self()
token.run_sync_soon(kill_during_shutdown)
# no nurseries involved, so the KeyboardInterrupt isn't wrapped
with pytest.raises(KeyboardInterrupt):
_core.run(check_kill_during_shutdown)
# KI arrives very early, before main is even spawned
print("check 4")
class InstrumentOfDeath(Instrument):
def before_run(self) -> None:
ki_self()
async def main_1() -> None:
await _core.checkpoint()
# no nurseries involved, so the KeyboardInterrupt isn't wrapped
with pytest.raises(KeyboardInterrupt):
_core.run(main_1, instruments=[InstrumentOfDeath()])
# checkpoint_if_cancelled notices pending KI
print("check 5")
@_core.enable_ki_protection
async def main_2() -> None:
assert _core.currently_ki_protected()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint_if_cancelled()
_core.run(main_2)
# KI arrives while main task is not abortable, b/c already scheduled
print("check 6")
@_core.enable_ki_protection
async def main_3() -> None:
assert _core.currently_ki_protected()
ki_self()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
await _core.cancel_shielded_checkpoint()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main_3)
# KI arrives while main task is not abortable, b/c refuses to be aborted
print("check 7")
@_core.enable_ki_protection
async def main_4() -> None:
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(_: RaiseCancelT) -> Abort:
_core.reschedule(task, outcome.Value(1))
return _core.Abort.FAILED
assert await _core.wait_task_rescheduled(abort) == 1
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
_core.run(main_4)
# KI delivered via slow abort
print("check 8")
@_core.enable_ki_protection
async def main_5() -> None:
assert _core.currently_ki_protected()
ki_self()
task = _core.current_task()
def abort(raise_cancel: RaiseCancelT) -> Abort:
result = outcome.capture(raise_cancel)
_core.reschedule(task, result)
return _core.Abort.FAILED
with pytest.raises(KeyboardInterrupt):
assert await _core.wait_task_rescheduled(abort)
await _core.checkpoint()
_core.run(main_5)
# KI arrives just before main task exits, so the run_sync_soon machinery
# is still functioning and will accept the callback to deliver the KI, but
# by the time the callback is actually run, main has exited and can't be
# aborted.
print("check 9")
@_core.enable_ki_protection
async def main_6() -> None:
ki_self()
with pytest.raises(KeyboardInterrupt):
_core.run(main_6)
print("check 10")
# KI in unprotected code, with
# restrict_keyboard_interrupt_to_checkpoints=True
record_list = []
async def main_7() -> None:
# We're not KI protected...
assert not _core.currently_ki_protected()
ki_self()
# ...but even after the KI, we keep running uninterrupted...
record_list.append("ok")
# ...until we hit a checkpoint:
with pytest.raises(KeyboardInterrupt):
await sleep(10)
_core.run(main_7, restrict_keyboard_interrupt_to_checkpoints=True)
assert record_list == ["ok"]
record_list = []
# Exact same code raises KI early if we leave off the argument, doesn't
# even reach the record.append call:
with pytest.raises(KeyboardInterrupt):
_core.run(main_7)
assert record_list == []
# KI arrives while main task is inside a cancelled cancellation scope
# the KeyboardInterrupt should take priority
print("check 11")
@_core.enable_ki_protection
async def main_8() -> None:
assert _core.currently_ki_protected()
with _core.CancelScope() as cancel_scope:
cancel_scope.cancel()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
ki_self()
with pytest.raises(KeyboardInterrupt):
await _core.checkpoint()
with pytest.raises(_core.Cancelled):
await _core.checkpoint()
_core.run(main_8)
def test_ki_is_good_neighbor() -> None:
# in the unlikely event someone overwrites our signal handler, we leave
# the overwritten one be
try:
orig = signal.getsignal(signal.SIGINT)
def my_handler(signum: object, frame: object) -> None: # pragma: no cover
pass
async def main() -> None:
signal.signal(signal.SIGINT, my_handler)
_core.run(main)
assert signal.getsignal(signal.SIGINT) is my_handler
finally:
signal.signal(signal.SIGINT, orig)
# Regression test for #461
# don't know if _active not being visible is a problem
def test_ki_with_broken_threads() -> None:
thread = threading.main_thread()
# scary!
original = threading._active[thread.ident] # type: ignore[attr-defined]
# put this in a try finally so we don't have a chance of cascading a
# breakage down to everything else
try:
del threading._active[thread.ident] # type: ignore[attr-defined]
@_core.enable_ki_protection
async def inner() -> None:
assert signal.getsignal(signal.SIGINT) != signal.default_int_handler
_core.run(inner)
finally:
threading._active[thread.ident] = original # type: ignore[attr-defined]

View File

@ -0,0 +1,118 @@
import pytest
from trio import run
from trio.lowlevel import RunVar, RunVarToken
from ... import _core
# scary runvar tests
def test_runvar_smoketest() -> None:
t1 = RunVar[str]("test1")
t2 = RunVar[str]("test2", default="catfish")
assert repr(t1) == "<RunVar name='test1'>"
async def first_check() -> None:
with pytest.raises(LookupError):
t1.get()
t1.set("swordfish")
assert t1.get() == "swordfish"
assert t2.get() == "catfish"
assert t2.get(default="eel") == "eel"
t2.set("goldfish")
assert t2.get() == "goldfish"
assert t2.get(default="tuna") == "goldfish"
async def second_check() -> None:
with pytest.raises(LookupError):
t1.get()
assert t2.get() == "catfish"
run(first_check)
run(second_check)
def test_runvar_resetting() -> None:
t1 = RunVar[str]("test1")
t2 = RunVar[str]("test2", default="dogfish")
t3 = RunVar[str]("test3")
async def reset_check() -> None:
token = t1.set("moonfish")
assert t1.get() == "moonfish"
t1.reset(token)
with pytest.raises(TypeError):
t1.reset(None) # type: ignore[arg-type]
with pytest.raises(LookupError):
t1.get()
token2 = t2.set("catdogfish")
assert t2.get() == "catdogfish"
t2.reset(token2)
assert t2.get() == "dogfish"
with pytest.raises(ValueError, match="^token has already been used$"):
t2.reset(token2)
token3 = t3.set("basculin")
assert t3.get() == "basculin"
with pytest.raises(ValueError, match="^token is not for us$"):
t1.reset(token3)
run(reset_check)
def test_runvar_sync() -> None:
t1 = RunVar[str]("test1")
async def sync_check() -> None:
async def task1() -> None:
t1.set("plaice")
assert t1.get() == "plaice"
async def task2(tok: RunVarToken[str]) -> None:
t1.reset(tok)
with pytest.raises(LookupError):
t1.get()
t1.set("haddock")
async with _core.open_nursery() as n:
token = t1.set("cod")
assert t1.get() == "cod"
n.start_soon(task1)
await _core.wait_all_tasks_blocked()
assert t1.get() == "plaice"
n.start_soon(task2, token)
await _core.wait_all_tasks_blocked()
assert t1.get() == "haddock"
run(sync_check)
def test_accessing_runvar_outside_run_call_fails() -> None:
t1 = RunVar[str]("test1")
with pytest.raises(RuntimeError):
t1.set("asdf")
with pytest.raises(RuntimeError):
t1.get()
async def get_token() -> RunVarToken[str]:
return t1.set("ok")
token = run(get_token)
with pytest.raises(RuntimeError):
t1.reset(token)

View File

@ -0,0 +1,175 @@
import time
from math import inf
import pytest
from trio import sleep
from ... import _core
from .. import wait_all_tasks_blocked
from .._mock_clock import MockClock
from .tutil import slow
def test_mock_clock() -> None:
REAL_NOW = 123.0
c = MockClock()
c._real_clock = lambda: REAL_NOW
repr(c) # smoke test
assert c.rate == 0
assert c.current_time() == 0
c.jump(1.2)
assert c.current_time() == 1.2
with pytest.raises(ValueError, match="^time can't go backwards$"):
c.jump(-1)
assert c.current_time() == 1.2
assert c.deadline_to_sleep_time(1.1) == 0
assert c.deadline_to_sleep_time(1.2) == 0
assert c.deadline_to_sleep_time(1.3) > 999999
with pytest.raises(ValueError, match="^rate must be >= 0$"):
c.rate = -1
assert c.rate == 0
c.rate = 2
assert c.current_time() == 1.2
REAL_NOW += 1
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 0.5
c.rate = 0.5
assert c.current_time() == 3.2
assert c.deadline_to_sleep_time(3.1) == 0
assert c.deadline_to_sleep_time(3.2) == 0
assert c.deadline_to_sleep_time(4.2) == 2.0
c.jump(0.8)
assert c.current_time() == 4.0
REAL_NOW += 1
assert c.current_time() == 4.5
c2 = MockClock(rate=3)
assert c2.rate == 3
assert c2.current_time() < 10
async def test_mock_clock_autojump(mock_clock: MockClock) -> None:
assert mock_clock.autojump_threshold == inf
mock_clock.autojump_threshold = 0
assert mock_clock.autojump_threshold == 0
real_start = time.perf_counter()
virtual_start = _core.current_time()
for i in range(10):
print(f"sleeping {10 * i} seconds")
await sleep(10 * i)
print("woke up!")
assert virtual_start + 10 * i == _core.current_time()
virtual_start = _core.current_time()
real_duration = time.perf_counter() - real_start
print(f"Slept {10 * sum(range(10))} seconds in {real_duration} seconds")
assert real_duration < 1
mock_clock.autojump_threshold = 0.02
t = _core.current_time()
# this should wake up before the autojump threshold triggers, so time
# shouldn't change
await wait_all_tasks_blocked()
assert t == _core.current_time()
# this should too
await wait_all_tasks_blocked(0.01)
assert t == _core.current_time()
# set up a situation where the autojump task is blocked for a long long
# time, to make sure that cancel-and-adjust-threshold logic is working
mock_clock.autojump_threshold = 10000
await wait_all_tasks_blocked()
mock_clock.autojump_threshold = 0
# if the above line didn't take affect immediately, then this would be
# bad:
await sleep(100000)
async def test_mock_clock_autojump_interference(mock_clock: MockClock) -> None:
mock_clock.autojump_threshold = 0.02
mock_clock2 = MockClock()
# messing with the autojump threshold of a clock that isn't actually
# installed in the run loop shouldn't do anything.
mock_clock2.autojump_threshold = 0.01
# if the autojump_threshold of 0.01 were in effect, then the next line
# would block forever, as the autojump task kept waking up to try to
# jump the clock.
await wait_all_tasks_blocked(0.015)
# but the 0.02 limit does apply
await sleep(100000)
def test_mock_clock_autojump_preset() -> None:
# Check that we can set the autojump_threshold before the clock is
# actually in use, and it gets picked up
mock_clock = MockClock(autojump_threshold=0.1)
mock_clock.autojump_threshold = 0.01
real_start = time.perf_counter()
_core.run(sleep, 10000, clock=mock_clock)
assert time.perf_counter() - real_start < 1
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_0(
mock_clock: MockClock,
) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with the default cushion=0.
mock_clock.autojump_threshold = 0
record = []
async def sleeper() -> None:
await sleep(100)
record.append("yawn")
async def waiter() -> None:
await wait_all_tasks_blocked()
record.append("waiter woke")
await sleep(1000)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter woke", "yawn", "waiter done"]
@slow
async def test_mock_clock_autojump_0_and_wait_all_tasks_blocked_nonzero(
mock_clock: MockClock,
) -> None:
# Checks that autojump_threshold=0 doesn't interfere with
# calling wait_all_tasks_blocked with a non-zero cushion.
mock_clock.autojump_threshold = 0
record = []
async def sleeper() -> None:
await sleep(100)
record.append("yawn")
async def waiter() -> None:
await wait_all_tasks_blocked(1)
record.append("waiter done")
async with _core.open_nursery() as nursery:
nursery.start_soon(sleeper)
nursery.start_soon(waiter)
assert record == ["waiter done", "yawn"]

View File

@ -0,0 +1,384 @@
from __future__ import annotations
import re
from typing import TypeVar
import pytest
import trio
from trio.lowlevel import (
add_parking_lot_breaker,
current_task,
remove_parking_lot_breaker,
)
from trio.testing import Matcher, RaisesGroup
from ... import _core
from ...testing import wait_all_tasks_blocked
from .._parking_lot import ParkingLot
from .tutil import check_sequence_matches
T = TypeVar("T")
async def test_parking_lot_basic() -> None:
record = []
async def waiter(i: int, lot: ParkingLot) -> None:
record.append(f"sleep {i}")
await lot.park()
record.append(f"wake {i}")
async with _core.open_nursery() as nursery:
lot = ParkingLot()
assert not lot
assert len(lot) == 0
assert lot.statistics().tasks_waiting == 0
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
assert bool(lot)
assert len(lot) == 3
assert lot.statistics().tasks_waiting == 3
lot.unpark_all()
assert lot.statistics().tasks_waiting == 0
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record,
[{"sleep 0", "sleep 1", "sleep 2"}, {"wake 0", "wake 1", "wake 2"}],
)
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
assert len(record) == 3
for _ in range(3):
lot.unpark()
await wait_all_tasks_blocked()
# 1-by-1 wakeups are strict FIFO
assert record == [
"sleep 0",
"sleep 1",
"sleep 2",
"wake 0",
"wake 1",
"wake 2",
]
# It's legal (but a no-op) to try and unpark while there's nothing parked
lot.unpark()
lot.unpark(count=1)
lot.unpark(count=100)
# Check unpark with count
async with _core.open_nursery() as nursery:
record = []
for i in range(3):
nursery.start_soon(waiter, i, lot)
await wait_all_tasks_blocked()
lot.unpark(count=2)
await wait_all_tasks_blocked()
check_sequence_matches(
record,
["sleep 0", "sleep 1", "sleep 2", {"wake 0", "wake 1"}],
)
lot.unpark_all()
with pytest.raises(
ValueError,
match=r"^Cannot pop a non-integer number of tasks\.$",
):
lot.unpark(count=1.5)
async def cancellable_waiter(
name: T,
lot: ParkingLot,
scopes: dict[T, _core.CancelScope],
record: list[str],
) -> None:
with _core.CancelScope() as scope:
scopes[name] = scope
record.append(f"sleep {name}")
try:
await lot.park()
except _core.Cancelled:
record.append(f"cancelled {name}")
else:
record.append(f"wake {name}")
async def test_parking_lot_cancel() -> None:
record: list[str] = []
scopes: dict[int, _core.CancelScope] = {}
async with _core.open_nursery() as nursery:
lot = ParkingLot()
nursery.start_soon(cancellable_waiter, 1, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(record) == 4
lot.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 6
check_sequence_matches(
record,
["sleep 1", "sleep 2", "sleep 3", "cancelled 2", {"wake 1", "wake 3"}],
)
async def test_parking_lot_repark() -> None:
record: list[str] = []
scopes: dict[int, _core.CancelScope] = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
with pytest.raises(TypeError):
lot1.repark([]) # type: ignore[arg-type]
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
lot1.repark(lot2)
assert len(lot1) == 2
assert len(lot2) == 1
lot2.unpark_all()
await wait_all_tasks_blocked()
assert len(record) == 4
assert record == ["sleep 1", "sleep 2", "sleep 3", "wake 1"]
lot1.repark_all(lot2)
assert len(lot1) == 0
assert len(lot2) == 2
scopes[2].cancel()
await wait_all_tasks_blocked()
assert len(lot2) == 1
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
]
lot2.unpark_all()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"cancelled 2",
"wake 3",
]
async def test_parking_lot_repark_with_count() -> None:
record: list[str] = []
scopes: dict[int, _core.CancelScope] = {}
lot1 = ParkingLot()
lot2 = ParkingLot()
async with _core.open_nursery() as nursery:
nursery.start_soon(cancellable_waiter, 1, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 2, lot1, scopes, record)
await wait_all_tasks_blocked()
nursery.start_soon(cancellable_waiter, 3, lot1, scopes, record)
await wait_all_tasks_blocked()
assert len(record) == 3
assert len(lot1) == 3
assert len(lot2) == 0
lot1.repark(lot2, count=2)
assert len(lot1) == 1
assert len(lot2) == 2
while lot2:
lot2.unpark()
await wait_all_tasks_blocked()
assert record == [
"sleep 1",
"sleep 2",
"sleep 3",
"wake 1",
"wake 2",
]
lot1.unpark_all()
async def dummy_task(
task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED,
) -> None:
task_status.started(_core.current_task())
await trio.sleep_forever()
async def test_parking_lot_breaker_basic() -> None:
"""Test basic functionality for breaking lots."""
lot = ParkingLot()
task = current_task()
# defaults to current task
lot.break_lot()
assert lot.broken_by == [task]
# breaking the lot again with the same task appends another copy in `broken_by`
lot.break_lot()
assert lot.broken_by == [task, task]
# trying to park in broken lot errors
broken_by_str = re.escape(str([task, task]))
with pytest.raises(
_core.BrokenResourceError,
match=f"^Attempted to park in parking lot broken by {broken_by_str}$",
):
await lot.park()
async def test_parking_lot_break_parking_tasks() -> None:
"""Checks that tasks currently waiting to park raise an error when the breaker exits."""
async def bad_parker(lot: ParkingLot, scope: _core.CancelScope) -> None:
add_parking_lot_breaker(current_task(), lot)
with scope:
await trio.sleep_forever()
lot = ParkingLot()
cs = _core.CancelScope()
# check that parked task errors
with RaisesGroup(
Matcher(_core.BrokenResourceError, match="^Parking lot broken by"),
):
async with _core.open_nursery() as nursery:
nursery.start_soon(bad_parker, lot, cs)
await wait_all_tasks_blocked()
nursery.start_soon(lot.park)
await wait_all_tasks_blocked()
cs.cancel()
async def test_parking_lot_breaker_registration() -> None:
lot = ParkingLot()
task = current_task()
with pytest.raises(
RuntimeError,
match="Attempted to remove task as breaker for a lot it is not registered for",
):
remove_parking_lot_breaker(task, lot)
# check that a task can be registered as breaker for the same lot multiple times
add_parking_lot_breaker(task, lot)
add_parking_lot_breaker(task, lot)
remove_parking_lot_breaker(task, lot)
remove_parking_lot_breaker(task, lot)
with pytest.raises(
RuntimeError,
match="Attempted to remove task as breaker for a lot it is not registered for",
):
remove_parking_lot_breaker(task, lot)
# registering a task as breaker on an already broken lot is fine
lot.break_lot()
child_task = None
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
add_parking_lot_breaker(child_task, lot)
nursery.cancel_scope.cancel()
assert lot.broken_by == [task, child_task]
# manually breaking a lot with an already exited task is fine
lot = ParkingLot()
lot.break_lot(child_task)
assert lot.broken_by == [child_task]
async def test_parking_lot_breaker_rebreak() -> None:
lot = ParkingLot()
task = current_task()
lot.break_lot()
# breaking an already broken lot with a different task is allowed
# The nursery is only to create a task we can pass to lot.break_lot
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
lot.break_lot(child_task)
nursery.cancel_scope.cancel()
assert lot.broken_by == [task, child_task]
async def test_parking_lot_multiple_breakers_exit() -> None:
# register multiple tasks as lot breakers, then have them all exit
lot = ParkingLot()
async with trio.open_nursery() as nursery:
child_task1 = await nursery.start(dummy_task)
child_task2 = await nursery.start(dummy_task)
child_task3 = await nursery.start(dummy_task)
add_parking_lot_breaker(child_task1, lot)
add_parking_lot_breaker(child_task2, lot)
add_parking_lot_breaker(child_task3, lot)
nursery.cancel_scope.cancel()
# I think the order is guaranteed currently, but doesn't hurt to be safe.
assert set(lot.broken_by) == {child_task1, child_task2, child_task3}
async def test_parking_lot_breaker_register_exited_task() -> None:
lot = ParkingLot()
child_task = None
async with trio.open_nursery() as nursery:
child_task = await nursery.start(dummy_task)
nursery.cancel_scope.cancel()
# trying to register an exited task as lot breaker errors
with pytest.raises(
trio.BrokenResourceError,
match="^Attempted to add already exited task as lot breaker.$",
):
add_parking_lot_breaker(child_task, lot)
async def test_parking_lot_break_itself() -> None:
"""Break a parking lot, where the breakee is parked.
Doing this is weird, but should probably be supported.
"""
async def return_me_and_park(
lot: ParkingLot,
*,
task_status: _core.TaskStatus[_core.Task] = trio.TASK_STATUS_IGNORED,
) -> None:
task_status.started(_core.current_task())
await lot.park()
lot = ParkingLot()
with RaisesGroup(
Matcher(_core.BrokenResourceError, match="^Parking lot broken by"),
):
async with _core.open_nursery() as nursery:
child_task = await nursery.start(return_me_and_park, lot)
lot.break_lot(child_task)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,195 @@
from __future__ import annotations
import threading
import time
from contextlib import contextmanager
from queue import Queue
from typing import TYPE_CHECKING, Iterator, NoReturn
import pytest
from .. import _thread_cache
from .._thread_cache import ThreadCache, start_thread_soon
from .tutil import gc_collect_harder, slow
if TYPE_CHECKING:
from outcome import Outcome
def test_thread_cache_basics() -> None:
q: Queue[Outcome[object]] = Queue()
def fn() -> NoReturn:
raise RuntimeError("hi")
def deliver(outcome: Outcome[object]) -> None:
q.put(outcome)
start_thread_soon(fn, deliver)
outcome = q.get()
with pytest.raises(RuntimeError, match="hi"):
outcome.unwrap()
def test_thread_cache_deref() -> None:
res = [False]
class del_me:
def __call__(self) -> int:
return 42
def __del__(self) -> None:
res[0] = True
q: Queue[Outcome[int]] = Queue()
def deliver(outcome: Outcome[int]) -> None:
q.put(outcome)
start_thread_soon(del_me(), deliver)
outcome = q.get()
assert outcome.unwrap() == 42
gc_collect_harder()
assert res[0]
@slow
def test_spawning_new_thread_from_deliver_reuses_starting_thread() -> None:
# We know that no-one else is using the thread cache, so if we keep
# submitting new jobs the instant the previous one is finished, we should
# keep getting the same thread over and over. This tests both that the
# thread cache is LIFO, and that threads can be assigned new work *before*
# deliver exits.
# Make sure there are a few threads running, so if we weren't LIFO then we
# could grab the wrong one.
q: Queue[Outcome[object]] = Queue()
COUNT = 5
for _ in range(COUNT):
start_thread_soon(lambda: time.sleep(1), lambda result: q.put(result))
for _ in range(COUNT):
q.get().unwrap()
seen_threads = set()
done = threading.Event()
def deliver(n: int, _: object) -> None:
print(n)
seen_threads.add(threading.current_thread())
if n == 0:
done.set()
else:
start_thread_soon(lambda: None, lambda _: deliver(n - 1, _))
start_thread_soon(lambda: None, lambda _: deliver(5, _))
done.wait()
assert len(seen_threads) == 1
@slow
def test_idle_threads_exit(monkeypatch: pytest.MonkeyPatch) -> None:
# Temporarily set the idle timeout to something tiny, to speed up the
# test. (But non-zero, so that the worker loop will at least yield the
# CPU.)
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
q: Queue[threading.Thread] = Queue()
start_thread_soon(lambda: None, lambda _: q.put(threading.current_thread()))
seen_thread = q.get()
# Since the idle timeout is 0, after sleeping for 1 second, the thread
# should have exited
time.sleep(1)
assert not seen_thread.is_alive()
@contextmanager
def _join_started_threads() -> Iterator[None]:
before = frozenset(threading.enumerate())
try:
yield
finally:
for thread in threading.enumerate():
if thread not in before:
thread.join(timeout=1.0)
assert not thread.is_alive()
def test_race_between_idle_exit_and_job_assignment(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# This is a lock where the first few times you try to acquire it with a
# timeout, it waits until the lock is available and then pretends to time
# out. Using this in our thread cache implementation causes the following
# sequence:
#
# 1. start_thread_soon grabs the worker thread, assigns it a job, and
# releases its lock.
# 2. The worker thread wakes up (because the lock has been released), but
# the JankyLock lies to it and tells it that the lock timed out. So the
# worker thread tries to exit.
# 3. The worker thread checks for the race between exiting and being
# assigned a job, and discovers that it *is* in the process of being
# assigned a job, so it loops around and tries to acquire the lock
# again.
# 4. Eventually the JankyLock admits that the lock is available, and
# everything proceeds as normal.
class JankyLock:
def __init__(self) -> None:
self._lock = threading.Lock()
self._counter = 3
def acquire(self, timeout: int = -1) -> bool:
got_it = self._lock.acquire(timeout=timeout)
if timeout == -1:
return True
elif got_it:
if self._counter > 0:
self._counter -= 1
self._lock.release()
return False
return True
else:
return False
def release(self) -> None:
self._lock.release()
monkeypatch.setattr(_thread_cache, "Lock", JankyLock)
with _join_started_threads():
tc = ThreadCache()
done = threading.Event()
tc.start_thread_soon(lambda: None, lambda _: done.set())
done.wait()
# Let's kill the thread we started, so it doesn't hang around until the
# test suite finishes. Doesn't really do any harm, but it can be confusing
# to see it in debug output.
monkeypatch.setattr(_thread_cache, "IDLE_TIMEOUT", 0.0001)
tc.start_thread_soon(lambda: None, lambda _: None)
def test_raise_in_deliver(capfd: pytest.CaptureFixture[str]) -> None:
seen_threads = set()
def track_threads() -> None:
seen_threads.add(threading.current_thread())
def deliver(_: object) -> NoReturn:
done.set()
raise RuntimeError("don't do this")
done = threading.Event()
start_thread_soon(track_threads, deliver)
done.wait()
done = threading.Event()
start_thread_soon(track_threads, lambda _: done.set())
done.wait()
assert len(seen_threads) == 1
err = capfd.readouterr().err
assert "don't do this" in err
assert "delivering result" in err

View File

@ -0,0 +1,13 @@
import pytest
from .tutil import check_sequence_matches
def test_check_sequence_matches() -> None:
check_sequence_matches([1, 2, 3], [1, 2, 3])
with pytest.raises(AssertionError):
check_sequence_matches([1, 3, 2], [1, 2, 3])
check_sequence_matches([1, 2, 3, 4], [1, {2, 3}, 4])
check_sequence_matches([1, 3, 2, 4], [1, {2, 3}, 4])
with pytest.raises(AssertionError):
check_sequence_matches([1, 2, 4, 3], [1, {2, 3}, 4])

View File

@ -0,0 +1,154 @@
from __future__ import annotations
import itertools
import pytest
from ... import _core
from ...testing import assert_checkpoints, wait_all_tasks_blocked
pytestmark = pytest.mark.filterwarnings(
"ignore:.*UnboundedQueue:trio.TrioDeprecationWarning",
)
async def test_UnboundedQueue_basic() -> None:
q: _core.UnboundedQueue[str | int | None] = _core.UnboundedQueue()
q.put_nowait("hi")
assert await q.get_batch() == ["hi"]
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
q.put_nowait(1)
q.put_nowait(2)
q.put_nowait(3)
assert q.get_batch_nowait() == [1, 2, 3]
assert q.empty()
assert q.qsize() == 0
q.put_nowait(None)
assert not q.empty()
assert q.qsize() == 1
stats = q.statistics()
assert stats.qsize == 1
assert stats.tasks_waiting == 0
# smoke test
repr(q)
async def test_UnboundedQueue_blocking() -> None:
record = []
q = _core.UnboundedQueue[int]()
async def get_batch_consumer() -> None:
while True:
batch = await q.get_batch()
assert batch
record.append(batch)
async def aiter_consumer() -> None:
async for batch in q:
assert batch
record.append(batch)
for consumer in (get_batch_consumer, aiter_consumer):
record.clear()
async with _core.open_nursery() as nursery:
nursery.start_soon(consumer)
await _core.wait_all_tasks_blocked()
stats = q.statistics()
assert stats.qsize == 0
assert stats.tasks_waiting == 1
q.put_nowait(10)
q.put_nowait(11)
await _core.wait_all_tasks_blocked()
q.put_nowait(12)
await _core.wait_all_tasks_blocked()
assert record == [[10, 11], [12]]
nursery.cancel_scope.cancel()
async def test_UnboundedQueue_fairness() -> None:
q = _core.UnboundedQueue[int]()
# If there's no-one else around, we can put stuff in and take it out
# again, no problem
q.put_nowait(1)
q.put_nowait(2)
assert q.get_batch_nowait() == [1, 2]
result = None
async def get_batch(q: _core.UnboundedQueue[int]) -> None:
nonlocal result
result = await q.get_batch()
# But if someone else is waiting to read, then they get dibs
async with _core.open_nursery() as nursery:
nursery.start_soon(get_batch, q)
await _core.wait_all_tasks_blocked()
q.put_nowait(3)
q.put_nowait(4)
with pytest.raises(_core.WouldBlock):
q.get_batch_nowait()
assert result == [3, 4]
# If two tasks are trying to read, they alternate
record = []
async def reader(name: str) -> None:
while True:
record.append((name, await q.get_batch()))
async with _core.open_nursery() as nursery:
nursery.start_soon(reader, "a")
await _core.wait_all_tasks_blocked()
nursery.start_soon(reader, "b")
await _core.wait_all_tasks_blocked()
for i in range(20):
q.put_nowait(i)
await _core.wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
assert record == list(zip(itertools.cycle("ab"), [[i] for i in range(20)]))
async def test_UnboundedQueue_trivial_yields() -> None:
q = _core.UnboundedQueue[None]()
q.put_nowait(None)
with assert_checkpoints():
await q.get_batch()
q.put_nowait(None)
with assert_checkpoints():
async for _ in q: # pragma: no branch
break
async def test_UnboundedQueue_no_spurious_wakeups() -> None:
# If we have two tasks waiting, and put two items into the queue... then
# only one task wakes up
record = []
async def getter(q: _core.UnboundedQueue[int], i: int) -> None:
got = await q.get_batch()
record.append((i, got))
async with _core.open_nursery() as nursery:
q = _core.UnboundedQueue[int]()
nursery.start_soon(getter, q, 1)
await wait_all_tasks_blocked()
nursery.start_soon(getter, q, 2)
await wait_all_tasks_blocked()
for i in range(10):
q.put_nowait(i)
await wait_all_tasks_blocked()
assert record == [(1, list(range(10)))]
nursery.cancel_scope.cancel()

View File

@ -0,0 +1,299 @@
from __future__ import annotations
import os
import sys
import tempfile
from contextlib import contextmanager
from typing import TYPE_CHECKING
from unittest.mock import create_autospec
import pytest
on_windows = os.name == "nt"
# Mark all the tests in this file as being windows-only
pytestmark = pytest.mark.skipif(not on_windows, reason="windows only")
assert (
sys.platform == "win32" or not TYPE_CHECKING
) # Skip type checking when not on Windows
from ... import _core, sleep
from ...testing import wait_all_tasks_blocked
from .tutil import gc_collect_harder, restore_unraisablehook, slow
if TYPE_CHECKING:
from collections.abc import Generator
from io import BufferedWriter
if on_windows:
from .._windows_cffi import (
INVALID_HANDLE_VALUE,
FileFlags,
Handle,
ffi,
kernel32,
raise_winerror,
)
def test_winerror(monkeypatch: pytest.MonkeyPatch) -> None:
mock = create_autospec(ffi.getwinerror)
monkeypatch.setattr(ffi, "getwinerror", mock)
# Returning none = no error, should not happen.
mock.return_value = None
with pytest.raises(RuntimeError, match=r"^No error set\?$"):
raise_winerror()
mock.assert_called_once_with()
mock.reset_mock()
with pytest.raises(RuntimeError, match=r"^No error set\?$"):
raise_winerror(38)
mock.assert_called_once_with(38)
mock.reset_mock()
mock.return_value = (12, "test error")
with pytest.raises(
OSError,
match=r"^\[WinError 12\] test error: 'file_1' -> 'file_2'$",
) as exc:
raise_winerror(filename="file_1", filename2="file_2")
mock.assert_called_once_with()
mock.reset_mock()
assert exc.value.winerror == 12
assert exc.value.strerror == "test error"
assert exc.value.filename == "file_1"
assert exc.value.filename2 == "file_2"
# With an explicit number passed in, it overrides what getwinerror() returns.
with pytest.raises(
OSError,
match=r"^\[WinError 18\] test error: 'a/file' -> 'b/file'$",
) as exc:
raise_winerror(18, filename="a/file", filename2="b/file")
mock.assert_called_once_with(18)
mock.reset_mock()
assert exc.value.winerror == 18
assert exc.value.strerror == "test error"
assert exc.value.filename == "a/file"
assert exc.value.filename2 == "b/file"
# The undocumented API that this is testing should be changed to stop using
# UnboundedQueue (or just removed until we have time to redo it), but until
# then we filter out the warning.
@pytest.mark.filterwarnings("ignore:.*UnboundedQueue:trio.TrioDeprecationWarning")
async def test_completion_key_listen() -> None:
from .. import _io_windows
async def post(key: int) -> None:
iocp = Handle(ffi.cast("HANDLE", _core.current_iocp()))
for i in range(10):
print("post", i)
if i % 3 == 0:
await _core.checkpoint()
success = kernel32.PostQueuedCompletionStatus(iocp, i, key, ffi.NULL)
assert success
with _core.monitor_completion_key() as (key, queue):
async with _core.open_nursery() as nursery:
nursery.start_soon(post, key)
i = 0
print("loop")
async for batch in queue: # pragma: no branch
print("got some", batch)
for info in batch:
assert isinstance(info, _io_windows.CompletionKeyEventInfo)
assert info.lpOverlapped == 0
assert info.dwNumberOfBytesTransferred == i
i += 1
if i == 10:
break
print("end loop")
async def test_readinto_overlapped() -> None:
data = b"1" * 1024 + b"2" * 1024 + b"3" * 1024 + b"4" * 1024
buffer = bytearray(len(data))
with tempfile.TemporaryDirectory() as tdir:
tfile = os.path.join(tdir, "numbers.txt")
with open( # noqa: ASYNC230 # This is a test, synchronous is ok
tfile,
"wb",
) as fp:
fp.write(data)
fp.flush()
rawname = tfile.encode("utf-16le") + b"\0\0"
rawname_buf = ffi.from_buffer(rawname)
handle = kernel32.CreateFileW(
ffi.cast("LPCWSTR", rawname_buf),
FileFlags.GENERIC_READ,
FileFlags.FILE_SHARE_READ,
ffi.NULL, # no security attributes
FileFlags.OPEN_EXISTING,
FileFlags.FILE_FLAG_OVERLAPPED,
ffi.NULL, # no template file
)
if handle == INVALID_HANDLE_VALUE: # pragma: no cover
raise_winerror()
try:
with memoryview(buffer) as buffer_view:
async def read_region(start: int, end: int) -> None:
await _core.readinto_overlapped(
handle,
buffer_view[start:end],
start,
)
_core.register_with_iocp(handle)
async with _core.open_nursery() as nursery:
for start in range(0, 4096, 512):
nursery.start_soon(read_region, start, start + 512)
assert buffer == data
with pytest.raises((BufferError, TypeError)):
await _core.readinto_overlapped(handle, b"immutable")
finally:
kernel32.CloseHandle(handle)
@contextmanager
def pipe_with_overlapped_read() -> Generator[tuple[BufferedWriter, int], None, None]:
import msvcrt
from asyncio.windows_utils import pipe
read_handle, write_handle = pipe(overlapped=(True, False))
try:
write_fd = msvcrt.open_osfhandle(write_handle, 0)
yield os.fdopen(write_fd, "wb", closefd=False), read_handle
finally:
kernel32.CloseHandle(Handle(ffi.cast("HANDLE", read_handle)))
kernel32.CloseHandle(Handle(ffi.cast("HANDLE", write_handle)))
@restore_unraisablehook()
def test_forgot_to_register_with_iocp() -> None:
with pipe_with_overlapped_read() as (write_fp, read_handle):
with write_fp:
write_fp.write(b"test\n")
left_run_yet = False
async def main() -> None:
target = bytearray(1)
try:
async with _core.open_nursery() as nursery:
nursery.start_soon(
_core.readinto_overlapped,
read_handle,
target,
name="xyz",
)
await wait_all_tasks_blocked()
nursery.cancel_scope.cancel()
finally:
# Run loop is exited without unwinding running tasks, so
# we don't get here until the main() coroutine is GC'ed
assert left_run_yet
with pytest.raises(_core.TrioInternalError) as exc_info:
_core.run(main)
left_run_yet = True
assert "Failed to cancel overlapped I/O in xyz " in str(exc_info.value)
assert "forget to call register_with_iocp()?" in str(exc_info.value)
# Make sure the Nursery.__del__ assertion about dangling children
# gets put with the correct test
del exc_info
gc_collect_harder()
@slow
async def test_too_late_to_cancel() -> None:
import time
with pipe_with_overlapped_read() as (write_fp, read_handle):
_core.register_with_iocp(read_handle)
target = bytearray(6)
async with _core.open_nursery() as nursery:
# Start an async read in the background
nursery.start_soon(_core.readinto_overlapped, read_handle, target)
await wait_all_tasks_blocked()
# Synchronous write to the other end of the pipe
with write_fp:
write_fp.write(b"test1\ntest2\n")
# Note: not trio.sleep! We're making sure the OS level
# ReadFile completes, before Trio has a chance to execute
# another checkpoint and notice it completed.
time.sleep(1) # noqa: ASYNC251
nursery.cancel_scope.cancel()
assert target[:6] == b"test1\n"
# Do another I/O to make sure we've actually processed the
# fallback completion that was posted when CancelIoEx failed.
assert await _core.readinto_overlapped(read_handle, target) == 6
assert target[:6] == b"test2\n"
def test_lsp_that_hooks_select_gives_good_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from .. import _io_windows
from .._windows_cffi import CData, WSAIoctls, _handle
def patched_get_underlying(
sock: int | CData,
*,
which: int = WSAIoctls.SIO_BASE_HANDLE,
) -> CData:
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BSP_HANDLE_SELECT:
return _handle(sock + 1)
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError,
match="SIO_BASE_HANDLE and SIO_BSP_HANDLE_SELECT differ",
):
_core.run(sleep, 0)
def test_lsp_that_completely_hides_base_socket_gives_good_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
# This tests behavior with an LSP that fails SIO_BASE_HANDLE and returns
# self for SIO_BSP_HANDLE_SELECT (like Komodia), but also returns
# self for SIO_BSP_HANDLE_POLL. No known LSP does this, but we want to
# make sure we get an error rather than an infinite loop.
from .. import _io_windows
from .._windows_cffi import CData, WSAIoctls, _handle
def patched_get_underlying(
sock: int | CData,
*,
which: int = WSAIoctls.SIO_BASE_HANDLE,
) -> CData:
if hasattr(sock, "fileno"): # pragma: no branch
sock = sock.fileno()
if which == WSAIoctls.SIO_BASE_HANDLE:
raise OSError("nope")
else:
return _handle(sock)
monkeypatch.setattr(_io_windows, "_get_underlying_socket", patched_get_underlying)
with pytest.raises(
RuntimeError,
match="SIO_BASE_HANDLE failed and SIO_BSP_HANDLE_POLL didn't return a diff",
):
_core.run(sleep, 0)

View File

@ -0,0 +1,117 @@
# Utilities for testing
from __future__ import annotations
import asyncio
import gc
import os
import socket as stdlib_socket
import sys
import warnings
from contextlib import closing, contextmanager
from typing import TYPE_CHECKING, TypeVar
import pytest
# See trio/_tests/conftest.py for the other half of this
from trio._tests.pytest_plugin import RUN_SLOW
if TYPE_CHECKING:
from collections.abc import Generator, Iterable, Sequence
slow = pytest.mark.skipif(not RUN_SLOW, reason="use --run-slow to run slow tests")
T = TypeVar("T")
try:
s = stdlib_socket.socket(stdlib_socket.AF_INET6, stdlib_socket.SOCK_STREAM, 0)
except OSError: # pragma: no cover
# Some systems don't even support creating an IPv6 socket, let alone
# binding it. (ex: Linux with 'ipv6.disable=1' in the kernel command line)
# We don't have any of those in our CI, and there's nothing that gets
# tested _only_ if can_create_ipv6 = False, so we'll just no-cover this.
can_create_ipv6 = False
can_bind_ipv6 = False
else:
can_create_ipv6 = True
with s:
try:
s.bind(("::1", 0))
except OSError: # pragma: no cover # since support for 3.7 was removed
can_bind_ipv6 = False
else:
can_bind_ipv6 = True
creates_ipv6 = pytest.mark.skipif(not can_create_ipv6, reason="need IPv6")
binds_ipv6 = pytest.mark.skipif(not can_bind_ipv6, reason="need IPv6")
def gc_collect_harder() -> None:
# In the test suite we sometimes want to call gc.collect() to make sure
# that any objects with noisy __del__ methods (e.g. unawaited coroutines)
# get collected before we continue, so their noise doesn't leak into
# unrelated tests.
#
# On PyPy, coroutine objects (for example) can survive at least 1 round of
# garbage collection, because executing their __del__ method to print the
# warning can cause them to be resurrected. So we call collect a few times
# to make sure.
for _ in range(5):
gc.collect()
# Some of our tests need to leak coroutines, and thus trigger the
# "RuntimeWarning: coroutine '...' was never awaited" message. This context
# manager should be used anywhere this happens to hide those messages, because
# when expected they're clutter.
@contextmanager
def ignore_coroutine_never_awaited_warnings() -> Generator[None, None, None]:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="coroutine '.*' was never awaited")
try:
yield
finally:
# Make sure to trigger any coroutine __del__ methods now, before
# we leave the context manager.
gc_collect_harder()
def _noop(*args: object, **kwargs: object) -> None:
pass # pragma: no cover
@contextmanager
def restore_unraisablehook() -> Generator[None, None, None]:
sys.unraisablehook, prev = sys.__unraisablehook__, sys.unraisablehook
try:
yield
finally:
sys.unraisablehook = prev
# Used to check sequences that might have some elements out of order.
# Example usage:
# The sequences [1, 2.1, 2.2, 3] and [1, 2.2, 2.1, 3] are both
# matched by the template [1, {2.1, 2.2}, 3]
def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) -> None:
i = 0
for pattern in template:
if not isinstance(pattern, set):
pattern = {pattern}
got = set(seq[i : i + len(pattern)])
assert got == pattern
i += len(got)
# https://bugs.freebsd.org/bugzilla/show_bug.cgi?id=246350
skip_if_fbsd_pipes_broken = pytest.mark.skipif(
sys.platform != "win32" # prevent mypy from complaining about missing uname
and hasattr(os, "uname")
and os.uname().sysname == "FreeBSD"
and os.uname().release[:4] < "12.2",
reason="hangs on FreeBSD 12.1 and earlier, due to FreeBSD bug #246350",
)
def create_asyncio_future_in_new_loop() -> asyncio.Future[object]:
with closing(asyncio.new_event_loop()) as loop:
return loop.create_future()

View File

@ -0,0 +1,76 @@
"""Test variadic generic typing for Nursery.start[_soon]()."""
from typing import Awaitable, Callable
from trio import TASK_STATUS_IGNORED, Nursery, TaskStatus
async def task_0() -> None: ...
async def task_1a(value: int) -> None: ...
async def task_1b(value: str) -> None: ...
async def task_2a(a: int, b: str) -> None: ...
async def task_2b(a: str, b: int) -> None: ...
async def task_2c(a: str, b: int, optional: bool = False) -> None: ...
async def task_requires_kw(a: int, *, b: bool) -> None: ...
async def task_startable_1(
a: str,
*,
task_status: TaskStatus[bool] = TASK_STATUS_IGNORED,
) -> None: ...
async def task_startable_2(
a: str,
b: float,
*,
task_status: TaskStatus[bool] = TASK_STATUS_IGNORED,
) -> None: ...
async def task_requires_start(*, task_status: TaskStatus[str]) -> None:
"""Check a function requiring start() to be used."""
async def task_pos_or_kw(value: str, task_status: TaskStatus[int]) -> None:
"""Check a function which doesn't use the *-syntax works."""
def check_start_soon(nursery: Nursery) -> None:
"""start_soon() functionality."""
nursery.start_soon(task_0)
nursery.start_soon(task_1a) # type: ignore
nursery.start_soon(task_2b) # type: ignore
nursery.start_soon(task_0, 45) # type: ignore
nursery.start_soon(task_1a, 32)
nursery.start_soon(task_1b, 32) # type: ignore
nursery.start_soon(task_1a, "abc") # type: ignore
nursery.start_soon(task_1b, "abc")
nursery.start_soon(task_2b, "abc") # type: ignore
nursery.start_soon(task_2a, 38, "46")
nursery.start_soon(task_2c, "abc", 12, True)
nursery.start_soon(task_2c, "abc", 12)
task_2c_cast: Callable[[str, int], Awaitable[object]] = (
task_2c # The assignment makes it work.
)
nursery.start_soon(task_2c_cast, "abc", 12)
nursery.start_soon(task_requires_kw, 12, True) # type: ignore
# Tasks following the start() API can be made to work.
nursery.start_soon(task_startable_1, "cdf")

View File

@ -0,0 +1,48 @@
from __future__ import annotations
from typing import Sequence, overload
import trio
from typing_extensions import assert_type
async def sleep_sort(values: Sequence[float]) -> list[float]:
return [1]
async def has_optional(arg: int | None = None) -> int:
return 5
@overload
async def foo_overloaded(arg: int) -> str: ...
@overload
async def foo_overloaded(arg: str) -> int: ...
async def foo_overloaded(arg: int | str) -> int | str:
if isinstance(arg, str):
return 5
return "hello"
v = trio.run(
sleep_sort,
(1, 3, 5, 2, 4),
clock=trio.testing.MockClock(autojump_threshold=0),
)
assert_type(v, "list[float]")
trio.run(sleep_sort, ["hi", "there"]) # type: ignore[arg-type]
trio.run(sleep_sort) # type: ignore[arg-type]
r = trio.run(has_optional)
assert_type(r, int)
r = trio.run(has_optional, 5)
trio.run(has_optional, 7, 8) # type: ignore[arg-type]
trio.run(has_optional, "hello") # type: ignore[arg-type]
assert_type(trio.run(foo_overloaded, 5), str)
assert_type(trio.run(foo_overloaded, ""), int)

View File

@ -0,0 +1,295 @@
from __future__ import annotations
import ctypes
import ctypes.util
import sys
import traceback
from functools import partial
from itertools import count
from threading import Lock, Thread
from typing import Any, Callable, Generic, TypeVar
import outcome
RetT = TypeVar("RetT")
def _to_os_thread_name(name: str) -> bytes:
# ctypes handles the trailing \00
return name.encode("ascii", errors="replace")[:15]
# used to construct the method used to set os thread name, or None, depending on platform.
# called once on import
def get_os_thread_name_func() -> Callable[[int | None, str], None] | None:
def namefunc(
setname: Callable[[int, bytes], int],
ident: int | None,
name: str,
) -> None:
# Thread.ident is None "if it has not been started". Unclear if that can happen
# with current usage.
if ident is not None: # pragma: no cover
setname(ident, _to_os_thread_name(name))
# namefunc on Mac also takes an ident, even if pthread_setname_np doesn't/can't use it
# so the caller don't need to care about platform.
def darwin_namefunc(
setname: Callable[[bytes], int],
ident: int | None,
name: str,
) -> None:
# I don't know if Mac can rename threads that hasn't been started, but default
# to no to be on the safe side.
if ident is not None: # pragma: no cover
setname(_to_os_thread_name(name))
# find the pthread library
# this will fail on windows and musl
libpthread_path = ctypes.util.find_library("pthread")
if not libpthread_path:
# musl includes pthread functions directly in libc.so
# (but note that find_library("c") does not work on musl,
# see: https://github.com/python/cpython/issues/65821)
# so try that library instead
# if it doesn't exist, CDLL() will fail below
libpthread_path = "libc.so"
# Sometimes windows can find the path, but gives a permission error when
# accessing it. Catching a wider exception in case of more esoteric errors.
# https://github.com/python-trio/trio/issues/2688
try:
libpthread = ctypes.CDLL(libpthread_path)
except Exception: # pragma: no cover
return None
# get the setname method from it
# afaik this should never fail
pthread_setname_np = getattr(libpthread, "pthread_setname_np", None)
if pthread_setname_np is None: # pragma: no cover
return None
# specify function prototype
pthread_setname_np.restype = ctypes.c_int
# on mac OSX pthread_setname_np does not take a thread id,
# it only lets threads name themselves, which is not a problem for us.
# Just need to make sure to call it correctly
if sys.platform == "darwin":
pthread_setname_np.argtypes = [ctypes.c_char_p]
return partial(darwin_namefunc, pthread_setname_np)
# otherwise assume linux parameter conventions. Should also work on *BSD
pthread_setname_np.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
return partial(namefunc, pthread_setname_np)
# construct os thread name method
set_os_thread_name = get_os_thread_name_func()
# The "thread cache" is a simple unbounded thread pool, i.e., it automatically
# spawns as many threads as needed to handle all the requests its given. Its
# only purpose is to cache worker threads so that they don't have to be
# started from scratch every time we want to delegate some work to a thread.
# It's expected that some higher-level code will track how many threads are in
# use to avoid overwhelming the system (e.g. the limiter= argument to
# trio.to_thread.run_sync).
#
# To maximize sharing, there's only one thread cache per process, even if you
# have multiple calls to trio.run.
#
# Guarantees:
#
# It's safe to call start_thread_soon simultaneously from
# multiple threads.
#
# Idle threads are chosen in LIFO order, i.e. we *don't* spread work evenly
# over all threads. Instead we try to let some threads do most of the work
# while others sit idle as much as possible. Compared to FIFO, this has better
# memory cache behavior, and it makes it easier to detect when we have too
# many threads, so idle ones can exit.
#
# This code assumes that 'dict' has the following properties:
#
# - __setitem__, __delitem__, and popitem are all thread-safe and atomic with
# respect to each other. This is guaranteed by the GIL.
#
# - popitem returns the most-recently-added item (i.e., __setitem__ + popitem
# give you a LIFO queue). This relies on dicts being insertion-ordered, like
# they are in py36+.
# How long a thread will idle waiting for new work before gives up and exits.
# This value is pretty arbitrary; I don't think it matters too much.
IDLE_TIMEOUT = 10 # seconds
name_counter = count()
class WorkerThread(Generic[RetT]):
def __init__(self, thread_cache: ThreadCache) -> None:
self._job: (
tuple[
Callable[[], RetT],
Callable[[outcome.Outcome[RetT]], object],
str | None,
]
| None
) = None
self._thread_cache = thread_cache
# This Lock is used in an unconventional way.
#
# "Unlocked" means we have a pending job that's been assigned to us;
# "locked" means that we don't.
#
# Initially we have no job, so it starts out in locked state.
self._worker_lock = Lock()
self._worker_lock.acquire()
self._default_name = f"Trio thread {next(name_counter)}"
self._thread = Thread(target=self._work, name=self._default_name, daemon=True)
if set_os_thread_name:
set_os_thread_name(self._thread.ident, self._default_name)
self._thread.start()
def _handle_job(self) -> None:
# Handle job in a separate method to ensure user-created
# objects are cleaned up in a consistent manner.
assert self._job is not None
fn, deliver, name = self._job
self._job = None
# set name
if name is not None:
self._thread.name = name
if set_os_thread_name:
set_os_thread_name(self._thread.ident, name)
result = outcome.capture(fn)
# reset name if it was changed
if name is not None:
self._thread.name = self._default_name
if set_os_thread_name:
set_os_thread_name(self._thread.ident, self._default_name)
# Tell the cache that we're available to be assigned a new
# job. We do this *before* calling 'deliver', so that if
# 'deliver' triggers a new job, it can be assigned to us
# instead of spawning a new thread.
self._thread_cache._idle_workers[self] = None
try:
deliver(result)
except BaseException as e:
print("Exception while delivering result of thread", file=sys.stderr)
traceback.print_exception(type(e), e, e.__traceback__)
def _work(self) -> None:
while True:
if self._worker_lock.acquire(timeout=IDLE_TIMEOUT):
# We got a job
self._handle_job()
else:
# Timeout acquiring lock, so we can probably exit. But,
# there's a race condition: we might be assigned a job *just*
# as we're about to exit. So we have to check.
try:
del self._thread_cache._idle_workers[self]
except KeyError:
# Someone else removed us from the idle worker queue, so
# they must be in the process of assigning us a job - loop
# around and wait for it.
continue
else:
# We successfully removed ourselves from the idle
# worker queue, so no more jobs are incoming; it's safe to
# exit.
return
class ThreadCache:
def __init__(self) -> None:
self._idle_workers: dict[WorkerThread[Any], None] = {}
def start_thread_soon(
self,
fn: Callable[[], RetT],
deliver: Callable[[outcome.Outcome[RetT]], object],
name: str | None = None,
) -> None:
worker: WorkerThread[RetT]
try:
worker, _ = self._idle_workers.popitem()
except KeyError:
worker = WorkerThread(self)
worker._job = (fn, deliver, name)
worker._worker_lock.release()
THREAD_CACHE = ThreadCache()
def start_thread_soon(
fn: Callable[[], RetT],
deliver: Callable[[outcome.Outcome[RetT]], object],
name: str | None = None,
) -> None:
"""Runs ``deliver(outcome.capture(fn))`` in a worker thread.
Generally ``fn`` does some blocking work, and ``deliver`` delivers the
result back to whoever is interested.
This is a low-level, no-frills interface, very similar to using
`threading.Thread` to spawn a thread directly. The main difference is
that this function tries to reuse threads when possible, so it can be
a bit faster than `threading.Thread`.
Worker threads have the `~threading.Thread.daemon` flag set, which means
that if your main thread exits, worker threads will automatically be
killed. If you want to make sure that your ``fn`` runs to completion, then
you should make sure that the main thread remains alive until ``deliver``
is called.
It is safe to call this function simultaneously from multiple threads.
Args:
fn (sync function): Performs arbitrary blocking work.
deliver (sync function): Takes the `outcome.Outcome` of ``fn``, and
delivers it. *Must not block.*
Because worker threads are cached and reused for multiple calls, neither
function should mutate thread-level state, like `threading.local` objects
or if they do, they should be careful to revert their changes before
returning.
Note:
The split between ``fn`` and ``deliver`` serves two purposes. First,
it's convenient, since most callers need something like this anyway.
Second, it avoids a small race condition that could cause too many
threads to be spawned. Consider a program that wants to run several
jobs sequentially on a thread, so the main thread submits a job, waits
for it to finish, submits another job, etc. In theory, this program
should only need one worker thread. But what could happen is:
1. Worker thread: First job finishes, and calls ``deliver``.
2. Main thread: receives notification that the job finished, and calls
``start_thread_soon``.
3. Main thread: sees that no worker threads are marked idle, so spawns
a second worker thread.
4. Original worker thread: marks itself as idle.
To avoid this, threads mark themselves as idle *before* calling
``deliver``.
Is this potential extra thread a major problem? Maybe not, but it's
easy enough to avoid, and we figure that if the user is trying to
limit how many threads they're using then it's polite to respect that.
"""
THREAD_CACHE.start_thread_soon(fn, deliver, name)

View File

@ -0,0 +1,287 @@
"""These are the only functions that ever yield back to the task runner."""
from __future__ import annotations
import enum
import types
from typing import TYPE_CHECKING, Any, Callable, NoReturn
import attrs
import outcome
from . import _run
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from ._run import Task
# Helper for the bottommost 'yield'. You can't use 'yield' inside an async
# function, but you can inside a generator, and if you decorate your generator
# with @types.coroutine, then it's even awaitable. However, it's still not a
# real async function: in particular, it isn't recognized by
# inspect.iscoroutinefunction, and it doesn't trigger the unawaited coroutine
# tracking machinery. Since our traps are public APIs, we make them real async
# functions, and then this helper takes care of the actual yield:
@types.coroutine
def _async_yield(obj: Any) -> Any: # type: ignore[misc]
return (yield obj)
# This class object is used as a singleton.
# Not exported in the trio._core namespace, but imported directly by _run.
class CancelShieldedCheckpoint:
pass
async def cancel_shielded_checkpoint() -> None:
"""Introduce a schedule point, but not a cancel point.
This is *not* a :ref:`checkpoint <checkpoints>`, but it is half of a
checkpoint, and when combined with :func:`checkpoint_if_cancelled` it can
make a full checkpoint.
Equivalent to (but potentially more efficient than)::
with trio.CancelScope(shield=True):
await trio.lowlevel.checkpoint()
"""
(await _async_yield(CancelShieldedCheckpoint)).unwrap()
# Return values for abort functions
class Abort(enum.Enum):
""":class:`enum.Enum` used as the return value from abort functions.
See :func:`wait_task_rescheduled` for details.
.. data:: SUCCEEDED
FAILED
"""
SUCCEEDED = 1
FAILED = 2
# Not exported in the trio._core namespace, but imported directly by _run.
@attrs.frozen(slots=False)
class WaitTaskRescheduled:
abort_func: Callable[[RaiseCancelT], Abort]
RaiseCancelT: TypeAlias = Callable[[], NoReturn]
# Should always return the type a Task "expects", unless you willfully reschedule it
# with a bad value.
async def wait_task_rescheduled(abort_func: Callable[[RaiseCancelT], Abort]) -> Any:
"""Put the current task to sleep, with cancellation support.
This is the lowest-level API for blocking in Trio. Every time a
:class:`~trio.lowlevel.Task` blocks, it does so by calling this function
(usually indirectly via some higher-level API).
This is a tricky interface with no guard rails. If you can use
:class:`ParkingLot` or the built-in I/O wait functions instead, then you
should.
Generally the way it works is that before calling this function, you make
arrangements for "someone" to call :func:`reschedule` on the current task
at some later point.
Then you call :func:`wait_task_rescheduled`, passing in ``abort_func``, an
"abort callback".
(Terminology: in Trio, "aborting" is the process of attempting to
interrupt a blocked task to deliver a cancellation.)
There are two possibilities for what happens next:
1. "Someone" calls :func:`reschedule` on the current task, and
:func:`wait_task_rescheduled` returns or raises whatever value or error
was passed to :func:`reschedule`.
2. The call's context transitions to a cancelled state (e.g. due to a
timeout expiring). When this happens, the ``abort_func`` is called. Its
interface looks like::
def abort_func(raise_cancel):
...
return trio.lowlevel.Abort.SUCCEEDED # or FAILED
It should attempt to clean up any state associated with this call, and
in particular, arrange that :func:`reschedule` will *not* be called
later. If (and only if!) it is successful, then it should return
:data:`Abort.SUCCEEDED`, in which case the task will automatically be
rescheduled with an appropriate :exc:`~trio.Cancelled` error.
Otherwise, it should return :data:`Abort.FAILED`. This means that the
task can't be cancelled at this time, and still has to make sure that
"someone" eventually calls :func:`reschedule`.
At that point there are again two possibilities. You can simply ignore
the cancellation altogether: wait for the operation to complete and
then reschedule and continue as normal. (For example, this is what
:func:`trio.to_thread.run_sync` does if cancellation is disabled.)
The other possibility is that the ``abort_func`` does succeed in
cancelling the operation, but for some reason isn't able to report that
right away. (Example: on Windows, it's possible to request that an
async ("overlapped") I/O operation be cancelled, but this request is
*also* asynchronous you don't find out until later whether the
operation was actually cancelled or not.) To report a delayed
cancellation, then you should reschedule the task yourself, and call
the ``raise_cancel`` callback passed to ``abort_func`` to raise a
:exc:`~trio.Cancelled` (or possibly :exc:`KeyboardInterrupt`) exception
into this task. Either of the approaches sketched below can work::
# Option 1:
# Catch the exception from raise_cancel and inject it into the task.
# (This is what Trio does automatically for you if you return
# Abort.SUCCEEDED.)
trio.lowlevel.reschedule(task, outcome.capture(raise_cancel))
# Option 2:
# wait to be woken by "someone", and then decide whether to raise
# the error from inside the task.
outer_raise_cancel = None
def abort(inner_raise_cancel):
nonlocal outer_raise_cancel
outer_raise_cancel = inner_raise_cancel
TRY_TO_CANCEL_OPERATION()
return trio.lowlevel.Abort.FAILED
await wait_task_rescheduled(abort)
if OPERATION_WAS_SUCCESSFULLY_CANCELLED:
# raises the error
outer_raise_cancel()
In any case it's guaranteed that we only call the ``abort_func`` at most
once per call to :func:`wait_task_rescheduled`.
Sometimes, it's useful to be able to share some mutable sleep-related data
between the sleeping task, the abort function, and the waking task. You
can use the sleeping task's :data:`~Task.custom_sleep_data` attribute to
store this data, and Trio won't touch it, except to make sure that it gets
cleared when the task is rescheduled.
.. warning::
If your ``abort_func`` raises an error, or returns any value other than
:data:`Abort.SUCCEEDED` or :data:`Abort.FAILED`, then Trio will crash
violently. Be careful! Similarly, it is entirely possible to deadlock a
Trio program by failing to reschedule a blocked task, or cause havoc by
calling :func:`reschedule` too many times. Remember what we said up
above about how you should use a higher-level API if at all possible?
"""
return (await _async_yield(WaitTaskRescheduled(abort_func))).unwrap()
# Not exported in the trio._core namespace, but imported directly by _run.
@attrs.frozen(slots=False)
class PermanentlyDetachCoroutineObject:
final_outcome: outcome.Outcome[Any]
async def permanently_detach_coroutine_object(
final_outcome: outcome.Outcome[Any],
) -> Any:
"""Permanently detach the current task from the Trio scheduler.
Normally, a Trio task doesn't exit until its coroutine object exits. When
you call this function, Trio acts like the coroutine object just exited
and the task terminates with the given outcome. This is useful if you want
to permanently switch the coroutine object over to a different coroutine
runner.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
You should make sure that the coroutine object has released any
Trio-specific resources it has acquired (e.g. nurseries).
Args:
final_outcome (outcome.Outcome): Trio acts as if the current task exited
with the given return value or exception.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
if _run.current_task().child_nurseries:
raise RuntimeError(
"can't permanently detach a coroutine object with open nurseries",
)
return await _async_yield(PermanentlyDetachCoroutineObject(final_outcome))
async def temporarily_detach_coroutine_object(
abort_func: Callable[[RaiseCancelT], Abort],
) -> Any:
"""Temporarily detach the current coroutine object from the Trio
scheduler.
When the calling coroutine enters this function it's running under Trio,
and when the function returns it's running under the foreign coroutine
runner.
The Trio :class:`Task` will continue to exist, but will be suspended until
you use :func:`reattach_detached_coroutine_object` to resume it. In the
mean time, you can use another coroutine runner to schedule the coroutine
object. In fact, you have to the function doesn't return until the
coroutine is advanced from outside.
Note that you'll need to save the current :class:`Task` object to later
resume; you can retrieve it with :func:`current_task`. You can also use
this :class:`Task` object to retrieve the coroutine object see
:data:`Task.coro`.
Args:
abort_func: Same as for :func:`wait_task_rescheduled`, except that it
must return :data:`Abort.FAILED`. (If it returned
:data:`Abort.SUCCEEDED`, then Trio would attempt to reschedule the
detached task directly without going through
:func:`reattach_detached_coroutine_object`, which would be bad.)
Your ``abort_func`` should still arrange for whatever the coroutine
object is doing to be cancelled, and then reattach to Trio and call
the ``raise_cancel`` callback, if possible.
Returns or raises whatever value or exception the new coroutine runner
uses to resume the coroutine.
"""
return await _async_yield(WaitTaskRescheduled(abort_func))
async def reattach_detached_coroutine_object(task: Task, yield_value: object) -> None:
"""Reattach a coroutine object that was detached using
:func:`temporarily_detach_coroutine_object`.
When the calling coroutine enters this function it's running under the
foreign coroutine runner, and when the function returns it's running under
Trio.
This must be called from inside the coroutine being resumed, and yields
whatever value you pass in. (Presumably you'll pass a value that will
cause the current coroutine runner to stop scheduling this task.) Then the
coroutine is resumed by the Trio scheduler at the next opportunity.
Args:
task (Task): The Trio task object that the current coroutine was
detached from.
yield_value (object): The object to yield to the current coroutine
runner.
"""
# This is a kind of crude check in particular, it can fail if the
# passed-in task is where the coroutine *runner* is running. But this is
# an experts-only interface, and there's no easy way to do a more accurate
# check, so I guess that's OK.
if not task.coro.cr_running:
raise RuntimeError("given task does not match calling coroutine")
_run.reschedule(task, outcome.Value("reattaching"))
value = await _async_yield(yield_value)
assert value == outcome.Value("reattaching")

View File

@ -0,0 +1,163 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic, TypeVar
import attrs
from .. import _core
from .._deprecate import deprecated
from .._util import final
T = TypeVar("T")
if TYPE_CHECKING:
from typing_extensions import Self
@attrs.frozen
class UnboundedQueueStatistics:
"""An object containing debugging information.
Currently, the following fields are defined:
* ``qsize``: The number of items currently in the queue.
* ``tasks_waiting``: The number of tasks blocked on this queue's
:meth:`get_batch` method.
"""
qsize: int
tasks_waiting: int
@final
class UnboundedQueue(Generic[T]):
"""An unbounded queue suitable for certain unusual forms of inter-task
communication.
This class is designed for use as a queue in cases where the producer for
some reason cannot be subjected to back-pressure, i.e., :meth:`put_nowait`
has to always succeed. In order to prevent the queue backlog from actually
growing without bound, the consumer API is modified to dequeue items in
"batches". If a consumer task processes each batch without yielding, then
this helps achieve (but does not guarantee) an effective bound on the
queue's memory use, at the cost of potentially increasing system latencies
in general. You should generally prefer to use a memory channel
instead if you can.
Currently each batch completely empties the queue, but `this may change in
the future <https://github.com/python-trio/trio/issues/51>`__.
A :class:`UnboundedQueue` object can be used as an asynchronous iterator,
where each iteration returns a new batch of items. I.e., these two loops
are equivalent::
async for batch in queue:
...
while True:
obj = await queue.get_batch()
...
"""
@deprecated(
"0.9.0",
issue=497,
thing="trio.lowlevel.UnboundedQueue",
instead="trio.open_memory_channel(math.inf)",
use_triodeprecationwarning=True,
)
def __init__(self) -> None:
self._lot = _core.ParkingLot()
self._data: list[T] = []
# used to allow handoff from put to the first task in the lot
self._can_get = False
def __repr__(self) -> str:
return f"<UnboundedQueue holding {len(self._data)} items>"
def qsize(self) -> int:
"""Returns the number of items currently in the queue."""
return len(self._data)
def empty(self) -> bool:
"""Returns True if the queue is empty, False otherwise.
There is some subtlety to interpreting this method's return value: see
`issue #63 <https://github.com/python-trio/trio/issues/63>`__.
"""
return not self._data
@_core.enable_ki_protection
def put_nowait(self, obj: T) -> None:
"""Put an object into the queue, without blocking.
This always succeeds, because the queue is unbounded. We don't provide
a blocking ``put`` method, because it would never need to block.
Args:
obj (object): The object to enqueue.
"""
if not self._data:
assert not self._can_get
if self._lot:
self._lot.unpark(count=1)
else:
self._can_get = True
self._data.append(obj)
def _get_batch_protected(self) -> list[T]:
data = self._data.copy()
self._data.clear()
self._can_get = False
return data
def get_batch_nowait(self) -> list[T]:
"""Attempt to get the next batch from the queue, without blocking.
Returns:
list: A list of dequeued items, in order. On a successful call this
list is always non-empty; if it would be empty we raise
:exc:`~trio.WouldBlock` instead.
Raises:
~trio.WouldBlock: if the queue is empty.
"""
if not self._can_get:
raise _core.WouldBlock
return self._get_batch_protected()
async def get_batch(self) -> list[T]:
"""Get the next batch from the queue, blocking as necessary.
Returns:
list: A list of dequeued items, in order. This list is always
non-empty.
"""
await _core.checkpoint_if_cancelled()
if not self._can_get:
await self._lot.park()
return self._get_batch_protected()
else:
try:
return self._get_batch_protected()
finally:
await _core.cancel_shielded_checkpoint()
def statistics(self) -> UnboundedQueueStatistics:
"""Return an :class:`UnboundedQueueStatistics` object containing debugging information."""
return UnboundedQueueStatistics(
qsize=len(self._data),
tasks_waiting=self._lot.statistics().tasks_waiting,
)
def __aiter__(self) -> Self:
return self
async def __anext__(self) -> list[T]:
return await self.get_batch()

View File

@ -0,0 +1,75 @@
from __future__ import annotations
import contextlib
import signal
import socket
import warnings
from .. import _core
from .._util import is_main_thread
class WakeupSocketpair:
def __init__(self) -> None:
# explicitly typed to please `pyright --verifytypes` without `--ignoreexternal`
self.wakeup_sock: socket.socket
self.write_sock: socket.socket
self.wakeup_sock, self.write_sock = socket.socketpair()
self.wakeup_sock.setblocking(False)
self.write_sock.setblocking(False)
# This somewhat reduces the amount of memory wasted queueing up data
# for wakeups. With these settings, maximum number of 1-byte sends
# before getting BlockingIOError:
# Linux 4.8: 6
# macOS (darwin 15.5): 1
# Windows 10: 525347
# Windows you're weird. (And on Windows setting SNDBUF to 0 makes send
# blocking, even on non-blocking sockets, so don't do that.)
self.wakeup_sock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1)
self.write_sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1)
# On Windows this is a TCP socket so this might matter. On other
# platforms this fails b/c AF_UNIX sockets aren't actually TCP.
with contextlib.suppress(OSError):
self.write_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
self.old_wakeup_fd: int | None = None
def wakeup_thread_and_signal_safe(self) -> None:
with contextlib.suppress(BlockingIOError):
self.write_sock.send(b"\x00")
async def wait_woken(self) -> None:
await _core.wait_readable(self.wakeup_sock)
self.drain()
def drain(self) -> None:
try:
while True:
self.wakeup_sock.recv(2**16)
except BlockingIOError:
pass
def wakeup_on_signals(self) -> None:
assert self.old_wakeup_fd is None
if not is_main_thread():
return
fd = self.write_sock.fileno()
self.old_wakeup_fd = signal.set_wakeup_fd(fd, warn_on_full_buffer=False)
if self.old_wakeup_fd != -1:
warnings.warn(
RuntimeWarning(
"It looks like Trio's signal handling code might have "
"collided with another library you're using. If you're "
"running Trio in guest mode, then this might mean you "
"should set host_uses_signal_set_wakeup_fd=True. "
"Otherwise, file a bug on Trio and we'll help you figure "
"out what's going on.",
),
stacklevel=1,
)
def close(self) -> None:
self.wakeup_sock.close()
self.write_sock.close()
if self.old_wakeup_fd is not None:
signal.set_wakeup_fd(self.old_wakeup_fd)

View File

@ -0,0 +1,520 @@
from __future__ import annotations
import enum
import re
from typing import TYPE_CHECKING, NewType, NoReturn, Protocol, cast
if TYPE_CHECKING:
from typing_extensions import TypeAlias
import cffi
################################################################
# Functions and types
################################################################
LIB = """
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa383751(v=vs.85).aspx
typedef int BOOL;
typedef unsigned char BYTE;
typedef BYTE BOOLEAN;
typedef void* PVOID;
typedef PVOID HANDLE;
typedef unsigned long DWORD;
typedef unsigned long ULONG;
typedef unsigned int NTSTATUS;
typedef unsigned long u_long;
typedef ULONG *PULONG;
typedef const void *LPCVOID;
typedef void *LPVOID;
typedef const wchar_t *LPCWSTR;
typedef uintptr_t ULONG_PTR;
typedef uintptr_t UINT_PTR;
typedef UINT_PTR SOCKET;
typedef struct _OVERLAPPED {
ULONG_PTR Internal;
ULONG_PTR InternalHigh;
union {
struct {
DWORD Offset;
DWORD OffsetHigh;
} DUMMYSTRUCTNAME;
PVOID Pointer;
} DUMMYUNIONNAME;
HANDLE hEvent;
} OVERLAPPED, *LPOVERLAPPED;
typedef OVERLAPPED WSAOVERLAPPED;
typedef LPOVERLAPPED LPWSAOVERLAPPED;
typedef PVOID LPSECURITY_ATTRIBUTES;
typedef PVOID LPCSTR;
typedef struct _OVERLAPPED_ENTRY {
ULONG_PTR lpCompletionKey;
LPOVERLAPPED lpOverlapped;
ULONG_PTR Internal;
DWORD dwNumberOfBytesTransferred;
} OVERLAPPED_ENTRY, *LPOVERLAPPED_ENTRY;
// kernel32.dll
HANDLE WINAPI CreateIoCompletionPort(
_In_ HANDLE FileHandle,
_In_opt_ HANDLE ExistingCompletionPort,
_In_ ULONG_PTR CompletionKey,
_In_ DWORD NumberOfConcurrentThreads
);
BOOL SetFileCompletionNotificationModes(
HANDLE FileHandle,
UCHAR Flags
);
HANDLE CreateFileW(
LPCWSTR lpFileName,
DWORD dwDesiredAccess,
DWORD dwShareMode,
LPSECURITY_ATTRIBUTES lpSecurityAttributes,
DWORD dwCreationDisposition,
DWORD dwFlagsAndAttributes,
HANDLE hTemplateFile
);
BOOL WINAPI CloseHandle(
_In_ HANDLE hObject
);
BOOL WINAPI PostQueuedCompletionStatus(
_In_ HANDLE CompletionPort,
_In_ DWORD dwNumberOfBytesTransferred,
_In_ ULONG_PTR dwCompletionKey,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WINAPI GetQueuedCompletionStatusEx(
_In_ HANDLE CompletionPort,
_Out_ LPOVERLAPPED_ENTRY lpCompletionPortEntries,
_In_ ULONG ulCount,
_Out_ PULONG ulNumEntriesRemoved,
_In_ DWORD dwMilliseconds,
_In_ BOOL fAlertable
);
BOOL WINAPI CancelIoEx(
_In_ HANDLE hFile,
_In_opt_ LPOVERLAPPED lpOverlapped
);
BOOL WriteFile(
HANDLE hFile,
LPCVOID lpBuffer,
DWORD nNumberOfBytesToWrite,
LPDWORD lpNumberOfBytesWritten,
LPOVERLAPPED lpOverlapped
);
BOOL ReadFile(
HANDLE hFile,
LPVOID lpBuffer,
DWORD nNumberOfBytesToRead,
LPDWORD lpNumberOfBytesRead,
LPOVERLAPPED lpOverlapped
);
BOOL WINAPI SetConsoleCtrlHandler(
_In_opt_ void* HandlerRoutine,
_In_ BOOL Add
);
HANDLE CreateEventA(
LPSECURITY_ATTRIBUTES lpEventAttributes,
BOOL bManualReset,
BOOL bInitialState,
LPCSTR lpName
);
BOOL SetEvent(
HANDLE hEvent
);
BOOL ResetEvent(
HANDLE hEvent
);
DWORD WaitForSingleObject(
HANDLE hHandle,
DWORD dwMilliseconds
);
DWORD WaitForMultipleObjects(
DWORD nCount,
HANDLE *lpHandles,
BOOL bWaitAll,
DWORD dwMilliseconds
);
ULONG RtlNtStatusToDosError(
NTSTATUS Status
);
int WSAIoctl(
SOCKET s,
DWORD dwIoControlCode,
LPVOID lpvInBuffer,
DWORD cbInBuffer,
LPVOID lpvOutBuffer,
DWORD cbOutBuffer,
LPDWORD lpcbBytesReturned,
LPWSAOVERLAPPED lpOverlapped,
// actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
void* lpCompletionRoutine
);
int WSAGetLastError();
BOOL DeviceIoControl(
HANDLE hDevice,
DWORD dwIoControlCode,
LPVOID lpInBuffer,
DWORD nInBufferSize,
LPVOID lpOutBuffer,
DWORD nOutBufferSize,
LPDWORD lpBytesReturned,
LPOVERLAPPED lpOverlapped
);
// From https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
typedef struct _AFD_POLL_HANDLE_INFO {
HANDLE Handle;
ULONG Events;
NTSTATUS Status;
} AFD_POLL_HANDLE_INFO, *PAFD_POLL_HANDLE_INFO;
// This is really defined as a messy union to allow stuff like
// i.DUMMYSTRUCTNAME.LowPart, but we don't need those complications.
// Under all that it's just an int64.
typedef int64_t LARGE_INTEGER;
typedef struct _AFD_POLL_INFO {
LARGE_INTEGER Timeout;
ULONG NumberOfHandles;
ULONG Exclusive;
AFD_POLL_HANDLE_INFO Handles[1];
} AFD_POLL_INFO, *PAFD_POLL_INFO;
"""
# cribbed from pywincffi
# programmatically strips out those annotations MSDN likes, like _In_
REGEX_SAL_ANNOTATION = re.compile(
r"\b(_In_|_Inout_|_Out_|_Outptr_|_Reserved_)(opt_)?\b",
)
LIB = REGEX_SAL_ANNOTATION.sub(" ", LIB)
# Other fixups:
# - get rid of FAR, cffi doesn't like it
LIB = re.sub(r"\bFAR\b", " ", LIB)
# - PASCAL is apparently an alias for __stdcall (on modern compilers - modern
# being _MSC_VER >= 800)
LIB = re.sub(r"\bPASCAL\b", "__stdcall", LIB)
ffi = cffi.api.FFI()
ffi.cdef(LIB)
CData: TypeAlias = cffi.api.FFI.CData
CType: TypeAlias = cffi.api.FFI.CType
AlwaysNull: TypeAlias = CType # We currently always pass ffi.NULL here.
Handle = NewType("Handle", CData)
HandleArray = NewType("HandleArray", CData)
class _Kernel32(Protocol):
"""Statically typed version of the kernel32.dll functions we use."""
def CreateIoCompletionPort(
self,
FileHandle: Handle,
ExistingCompletionPort: CData | AlwaysNull,
CompletionKey: int,
NumberOfConcurrentThreads: int,
/,
) -> Handle: ...
def CreateEventA(
self,
lpEventAttributes: AlwaysNull,
bManualReset: bool,
bInitialState: bool,
lpName: AlwaysNull,
/,
) -> Handle: ...
def SetFileCompletionNotificationModes(
self,
handle: Handle,
flags: CompletionModes,
/,
) -> int: ...
def PostQueuedCompletionStatus(
self,
CompletionPort: Handle,
dwNumberOfBytesTransferred: int,
dwCompletionKey: int,
lpOverlapped: CData | AlwaysNull,
/,
) -> bool: ...
def CancelIoEx(
self,
hFile: Handle,
lpOverlapped: CData | AlwaysNull,
/,
) -> bool: ...
def WriteFile(
self,
hFile: Handle,
# not sure about this type
lpBuffer: CData,
nNumberOfBytesToWrite: int,
lpNumberOfBytesWritten: AlwaysNull,
lpOverlapped: _Overlapped,
/,
) -> bool: ...
def ReadFile(
self,
hFile: Handle,
# not sure about this type
lpBuffer: CData,
nNumberOfBytesToRead: int,
lpNumberOfBytesRead: AlwaysNull,
lpOverlapped: _Overlapped,
/,
) -> bool: ...
def GetQueuedCompletionStatusEx(
self,
CompletionPort: Handle,
lpCompletionPortEntries: CData,
ulCount: int,
ulNumEntriesRemoved: CData,
dwMilliseconds: int,
fAlertable: bool | int,
/,
) -> CData: ...
def CreateFileW(
self,
lpFileName: CData,
dwDesiredAccess: FileFlags,
dwShareMode: FileFlags,
lpSecurityAttributes: AlwaysNull,
dwCreationDisposition: FileFlags,
dwFlagsAndAttributes: FileFlags,
hTemplateFile: AlwaysNull,
/,
) -> Handle: ...
def WaitForSingleObject(self, hHandle: Handle, dwMilliseconds: int, /) -> CData: ...
def WaitForMultipleObjects(
self,
nCount: int,
lpHandles: HandleArray,
bWaitAll: bool,
dwMilliseconds: int,
/,
) -> ErrorCodes: ...
def SetEvent(self, handle: Handle, /) -> None: ...
def CloseHandle(self, handle: Handle, /) -> bool: ...
def DeviceIoControl(
self,
hDevice: Handle,
dwIoControlCode: int,
# this is wrong (it's not always null)
lpInBuffer: AlwaysNull,
nInBufferSize: int,
# this is also wrong
lpOutBuffer: AlwaysNull,
nOutBufferSize: int,
lpBytesReturned: AlwaysNull,
lpOverlapped: CData,
/,
) -> bool: ...
class _Nt(Protocol):
"""Statically typed version of the dtdll.dll functions we use."""
def RtlNtStatusToDosError(self, status: int, /) -> ErrorCodes: ...
class _Ws2(Protocol):
"""Statically typed version of the ws2_32.dll functions we use."""
def WSAGetLastError(self) -> int: ...
def WSAIoctl(
self,
socket: CData,
dwIoControlCode: WSAIoctls,
lpvInBuffer: AlwaysNull,
cbInBuffer: int,
lpvOutBuffer: CData,
cbOutBuffer: int,
lpcbBytesReturned: CData, # int*
lpOverlapped: AlwaysNull,
# actually LPWSAOVERLAPPED_COMPLETION_ROUTINE
lpCompletionRoutine: AlwaysNull,
/,
) -> int: ...
class _DummyStruct(Protocol):
Offset: int
OffsetHigh: int
class _DummyUnion(Protocol):
DUMMYSTRUCTNAME: _DummyStruct
Pointer: object
class _Overlapped(Protocol):
Internal: int
InternalHigh: int
DUMMYUNIONNAME: _DummyUnion
hEvent: Handle
kernel32 = cast(_Kernel32, ffi.dlopen("kernel32.dll"))
ntdll = cast(_Nt, ffi.dlopen("ntdll.dll"))
ws2_32 = cast(_Ws2, ffi.dlopen("ws2_32.dll"))
################################################################
# Magic numbers
################################################################
# Here's a great resource for looking these up:
# https://www.magnumdb.com
# (Tip: check the box to see "Hex value")
INVALID_HANDLE_VALUE = Handle(ffi.cast("HANDLE", -1))
class ErrorCodes(enum.IntEnum):
STATUS_TIMEOUT = 0x102
WAIT_TIMEOUT = 0x102
WAIT_ABANDONED = 0x80
WAIT_OBJECT_0 = 0x00 # object is signaled
WAIT_FAILED = 0xFFFFFFFF
ERROR_IO_PENDING = 997
ERROR_OPERATION_ABORTED = 995
ERROR_ABANDONED_WAIT_0 = 735
ERROR_INVALID_HANDLE = 6
ERROR_INVALID_PARMETER = 87
ERROR_NOT_FOUND = 1168
ERROR_NOT_SOCKET = 10038
class FileFlags(enum.IntFlag):
GENERIC_READ = 0x80000000
SYNCHRONIZE = 0x00100000
FILE_FLAG_OVERLAPPED = 0x40000000
FILE_SHARE_READ = 1
FILE_SHARE_WRITE = 2
FILE_SHARE_DELETE = 4
CREATE_NEW = 1
CREATE_ALWAYS = 2
OPEN_EXISTING = 3
OPEN_ALWAYS = 4
TRUNCATE_EXISTING = 5
class AFDPollFlags(enum.IntFlag):
# These are drawn from a combination of:
# https://github.com/piscisaureus/wepoll/blob/master/src/afd.h
# https://github.com/reactos/reactos/blob/master/sdk/include/reactos/drivers/afd/shared.h
AFD_POLL_RECEIVE = 0x0001
AFD_POLL_RECEIVE_EXPEDITED = 0x0002 # OOB/urgent data
AFD_POLL_SEND = 0x0004
AFD_POLL_DISCONNECT = 0x0008 # received EOF (FIN)
AFD_POLL_ABORT = 0x0010 # received RST
AFD_POLL_LOCAL_CLOSE = 0x0020 # local socket object closed
AFD_POLL_CONNECT = 0x0040 # socket is successfully connected
AFD_POLL_ACCEPT = 0x0080 # you can call accept on this socket
AFD_POLL_CONNECT_FAIL = 0x0100 # connect() terminated unsuccessfully
# See WSAEventSelect docs for more details on these four:
AFD_POLL_QOS = 0x0200
AFD_POLL_GROUP_QOS = 0x0400
AFD_POLL_ROUTING_INTERFACE_CHANGE = 0x0800
AFD_POLL_EVENT_ADDRESS_LIST_CHANGE = 0x1000
class WSAIoctls(enum.IntEnum):
SIO_BASE_HANDLE = 0x48000022
SIO_BSP_HANDLE_SELECT = 0x4800001C
SIO_BSP_HANDLE_POLL = 0x4800001D
class CompletionModes(enum.IntFlag):
FILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 0x1
FILE_SKIP_SET_EVENT_ON_HANDLE = 0x2
class IoControlCodes(enum.IntEnum):
IOCTL_AFD_POLL = 0x00012024
################################################################
# Generic helpers
################################################################
def _handle(obj: int | CData) -> Handle:
# For now, represent handles as either cffi HANDLEs or as ints. If you
# try to pass in a file descriptor instead, it's not going to work
# out. (For that msvcrt.get_osfhandle does the trick, but I don't know if
# we'll actually need that for anything...) For sockets this doesn't
# matter, Python never allocates an fd. So let's wait until we actually
# encounter the problem before worrying about it.
if isinstance(obj, int):
return Handle(ffi.cast("HANDLE", obj))
return Handle(obj)
def handle_array(count: int) -> HandleArray:
"""Make an array of handles."""
return HandleArray(ffi.new(f"HANDLE[{count}]"))
def raise_winerror(
winerror: int | None = None,
*,
filename: str | None = None,
filename2: str | None = None,
) -> NoReturn:
# assert sys.platform == "win32" # TODO: make this work in MyPy
# ... in the meanwhile, ffi.getwinerror() is undefined on non-Windows, necessitating the type
# ignores.
if winerror is None:
err = ffi.getwinerror() # type: ignore[attr-defined,unused-ignore]
if err is None:
raise RuntimeError("No error set?")
winerror, msg = err
else:
err = ffi.getwinerror(winerror) # type: ignore[attr-defined,unused-ignore]
if err is None:
raise RuntimeError("No error set?")
_, msg = err
# https://docs.python.org/3/library/exceptions.html#OSError
raise OSError(0, msg, filename, winerror, filename2)