from __future__ import annotations import signal from typing import TYPE_CHECKING, NoReturn import pytest import trio from trio.testing import RaisesGroup from .. import _core from .._signals import _signal_handler, get_pending_signal_count, open_signal_receiver from .._util import signal_raise if TYPE_CHECKING: from types import FrameType async def test_open_signal_receiver() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL) as receiver: # Raise it a few times, to exercise signal coalescing, both at the # call_soon level and at the SignalQueue level signal_raise(signal.SIGILL) signal_raise(signal.SIGILL) await _core.wait_all_tasks_blocked() signal_raise(signal.SIGILL) await _core.wait_all_tasks_blocked() async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break assert get_pending_signal_count(receiver) == 0 signal_raise(signal.SIGILL) async for signum in receiver: # pragma: no branch assert signum == signal.SIGILL break assert get_pending_signal_count(receiver) == 0 with pytest.raises(RuntimeError): await receiver.__anext__() assert signal.getsignal(signal.SIGILL) is orig async def test_open_signal_receiver_restore_handler_after_one_bad_signal() -> None: orig = signal.getsignal(signal.SIGILL) with pytest.raises( ValueError, match="(signal number out of range|invalid signal value)$", ): with open_signal_receiver(signal.SIGILL, 1234567): pass # pragma: no cover # Still restored even if we errored out assert signal.getsignal(signal.SIGILL) is orig async def test_open_signal_receiver_empty_fail() -> None: with pytest.raises(TypeError, match="No signals were provided"): with open_signal_receiver(): pass async def test_open_signal_receiver_restore_handler_after_duplicate_signal() -> None: orig = signal.getsignal(signal.SIGILL) with open_signal_receiver(signal.SIGILL, signal.SIGILL): pass # Still restored correctly assert signal.getsignal(signal.SIGILL) is orig async def test_catch_signals_wrong_thread() -> None: async def naughty() -> None: with open_signal_receiver(signal.SIGINT): pass # pragma: no cover with pytest.raises(RuntimeError): await trio.to_thread.run_sync(trio.run, naughty) async def test_open_signal_receiver_conflict() -> None: with RaisesGroup(trio.BusyResourceError): with open_signal_receiver(signal.SIGILL) as receiver: async with trio.open_nursery() as nursery: nursery.start_soon(receiver.__anext__) nursery.start_soon(receiver.__anext__) # Blocks until all previous calls to run_sync_soon(idempotent=True) have been # processed. async def wait_run_sync_soon_idempotent_queue_barrier() -> None: ev = trio.Event() token = _core.current_trio_token() token.run_sync_soon(ev.set, idempotent=True) await ev.wait() async def test_open_signal_receiver_no_starvation() -> None: # Set up a situation where there are always 2 pending signals available to # report, and make sure that instead of getting the same signal reported # over and over, it alternates between reporting both of them. with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: try: print(signal.getsignal(signal.SIGILL)) previous = None for _ in range(10): signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() if previous is None: previous = await receiver.__anext__() else: got = await receiver.__anext__() assert got in [signal.SIGILL, signal.SIGFPE] assert got != previous previous = got # Clear out the last signal so that it doesn't get redelivered while get_pending_signal_count(receiver) != 0: await receiver.__anext__() except BaseException: # pragma: no cover # If there's an unhandled exception above, then exiting the # open_signal_receiver block might cause the signal to be # redelivered and give us a core dump instead of a traceback... import traceback traceback.print_exc() async def test_catch_signals_race_condition_on_exit() -> None: delivered_directly: set[int] = set() def direct_handler(signo: int, frame: FrameType | None) -> None: delivered_directly.add(signo) print(1) # Test the version where the call_soon *doesn't* have a chance to run # before we exit the with block: with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert delivered_directly == {signal.SIGILL, signal.SIGFPE} delivered_directly.clear() print(2) # Test the version where the call_soon *does* have a chance to run before # we exit the with block: with _signal_handler({signal.SIGILL, signal.SIGFPE}, direct_handler): with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 2 assert delivered_directly == {signal.SIGILL, signal.SIGFPE} delivered_directly.clear() # Again, but with a SIG_IGN signal: print(3) with _signal_handler({signal.SIGILL}, signal.SIG_IGN): with open_signal_receiver(signal.SIGILL) as receiver: signal_raise(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() # test passes if the process reaches this point without dying print(4) with _signal_handler({signal.SIGILL}, signal.SIG_IGN): with open_signal_receiver(signal.SIGILL) as receiver: signal_raise(signal.SIGILL) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 1 # test passes if the process reaches this point without dying # Check exception chaining if there are multiple exception-raising # handlers def raise_handler(signum: int, frame: FrameType | None) -> NoReturn: raise RuntimeError(signum) with _signal_handler({signal.SIGILL, signal.SIGFPE}, raise_handler): with pytest.raises(RuntimeError) as excinfo: with open_signal_receiver(signal.SIGILL, signal.SIGFPE) as receiver: signal_raise(signal.SIGILL) signal_raise(signal.SIGFPE) await wait_run_sync_soon_idempotent_queue_barrier() assert get_pending_signal_count(receiver) == 2 exc = excinfo.value signums = {exc.args[0]} assert isinstance(exc.__context__, RuntimeError) signums.add(exc.__context__.args[0]) assert signums == {signal.SIGILL, signal.SIGFPE}