HES-Selenium/lib/python3.13/site-packages/trio/_tests/test_highlevel_generic.py

99 lines
3.0 KiB
Python

from __future__ import annotations
from typing import NoReturn
import attrs
import pytest
from .._highlevel_generic import StapledStream
from ..abc import ReceiveStream, SendStream
@attrs.define(slots=False)
class RecordSendStream(SendStream):
record: list[str | tuple[str, object]] = attrs.Factory(list)
async def send_all(self, data: object) -> None:
self.record.append(("send_all", data))
async def wait_send_all_might_not_block(self) -> None:
self.record.append("wait_send_all_might_not_block")
async def aclose(self) -> None:
self.record.append("aclose")
@attrs.define(slots=False)
class RecordReceiveStream(ReceiveStream):
record: list[str | tuple[str, int | None]] = attrs.Factory(list)
async def receive_some(self, max_bytes: int | None = None) -> bytes:
self.record.append(("receive_some", max_bytes))
return b""
async def aclose(self) -> None:
self.record.append("aclose")
async def test_StapledStream() -> None:
send_stream = RecordSendStream()
receive_stream = RecordReceiveStream()
stapled = StapledStream(send_stream, receive_stream)
assert stapled.send_stream is send_stream
assert stapled.receive_stream is receive_stream
await stapled.send_all(b"foo")
await stapled.wait_send_all_might_not_block()
assert send_stream.record == [
("send_all", b"foo"),
"wait_send_all_might_not_block",
]
send_stream.record.clear()
await stapled.send_eof()
assert send_stream.record == ["aclose"]
send_stream.record.clear()
async def fake_send_eof() -> None:
send_stream.record.append("send_eof")
send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
await stapled.send_eof()
assert send_stream.record == ["send_eof"]
send_stream.record.clear()
assert receive_stream.record == []
await stapled.receive_some(1234)
assert receive_stream.record == [("receive_some", 1234)]
assert send_stream.record == []
receive_stream.record.clear()
await stapled.aclose()
assert receive_stream.record == ["aclose"]
assert send_stream.record == ["aclose"]
async def test_StapledStream_with_erroring_close() -> None:
# Make sure that if one of the aclose methods errors out, then the other
# one still gets called.
class BrokenSendStream(RecordSendStream):
async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError("send error")
class BrokenReceiveStream(RecordReceiveStream):
async def aclose(self) -> NoReturn:
await super().aclose()
raise ValueError("recv error")
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
with pytest.raises(ValueError, match="^(send|recv) error$") as excinfo:
await stapled.aclose()
assert isinstance(excinfo.value.__context__, ValueError)
assert stapled.send_stream.record == ["aclose"]
assert stapled.receive_stream.record == ["aclose"]