99 lines
3.0 KiB
Python
99 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import NoReturn
|
|
|
|
import attrs
|
|
import pytest
|
|
|
|
from .._highlevel_generic import StapledStream
|
|
from ..abc import ReceiveStream, SendStream
|
|
|
|
|
|
@attrs.define(slots=False)
|
|
class RecordSendStream(SendStream):
|
|
record: list[str | tuple[str, object]] = attrs.Factory(list)
|
|
|
|
async def send_all(self, data: object) -> None:
|
|
self.record.append(("send_all", data))
|
|
|
|
async def wait_send_all_might_not_block(self) -> None:
|
|
self.record.append("wait_send_all_might_not_block")
|
|
|
|
async def aclose(self) -> None:
|
|
self.record.append("aclose")
|
|
|
|
|
|
@attrs.define(slots=False)
|
|
class RecordReceiveStream(ReceiveStream):
|
|
record: list[str | tuple[str, int | None]] = attrs.Factory(list)
|
|
|
|
async def receive_some(self, max_bytes: int | None = None) -> bytes:
|
|
self.record.append(("receive_some", max_bytes))
|
|
return b""
|
|
|
|
async def aclose(self) -> None:
|
|
self.record.append("aclose")
|
|
|
|
|
|
async def test_StapledStream() -> None:
|
|
send_stream = RecordSendStream()
|
|
receive_stream = RecordReceiveStream()
|
|
stapled = StapledStream(send_stream, receive_stream)
|
|
|
|
assert stapled.send_stream is send_stream
|
|
assert stapled.receive_stream is receive_stream
|
|
|
|
await stapled.send_all(b"foo")
|
|
await stapled.wait_send_all_might_not_block()
|
|
assert send_stream.record == [
|
|
("send_all", b"foo"),
|
|
"wait_send_all_might_not_block",
|
|
]
|
|
send_stream.record.clear()
|
|
|
|
await stapled.send_eof()
|
|
assert send_stream.record == ["aclose"]
|
|
send_stream.record.clear()
|
|
|
|
async def fake_send_eof() -> None:
|
|
send_stream.record.append("send_eof")
|
|
|
|
send_stream.send_eof = fake_send_eof # type: ignore[attr-defined]
|
|
await stapled.send_eof()
|
|
assert send_stream.record == ["send_eof"]
|
|
|
|
send_stream.record.clear()
|
|
assert receive_stream.record == []
|
|
|
|
await stapled.receive_some(1234)
|
|
assert receive_stream.record == [("receive_some", 1234)]
|
|
assert send_stream.record == []
|
|
receive_stream.record.clear()
|
|
|
|
await stapled.aclose()
|
|
assert receive_stream.record == ["aclose"]
|
|
assert send_stream.record == ["aclose"]
|
|
|
|
|
|
async def test_StapledStream_with_erroring_close() -> None:
|
|
# Make sure that if one of the aclose methods errors out, then the other
|
|
# one still gets called.
|
|
class BrokenSendStream(RecordSendStream):
|
|
async def aclose(self) -> NoReturn:
|
|
await super().aclose()
|
|
raise ValueError("send error")
|
|
|
|
class BrokenReceiveStream(RecordReceiveStream):
|
|
async def aclose(self) -> NoReturn:
|
|
await super().aclose()
|
|
raise ValueError("recv error")
|
|
|
|
stapled = StapledStream(BrokenSendStream(), BrokenReceiveStream())
|
|
|
|
with pytest.raises(ValueError, match="^(send|recv) error$") as excinfo:
|
|
await stapled.aclose()
|
|
assert isinstance(excinfo.value.__context__, ValueError)
|
|
|
|
assert stapled.send_stream.record == ["aclose"]
|
|
assert stapled.receive_stream.record == ["aclose"]
|