331 lines
11 KiB
Python
331 lines
11 KiB
Python
from __future__ import annotations
|
||
|
||
import errno
|
||
import socket as stdlib_socket
|
||
import sys
|
||
from typing import Sequence
|
||
|
||
import pytest
|
||
|
||
from .. import _core, socket as tsocket
|
||
from .._highlevel_socket import *
|
||
from ..testing import (
|
||
assert_checkpoints,
|
||
check_half_closeable_stream,
|
||
wait_all_tasks_blocked,
|
||
)
|
||
from .test_socket import setsockopt_tests
|
||
|
||
|
||
async def test_SocketStream_basics() -> None:
|
||
# stdlib socket bad (even if connected)
|
||
stdlib_a, stdlib_b = stdlib_socket.socketpair()
|
||
with stdlib_a, stdlib_b:
|
||
with pytest.raises(TypeError):
|
||
SocketStream(stdlib_a) # type: ignore[arg-type]
|
||
|
||
# DGRAM socket bad
|
||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as sock:
|
||
with pytest.raises(
|
||
ValueError,
|
||
match="^SocketStream requires a SOCK_STREAM socket$",
|
||
):
|
||
# TODO: does not raise an error?
|
||
SocketStream(sock)
|
||
|
||
a, b = tsocket.socketpair()
|
||
with a, b:
|
||
s = SocketStream(a)
|
||
assert s.socket is a
|
||
|
||
# Use a real, connected socket to test socket options, because
|
||
# socketpair() might give us a unix socket that doesn't support any of
|
||
# these options
|
||
with tsocket.socket() as listen_sock:
|
||
await listen_sock.bind(("127.0.0.1", 0))
|
||
listen_sock.listen(1)
|
||
with tsocket.socket() as client_sock:
|
||
await client_sock.connect(listen_sock.getsockname())
|
||
|
||
s = SocketStream(client_sock)
|
||
|
||
# TCP_NODELAY enabled by default
|
||
assert s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||
# We can disable it though
|
||
s.setsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, False)
|
||
assert not s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY)
|
||
|
||
res = s.getsockopt(tsocket.IPPROTO_TCP, tsocket.TCP_NODELAY, 1)
|
||
assert isinstance(res, bytes)
|
||
|
||
setsockopt_tests(s)
|
||
|
||
|
||
async def test_SocketStream_send_all() -> None:
|
||
BIG = 10000000
|
||
|
||
a_sock, b_sock = tsocket.socketpair()
|
||
with a_sock, b_sock:
|
||
a = SocketStream(a_sock)
|
||
b = SocketStream(b_sock)
|
||
|
||
# Check a send_all that has to be split into multiple parts (on most
|
||
# platforms... on Windows every send() either succeeds or fails as a
|
||
# whole)
|
||
async def sender() -> None:
|
||
data = bytearray(BIG)
|
||
await a.send_all(data)
|
||
# send_all uses memoryviews internally, which temporarily "lock"
|
||
# the object they view. If it doesn't clean them up properly, then
|
||
# some bytearray operations might raise an error afterwards, which
|
||
# would be a pretty weird and annoying side-effect to spring on
|
||
# users. So test that this doesn't happen, by forcing the
|
||
# bytearray's underlying buffer to be realloc'ed:
|
||
data += bytes(BIG)
|
||
# (Note: the above line of code doesn't do a very good job at
|
||
# testing anything, because:
|
||
# - on CPython, the refcount GC generally cleans up memoryviews
|
||
# for us even if we're sloppy.
|
||
# - on PyPy3, at least as of 5.7.0, the memoryview code and the
|
||
# bytearray code conspire so that resizing never fails – if
|
||
# resizing forces the bytearray's internal buffer to move, then
|
||
# all memoryview references are automagically updated (!!).
|
||
# See:
|
||
# https://gist.github.com/njsmith/0ffd38ec05ad8e34004f34a7dc492227
|
||
# But I'm leaving the test here in hopes that if this ever changes
|
||
# and we break our implementation of send_all, then we'll get some
|
||
# early warning...)
|
||
|
||
async def receiver() -> None:
|
||
# Make sure the sender fills up the kernel buffers and blocks
|
||
await wait_all_tasks_blocked()
|
||
nbytes = 0
|
||
while nbytes < BIG:
|
||
nbytes += len(await b.receive_some(BIG))
|
||
assert nbytes == BIG
|
||
|
||
async with _core.open_nursery() as nursery:
|
||
nursery.start_soon(sender)
|
||
nursery.start_soon(receiver)
|
||
|
||
# We know that we received BIG bytes of NULs so far. Make sure that
|
||
# was all the data in there.
|
||
await a.send_all(b"e")
|
||
assert await b.receive_some(10) == b"e"
|
||
await a.send_eof()
|
||
assert await b.receive_some(10) == b""
|
||
|
||
|
||
async def fill_stream(s: SocketStream) -> None:
|
||
async def sender() -> None:
|
||
while True:
|
||
await s.send_all(b"x" * 10000)
|
||
|
||
async def waiter(nursery: _core.Nursery) -> None:
|
||
await wait_all_tasks_blocked()
|
||
nursery.cancel_scope.cancel()
|
||
|
||
async with _core.open_nursery() as nursery:
|
||
nursery.start_soon(sender)
|
||
nursery.start_soon(waiter, nursery)
|
||
|
||
|
||
async def test_SocketStream_generic() -> None:
|
||
async def stream_maker() -> tuple[SocketStream, SocketStream]:
|
||
left, right = tsocket.socketpair()
|
||
return SocketStream(left), SocketStream(right)
|
||
|
||
async def clogged_stream_maker() -> tuple[SocketStream, SocketStream]:
|
||
left, right = await stream_maker()
|
||
await fill_stream(left)
|
||
await fill_stream(right)
|
||
return left, right
|
||
|
||
await check_half_closeable_stream(stream_maker, clogged_stream_maker)
|
||
|
||
|
||
async def test_SocketListener() -> None:
|
||
# Not a Trio socket
|
||
with stdlib_socket.socket() as s:
|
||
s.bind(("127.0.0.1", 0))
|
||
s.listen(10)
|
||
with pytest.raises(TypeError):
|
||
SocketListener(s) # type: ignore[arg-type]
|
||
|
||
# Not a SOCK_STREAM
|
||
with tsocket.socket(type=tsocket.SOCK_DGRAM) as s:
|
||
await s.bind(("127.0.0.1", 0))
|
||
with pytest.raises(
|
||
ValueError,
|
||
match="^SocketListener requires a SOCK_STREAM socket$",
|
||
) as excinfo:
|
||
SocketListener(s)
|
||
excinfo.match(r".*SOCK_STREAM")
|
||
|
||
# Didn't call .listen()
|
||
# macOS has no way to check for this, so skip testing it there.
|
||
if sys.platform != "darwin":
|
||
with tsocket.socket() as s:
|
||
await s.bind(("127.0.0.1", 0))
|
||
with pytest.raises(
|
||
ValueError,
|
||
match="^SocketListener requires a listening socket$",
|
||
) as excinfo:
|
||
SocketListener(s)
|
||
excinfo.match(r".*listen")
|
||
|
||
listen_sock = tsocket.socket()
|
||
await listen_sock.bind(("127.0.0.1", 0))
|
||
listen_sock.listen(10)
|
||
listener = SocketListener(listen_sock)
|
||
|
||
assert listener.socket is listen_sock
|
||
|
||
client_sock = tsocket.socket()
|
||
await client_sock.connect(listen_sock.getsockname())
|
||
with assert_checkpoints():
|
||
server_stream = await listener.accept()
|
||
assert isinstance(server_stream, SocketStream)
|
||
assert server_stream.socket.getsockname() == listen_sock.getsockname()
|
||
assert server_stream.socket.getpeername() == client_sock.getsockname()
|
||
|
||
with assert_checkpoints():
|
||
await listener.aclose()
|
||
|
||
with assert_checkpoints():
|
||
await listener.aclose()
|
||
|
||
with assert_checkpoints():
|
||
with pytest.raises(_core.ClosedResourceError):
|
||
await listener.accept()
|
||
|
||
client_sock.close()
|
||
await server_stream.aclose()
|
||
|
||
|
||
async def test_SocketListener_socket_closed_underfoot() -> None:
|
||
listen_sock = tsocket.socket()
|
||
await listen_sock.bind(("127.0.0.1", 0))
|
||
listen_sock.listen(10)
|
||
listener = SocketListener(listen_sock)
|
||
|
||
# Close the socket, not the listener
|
||
listen_sock.close()
|
||
|
||
# SocketListener gives correct error
|
||
with assert_checkpoints():
|
||
with pytest.raises(_core.ClosedResourceError):
|
||
await listener.accept()
|
||
|
||
|
||
async def test_SocketListener_accept_errors() -> None:
|
||
class FakeSocket(tsocket.SocketType):
|
||
def __init__(self, events: Sequence[SocketType | BaseException]) -> None:
|
||
self._events = iter(events)
|
||
|
||
type = tsocket.SOCK_STREAM
|
||
|
||
# Fool the check for SO_ACCEPTCONN in SocketListener.__init__
|
||
@overload
|
||
def getsockopt(self, /, level: int, optname: int) -> int: ...
|
||
|
||
@overload
|
||
def getsockopt( # noqa: F811
|
||
self,
|
||
/,
|
||
level: int,
|
||
optname: int,
|
||
buflen: int,
|
||
) -> bytes: ...
|
||
|
||
def getsockopt( # noqa: F811
|
||
self,
|
||
/,
|
||
level: int,
|
||
optname: int,
|
||
buflen: int | None = None,
|
||
) -> int | bytes:
|
||
return True
|
||
|
||
@overload
|
||
def setsockopt(
|
||
self,
|
||
/,
|
||
level: int,
|
||
optname: int,
|
||
value: int | Buffer,
|
||
) -> None: ...
|
||
|
||
@overload
|
||
def setsockopt( # noqa: F811
|
||
self,
|
||
/,
|
||
level: int,
|
||
optname: int,
|
||
value: None,
|
||
optlen: int,
|
||
) -> None: ...
|
||
|
||
def setsockopt( # noqa: F811
|
||
self,
|
||
/,
|
||
level: int,
|
||
optname: int,
|
||
value: int | Buffer | None,
|
||
optlen: int | None = None,
|
||
) -> None:
|
||
pass
|
||
|
||
async def accept(self) -> tuple[SocketType, object]:
|
||
await _core.checkpoint()
|
||
event = next(self._events)
|
||
if isinstance(event, BaseException):
|
||
raise event
|
||
else:
|
||
return event, None
|
||
|
||
fake_server_sock = FakeSocket([])
|
||
|
||
fake_listen_sock = FakeSocket(
|
||
[
|
||
OSError(errno.ECONNABORTED, "Connection aborted"),
|
||
OSError(errno.EPERM, "Permission denied"),
|
||
OSError(errno.EPROTO, "Bad protocol"),
|
||
fake_server_sock,
|
||
OSError(errno.EMFILE, "Out of file descriptors"),
|
||
OSError(errno.EFAULT, "attempt to write to read-only memory"),
|
||
OSError(errno.ENOBUFS, "out of buffers"),
|
||
fake_server_sock,
|
||
],
|
||
)
|
||
|
||
listener = SocketListener(fake_listen_sock)
|
||
|
||
with assert_checkpoints():
|
||
stream = await listener.accept()
|
||
assert stream.socket is fake_server_sock
|
||
|
||
for code, match in {
|
||
errno.EMFILE: r"\[\w+ \d+\] Out of file descriptors$",
|
||
errno.EFAULT: r"\[\w+ \d+\] attempt to write to read-only memory$",
|
||
errno.ENOBUFS: r"\[\w+ \d+\] out of buffers$",
|
||
}.items():
|
||
with assert_checkpoints():
|
||
with pytest.raises(OSError, match=match) as excinfo:
|
||
await listener.accept()
|
||
assert excinfo.value.errno == code
|
||
|
||
with assert_checkpoints():
|
||
stream = await listener.accept()
|
||
assert stream.socket is fake_server_sock
|
||
|
||
|
||
async def test_socket_stream_works_when_peer_has_already_closed() -> None:
|
||
sock_a, sock_b = tsocket.socketpair()
|
||
with sock_a, sock_b:
|
||
await sock_b.send(b"x")
|
||
sock_b.close()
|
||
stream = SocketStream(sock_a)
|
||
assert await stream.receive_some(1) == b"x"
|
||
assert await stream.receive_some(1) == b""
|