Updated script that can be controled by Nodejs web app
This commit is contained in:
94
lib/python3.13/site-packages/wsproto/__init__.py
Normal file
94
lib/python3.13/site-packages/wsproto/__init__.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""
|
||||
wsproto
|
||||
~~~~~~~
|
||||
|
||||
A WebSocket implementation.
|
||||
"""
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from .connection import Connection, ConnectionState, ConnectionType
|
||||
from .events import Event
|
||||
from .handshake import H11Handshake
|
||||
from .typing import Headers
|
||||
|
||||
__version__ = "1.2.0"
|
||||
|
||||
|
||||
class WSConnection:
|
||||
"""
|
||||
Represents the local end of a WebSocket connection to a remote peer.
|
||||
"""
|
||||
|
||||
def __init__(self, connection_type: ConnectionType) -> None:
|
||||
"""
|
||||
Constructor
|
||||
|
||||
:param wsproto.connection.ConnectionType connection_type: Controls
|
||||
whether the library behaves as a client or as a server.
|
||||
"""
|
||||
self.client = connection_type is ConnectionType.CLIENT
|
||||
self.handshake = H11Handshake(connection_type)
|
||||
self.connection: Optional[Connection] = None
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
"""
|
||||
:returns: Connection state
|
||||
:rtype: wsproto.connection.ConnectionState
|
||||
"""
|
||||
if self.connection is None:
|
||||
return self.handshake.state
|
||||
return self.connection.state
|
||||
|
||||
def initiate_upgrade_connection(
|
||||
self, headers: Headers, path: Union[bytes, str]
|
||||
) -> None:
|
||||
self.handshake.initiate_upgrade_connection(headers, path)
|
||||
|
||||
def send(self, event: Event) -> bytes:
|
||||
"""
|
||||
Generate network data for the specified event.
|
||||
|
||||
When you want to communicate with a WebSocket peer, you should construct
|
||||
an event and pass it to this method. This method will return the bytes
|
||||
that you should send to the peer.
|
||||
|
||||
:param wsproto.events.Event event: The event to generate data for
|
||||
:returns bytes: The data to send to the peer
|
||||
"""
|
||||
data = b""
|
||||
if self.connection is None:
|
||||
data += self.handshake.send(event)
|
||||
self.connection = self.handshake.connection
|
||||
else:
|
||||
data += self.connection.send(event)
|
||||
return data
|
||||
|
||||
def receive_data(self, data: Optional[bytes]) -> None:
|
||||
"""
|
||||
Feed network data into the connection instance.
|
||||
|
||||
After calling this method, you should call :meth:`events` to see if the
|
||||
received data triggered any new events.
|
||||
|
||||
:param bytes data: Data received from remote peer
|
||||
"""
|
||||
if self.connection is None:
|
||||
self.handshake.receive_data(data)
|
||||
self.connection = self.handshake.connection
|
||||
else:
|
||||
self.connection.receive_data(data)
|
||||
|
||||
def events(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
A generator that yields pending events.
|
||||
|
||||
Each event is an instance of a subclass of
|
||||
:class:`wsproto.events.Event`.
|
||||
"""
|
||||
yield from self.handshake.events()
|
||||
if self.connection is not None:
|
||||
yield from self.connection.events()
|
||||
|
||||
|
||||
__all__ = ("ConnectionType", "WSConnection")
|
189
lib/python3.13/site-packages/wsproto/connection.py
Normal file
189
lib/python3.13/site-packages/wsproto/connection.py
Normal file
@ -0,0 +1,189 @@
|
||||
"""
|
||||
wsproto/connection
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
An implementation of a WebSocket connection.
|
||||
"""
|
||||
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import Deque, Generator, List, Optional
|
||||
|
||||
from .events import (
|
||||
BytesMessage,
|
||||
CloseConnection,
|
||||
Event,
|
||||
Message,
|
||||
Ping,
|
||||
Pong,
|
||||
TextMessage,
|
||||
)
|
||||
from .extensions import Extension
|
||||
from .frame_protocol import CloseReason, FrameProtocol, Opcode, ParseFailed
|
||||
from .utilities import LocalProtocolError
|
||||
|
||||
|
||||
class ConnectionState(Enum):
|
||||
"""
|
||||
RFC 6455, Section 4 - Opening Handshake
|
||||
"""
|
||||
|
||||
#: The opening handshake is in progress.
|
||||
CONNECTING = 0
|
||||
#: The opening handshake is complete.
|
||||
OPEN = 1
|
||||
#: The remote WebSocket has initiated a connection close.
|
||||
REMOTE_CLOSING = 2
|
||||
#: The local WebSocket (i.e. this instance) has initiated a connection close.
|
||||
LOCAL_CLOSING = 3
|
||||
#: The closing handshake has completed.
|
||||
CLOSED = 4
|
||||
#: The connection was rejected during the opening handshake.
|
||||
REJECTING = 5
|
||||
|
||||
|
||||
class ConnectionType(Enum):
|
||||
"""An enumeration of connection types."""
|
||||
|
||||
#: This connection will act as client and talk to a remote server
|
||||
CLIENT = 1
|
||||
|
||||
#: This connection will as as server and waits for client connections
|
||||
SERVER = 2
|
||||
|
||||
|
||||
CLIENT = ConnectionType.CLIENT
|
||||
SERVER = ConnectionType.SERVER
|
||||
|
||||
|
||||
class Connection:
|
||||
"""
|
||||
A low-level WebSocket connection object.
|
||||
|
||||
This wraps two other protocol objects, an HTTP/1.1 protocol object used
|
||||
to do the initial HTTP upgrade handshake and a WebSocket frame protocol
|
||||
object used to exchange messages and other control frames.
|
||||
|
||||
:param conn_type: Whether this object is on the client- or server-side of
|
||||
a connection. To initialise as a client pass ``CLIENT`` otherwise
|
||||
pass ``SERVER``.
|
||||
:type conn_type: ``ConnectionType``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connection_type: ConnectionType,
|
||||
extensions: Optional[List[Extension]] = None,
|
||||
trailing_data: bytes = b"",
|
||||
) -> None:
|
||||
self.client = connection_type is ConnectionType.CLIENT
|
||||
self._events: Deque[Event] = deque()
|
||||
self._proto = FrameProtocol(self.client, extensions or [])
|
||||
self._state = ConnectionState.OPEN
|
||||
self.receive_data(trailing_data)
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
return self._state
|
||||
|
||||
def send(self, event: Event) -> bytes:
|
||||
data = b""
|
||||
if isinstance(event, Message) and self.state == ConnectionState.OPEN:
|
||||
data += self._proto.send_data(event.data, event.message_finished)
|
||||
elif isinstance(event, Ping) and self.state == ConnectionState.OPEN:
|
||||
data += self._proto.ping(event.payload)
|
||||
elif isinstance(event, Pong) and self.state == ConnectionState.OPEN:
|
||||
data += self._proto.pong(event.payload)
|
||||
elif isinstance(event, CloseConnection) and self.state in {
|
||||
ConnectionState.OPEN,
|
||||
ConnectionState.REMOTE_CLOSING,
|
||||
}:
|
||||
data += self._proto.close(event.code, event.reason)
|
||||
if self.state == ConnectionState.REMOTE_CLOSING:
|
||||
self._state = ConnectionState.CLOSED
|
||||
else:
|
||||
self._state = ConnectionState.LOCAL_CLOSING
|
||||
else:
|
||||
raise LocalProtocolError(
|
||||
f"Event {event} cannot be sent in state {self.state}."
|
||||
)
|
||||
return data
|
||||
|
||||
def receive_data(self, data: Optional[bytes]) -> None:
|
||||
"""
|
||||
Pass some received data to the connection for handling.
|
||||
|
||||
A list of events that the remote peer triggered by sending this data can
|
||||
be retrieved with :meth:`~wsproto.connection.Connection.events`.
|
||||
|
||||
:param data: The data received from the remote peer on the network.
|
||||
:type data: ``bytes``
|
||||
"""
|
||||
|
||||
if data is None:
|
||||
# "If _The WebSocket Connection is Closed_ and no Close control
|
||||
# frame was received by the endpoint (such as could occur if the
|
||||
# underlying transport connection is lost), _The WebSocket
|
||||
# Connection Close Code_ is considered to be 1006."
|
||||
self._events.append(CloseConnection(code=CloseReason.ABNORMAL_CLOSURE))
|
||||
self._state = ConnectionState.CLOSED
|
||||
return
|
||||
|
||||
if self.state in (ConnectionState.OPEN, ConnectionState.LOCAL_CLOSING):
|
||||
self._proto.receive_bytes(data)
|
||||
elif self.state is ConnectionState.CLOSED:
|
||||
raise LocalProtocolError("Connection already closed.")
|
||||
else:
|
||||
pass # pragma: no cover
|
||||
|
||||
def events(self) -> Generator[Event, None, None]:
|
||||
"""
|
||||
Return a generator that provides any events that have been generated
|
||||
by protocol activity.
|
||||
|
||||
:returns: generator of :class:`Event <wsproto.events.Event>` subclasses
|
||||
"""
|
||||
while self._events:
|
||||
yield self._events.popleft()
|
||||
|
||||
try:
|
||||
for frame in self._proto.received_frames():
|
||||
if frame.opcode is Opcode.PING:
|
||||
assert frame.frame_finished and frame.message_finished
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
yield Ping(payload=frame.payload)
|
||||
|
||||
elif frame.opcode is Opcode.PONG:
|
||||
assert frame.frame_finished and frame.message_finished
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
yield Pong(payload=frame.payload)
|
||||
|
||||
elif frame.opcode is Opcode.CLOSE:
|
||||
assert isinstance(frame.payload, tuple)
|
||||
code, reason = frame.payload
|
||||
if self.state is ConnectionState.LOCAL_CLOSING:
|
||||
self._state = ConnectionState.CLOSED
|
||||
else:
|
||||
self._state = ConnectionState.REMOTE_CLOSING
|
||||
yield CloseConnection(code=code, reason=reason)
|
||||
|
||||
elif frame.opcode is Opcode.TEXT:
|
||||
assert isinstance(frame.payload, str)
|
||||
yield TextMessage(
|
||||
data=frame.payload,
|
||||
frame_finished=frame.frame_finished,
|
||||
message_finished=frame.message_finished,
|
||||
)
|
||||
|
||||
elif frame.opcode is Opcode.BINARY:
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
yield BytesMessage(
|
||||
data=frame.payload,
|
||||
frame_finished=frame.frame_finished,
|
||||
message_finished=frame.message_finished,
|
||||
)
|
||||
|
||||
else:
|
||||
pass # pragma: no cover
|
||||
except ParseFailed as exc:
|
||||
yield CloseConnection(code=exc.code, reason=str(exc))
|
295
lib/python3.13/site-packages/wsproto/events.py
Normal file
295
lib/python3.13/site-packages/wsproto/events.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""
|
||||
wsproto/events
|
||||
~~~~~~~~~~~~~~
|
||||
|
||||
Events that result from processing data on a WebSocket connection.
|
||||
"""
|
||||
from abc import ABC
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Generic, List, Optional, Sequence, TypeVar, Union
|
||||
|
||||
from .extensions import Extension
|
||||
from .typing import Headers
|
||||
|
||||
|
||||
class Event(ABC):
|
||||
"""
|
||||
Base class for wsproto events.
|
||||
"""
|
||||
|
||||
pass # noqa
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Request(Event):
|
||||
"""The beginning of a Websocket connection, the HTTP Upgrade request
|
||||
|
||||
This event is fired when a SERVER connection receives a WebSocket
|
||||
handshake request (HTTP with upgrade header).
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: host
|
||||
|
||||
(Required) The hostname, or host header value.
|
||||
|
||||
.. attribute:: target
|
||||
|
||||
(Required) The request target (path and query string)
|
||||
|
||||
.. attribute:: extensions
|
||||
|
||||
The proposed extensions.
|
||||
|
||||
.. attribute:: extra_headers
|
||||
|
||||
The additional request headers, excluding extensions, host, subprotocols,
|
||||
and version headers.
|
||||
|
||||
.. attribute:: subprotocols
|
||||
|
||||
A list of the subprotocols proposed in the request, as a list
|
||||
of strings.
|
||||
"""
|
||||
|
||||
host: str
|
||||
target: str
|
||||
extensions: Union[Sequence[Extension], Sequence[str]] = field( # type: ignore[assignment]
|
||||
default_factory=list
|
||||
)
|
||||
extra_headers: Headers = field(default_factory=list)
|
||||
subprotocols: List[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AcceptConnection(Event):
|
||||
"""The acceptance of a Websocket upgrade request.
|
||||
|
||||
This event is fired when a CLIENT receives an acceptance response
|
||||
from a server. It is also used to accept an upgrade request when
|
||||
acting as a SERVER.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: extra_headers
|
||||
|
||||
Any additional (non websocket related) headers present in the
|
||||
acceptance response.
|
||||
|
||||
.. attribute:: subprotocol
|
||||
|
||||
The accepted subprotocol to use.
|
||||
|
||||
"""
|
||||
|
||||
subprotocol: Optional[str] = None
|
||||
extensions: List[Extension] = field(default_factory=list)
|
||||
extra_headers: Headers = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RejectConnection(Event):
|
||||
"""The rejection of a Websocket upgrade request, the HTTP response.
|
||||
|
||||
The ``RejectConnection`` event sends the appropriate HTTP headers to
|
||||
communicate to the peer that the handshake has been rejected. You may also
|
||||
send an HTTP body by setting the ``has_body`` attribute to ``True`` and then
|
||||
sending one or more :class:`RejectData` events after this one. When sending
|
||||
a response body, the caller should set the ``Content-Length``,
|
||||
``Content-Type``, and/or ``Transfer-Encoding`` headers as appropriate.
|
||||
|
||||
When receiving a ``RejectConnection`` event, the ``has_body`` attribute will
|
||||
in almost all cases be ``True`` (even if the server set it to ``False``) and
|
||||
will be followed by at least one ``RejectData`` events, even though the data
|
||||
itself might be just ``b""``. (The only scenario in which the caller
|
||||
receives a ``RejectConnection`` with ``has_body == False`` is if the peer
|
||||
violates sends an informational status code (1xx) other than 101.)
|
||||
|
||||
The ``has_body`` attribute should only be used when receiving the event. (It
|
||||
has ) is False the headers must include a
|
||||
content-length or transfer encoding.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: headers (Headers)
|
||||
|
||||
The headers to send with the response.
|
||||
|
||||
.. attribute:: has_body
|
||||
|
||||
This defaults to False, but set to True if there is a body. See
|
||||
also :class:`~RejectData`.
|
||||
|
||||
.. attribute:: status_code
|
||||
|
||||
The response status code.
|
||||
|
||||
"""
|
||||
|
||||
status_code: int = 400
|
||||
headers: Headers = field(default_factory=list)
|
||||
has_body: bool = False
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RejectData(Event):
|
||||
"""The rejection HTTP response body.
|
||||
|
||||
The caller may send multiple ``RejectData`` events. The final event should
|
||||
have the ``body_finished`` attribute set to ``True``.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: body_finished
|
||||
|
||||
True if this is the final chunk of the body data.
|
||||
|
||||
.. attribute:: data (bytes)
|
||||
|
||||
(Required) The raw body data.
|
||||
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
body_finished: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CloseConnection(Event):
|
||||
|
||||
"""The end of a Websocket connection, represents a closure frame.
|
||||
|
||||
**wsproto does not automatically send a response to a close event.** To
|
||||
comply with the RFC you MUST send a close event back to the remote WebSocket
|
||||
if you have not already sent one. The :meth:`response` method provides a
|
||||
suitable event for this purpose, and you should check if a response needs
|
||||
to be sent by checking :func:`wsproto.WSConnection.state`.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: code
|
||||
|
||||
(Required) The integer close code to indicate why the connection
|
||||
has closed.
|
||||
|
||||
.. attribute:: reason
|
||||
|
||||
Additional reasoning for why the connection has closed.
|
||||
|
||||
"""
|
||||
|
||||
code: int
|
||||
reason: Optional[str] = None
|
||||
|
||||
def response(self) -> "CloseConnection":
|
||||
"""Generate an RFC-compliant close frame to send back to the peer."""
|
||||
return CloseConnection(code=self.code, reason=self.reason)
|
||||
|
||||
|
||||
T = TypeVar("T", bytes, str)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Message(Event, Generic[T]):
|
||||
"""The websocket data message.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: data
|
||||
|
||||
(Required) The message data as byte string, can be decoded as UTF-8 for
|
||||
TEXT messages. This only represents a single chunk of data and
|
||||
not a full WebSocket message. You need to buffer and
|
||||
reassemble these chunks to get the full message.
|
||||
|
||||
.. attribute:: frame_finished
|
||||
|
||||
This has no semantic content, but is provided just in case some
|
||||
weird edge case user wants to be able to reconstruct the
|
||||
fragmentation pattern of the original stream.
|
||||
|
||||
.. attribute:: message_finished
|
||||
|
||||
True if this frame is the last one of this message, False if
|
||||
more frames are expected.
|
||||
|
||||
"""
|
||||
|
||||
data: T
|
||||
frame_finished: bool = True
|
||||
message_finished: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TextMessage(Message[str]): # pylint: disable=unsubscriptable-object
|
||||
"""This event is fired when a data frame with TEXT payload is received.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: data
|
||||
|
||||
The message data as string, This only represents a single chunk
|
||||
of data and not a full WebSocket message. You need to buffer
|
||||
and reassemble these chunks to get the full message.
|
||||
|
||||
"""
|
||||
|
||||
# https://github.com/python/mypy/issues/5744
|
||||
data: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BytesMessage(Message[bytes]): # pylint: disable=unsubscriptable-object
|
||||
"""This event is fired when a data frame with BINARY payload is
|
||||
received.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: data
|
||||
|
||||
The message data as byte string, can be decoded as UTF-8 for
|
||||
TEXT messages. This only represents a single chunk of data and
|
||||
not a full WebSocket message. You need to buffer and
|
||||
reassemble these chunks to get the full message.
|
||||
"""
|
||||
|
||||
# https://github.com/python/mypy/issues/5744
|
||||
data: bytes
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Ping(Event):
|
||||
"""The Ping event can be sent to trigger a ping frame and is fired
|
||||
when a Ping is received.
|
||||
|
||||
**wsproto does not automatically send a pong response to a ping event.** To
|
||||
comply with the RFC you MUST send a pong even as soon as is practical. The
|
||||
:meth:`response` method provides a suitable event for this purpose.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: payload
|
||||
|
||||
An optional payload to emit with the ping frame.
|
||||
"""
|
||||
|
||||
payload: bytes = b""
|
||||
|
||||
def response(self) -> "Pong":
|
||||
"""Generate an RFC-compliant :class:`Pong` response to this ping."""
|
||||
return Pong(payload=self.payload)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Pong(Event):
|
||||
"""The Pong event is fired when a Pong is received.
|
||||
|
||||
Fields:
|
||||
|
||||
.. attribute:: payload
|
||||
|
||||
An optional payload to emit with the pong frame.
|
||||
|
||||
"""
|
||||
|
||||
payload: bytes = b""
|
315
lib/python3.13/site-packages/wsproto/extensions.py
Normal file
315
lib/python3.13/site-packages/wsproto/extensions.py
Normal file
@ -0,0 +1,315 @@
|
||||
"""
|
||||
wsproto/extensions
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
WebSocket extensions.
|
||||
"""
|
||||
|
||||
import zlib
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from .frame_protocol import CloseReason, FrameDecoder, FrameProtocol, Opcode, RsvBits
|
||||
|
||||
|
||||
class Extension:
|
||||
name: str
|
||||
|
||||
def enabled(self) -> bool:
|
||||
return False
|
||||
|
||||
def offer(self) -> Union[bool, str]:
|
||||
pass
|
||||
|
||||
def accept(self, offer: str) -> Optional[Union[bool, str]]:
|
||||
pass
|
||||
|
||||
def finalize(self, offer: str) -> None:
|
||||
pass
|
||||
|
||||
def frame_inbound_header(
|
||||
self,
|
||||
proto: Union[FrameDecoder, FrameProtocol],
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
payload_length: int,
|
||||
) -> Union[CloseReason, RsvBits]:
|
||||
return RsvBits(False, False, False)
|
||||
|
||||
def frame_inbound_payload_data(
|
||||
self, proto: Union[FrameDecoder, FrameProtocol], data: bytes
|
||||
) -> Union[bytes, CloseReason]:
|
||||
return data
|
||||
|
||||
def frame_inbound_complete(
|
||||
self, proto: Union[FrameDecoder, FrameProtocol], fin: bool
|
||||
) -> Union[bytes, CloseReason, None]:
|
||||
pass
|
||||
|
||||
def frame_outbound(
|
||||
self,
|
||||
proto: Union[FrameDecoder, FrameProtocol],
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
data: bytes,
|
||||
fin: bool,
|
||||
) -> Tuple[RsvBits, bytes]:
|
||||
return (rsv, data)
|
||||
|
||||
|
||||
class PerMessageDeflate(Extension):
|
||||
name = "permessage-deflate"
|
||||
|
||||
DEFAULT_CLIENT_MAX_WINDOW_BITS = 15
|
||||
DEFAULT_SERVER_MAX_WINDOW_BITS = 15
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_no_context_takeover: bool = False,
|
||||
client_max_window_bits: Optional[int] = None,
|
||||
server_no_context_takeover: bool = False,
|
||||
server_max_window_bits: Optional[int] = None,
|
||||
) -> None:
|
||||
self.client_no_context_takeover = client_no_context_takeover
|
||||
self.server_no_context_takeover = server_no_context_takeover
|
||||
self._client_max_window_bits = self.DEFAULT_CLIENT_MAX_WINDOW_BITS
|
||||
self._server_max_window_bits = self.DEFAULT_SERVER_MAX_WINDOW_BITS
|
||||
if client_max_window_bits is not None:
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
if server_max_window_bits is not None:
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
|
||||
self._compressor: Optional[zlib._Compress] = None # noqa
|
||||
self._decompressor: Optional[zlib._Decompress] = None # noqa
|
||||
# This refers to the current frame
|
||||
self._inbound_is_compressible: Optional[bool] = None
|
||||
# This refers to the ongoing message (which might span multiple
|
||||
# frames). Only the first frame in a fragmented message is flagged for
|
||||
# compression, so this carries that bit forward.
|
||||
self._inbound_compressed: Optional[bool] = None
|
||||
|
||||
self._enabled = False
|
||||
|
||||
@property
|
||||
def client_max_window_bits(self) -> int:
|
||||
return self._client_max_window_bits
|
||||
|
||||
@client_max_window_bits.setter
|
||||
def client_max_window_bits(self, value: int) -> None:
|
||||
if value < 9 or value > 15:
|
||||
raise ValueError("Window size must be between 9 and 15 inclusive")
|
||||
self._client_max_window_bits = value
|
||||
|
||||
@property
|
||||
def server_max_window_bits(self) -> int:
|
||||
return self._server_max_window_bits
|
||||
|
||||
@server_max_window_bits.setter
|
||||
def server_max_window_bits(self, value: int) -> None:
|
||||
if value < 9 or value > 15:
|
||||
raise ValueError("Window size must be between 9 and 15 inclusive")
|
||||
self._server_max_window_bits = value
|
||||
|
||||
def _compressible_opcode(self, opcode: Opcode) -> bool:
|
||||
return opcode in (Opcode.TEXT, Opcode.BINARY, Opcode.CONTINUATION)
|
||||
|
||||
def enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
def offer(self) -> Union[bool, str]:
|
||||
parameters = [
|
||||
"client_max_window_bits=%d" % self.client_max_window_bits,
|
||||
"server_max_window_bits=%d" % self.server_max_window_bits,
|
||||
]
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
parameters.append("client_no_context_takeover")
|
||||
if self.server_no_context_takeover:
|
||||
parameters.append("server_no_context_takeover")
|
||||
|
||||
return "; ".join(parameters)
|
||||
|
||||
def finalize(self, offer: str) -> None:
|
||||
bits = [b.strip() for b in offer.split(";")]
|
||||
for bit in bits[1:]:
|
||||
if bit.startswith("client_no_context_takeover"):
|
||||
self.client_no_context_takeover = True
|
||||
elif bit.startswith("server_no_context_takeover"):
|
||||
self.server_no_context_takeover = True
|
||||
elif bit.startswith("client_max_window_bits"):
|
||||
self.client_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
elif bit.startswith("server_max_window_bits"):
|
||||
self.server_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
|
||||
self._enabled = True
|
||||
|
||||
def _parse_params(self, params: str) -> Tuple[Optional[int], Optional[int]]:
|
||||
client_max_window_bits = None
|
||||
server_max_window_bits = None
|
||||
|
||||
bits = [b.strip() for b in params.split(";")]
|
||||
for bit in bits[1:]:
|
||||
if bit.startswith("client_no_context_takeover"):
|
||||
self.client_no_context_takeover = True
|
||||
elif bit.startswith("server_no_context_takeover"):
|
||||
self.server_no_context_takeover = True
|
||||
elif bit.startswith("client_max_window_bits"):
|
||||
if "=" in bit:
|
||||
client_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
else:
|
||||
client_max_window_bits = self.client_max_window_bits
|
||||
elif bit.startswith("server_max_window_bits"):
|
||||
if "=" in bit:
|
||||
server_max_window_bits = int(bit.split("=", 1)[1].strip())
|
||||
else:
|
||||
server_max_window_bits = self.server_max_window_bits
|
||||
|
||||
return client_max_window_bits, server_max_window_bits
|
||||
|
||||
def accept(self, offer: str) -> Union[bool, None, str]:
|
||||
client_max_window_bits, server_max_window_bits = self._parse_params(offer)
|
||||
|
||||
parameters = []
|
||||
|
||||
if self.client_no_context_takeover:
|
||||
parameters.append("client_no_context_takeover")
|
||||
if self.server_no_context_takeover:
|
||||
parameters.append("server_no_context_takeover")
|
||||
try:
|
||||
if client_max_window_bits is not None:
|
||||
parameters.append("client_max_window_bits=%d" % client_max_window_bits)
|
||||
self.client_max_window_bits = client_max_window_bits
|
||||
if server_max_window_bits is not None:
|
||||
parameters.append("server_max_window_bits=%d" % server_max_window_bits)
|
||||
self.server_max_window_bits = server_max_window_bits
|
||||
except ValueError:
|
||||
return None
|
||||
else:
|
||||
self._enabled = True
|
||||
return "; ".join(parameters)
|
||||
|
||||
def frame_inbound_header(
|
||||
self,
|
||||
proto: Union[FrameDecoder, FrameProtocol],
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
payload_length: int,
|
||||
) -> Union[CloseReason, RsvBits]:
|
||||
if rsv.rsv1 and opcode.iscontrol():
|
||||
return CloseReason.PROTOCOL_ERROR
|
||||
if rsv.rsv1 and opcode is Opcode.CONTINUATION:
|
||||
return CloseReason.PROTOCOL_ERROR
|
||||
|
||||
self._inbound_is_compressible = self._compressible_opcode(opcode)
|
||||
|
||||
if self._inbound_compressed is None:
|
||||
self._inbound_compressed = rsv.rsv1
|
||||
if self._inbound_compressed:
|
||||
assert self._inbound_is_compressible
|
||||
if proto.client:
|
||||
bits = self.server_max_window_bits
|
||||
else:
|
||||
bits = self.client_max_window_bits
|
||||
if self._decompressor is None:
|
||||
self._decompressor = zlib.decompressobj(-int(bits))
|
||||
|
||||
return RsvBits(True, False, False)
|
||||
|
||||
def frame_inbound_payload_data(
|
||||
self, proto: Union[FrameDecoder, FrameProtocol], data: bytes
|
||||
) -> Union[bytes, CloseReason]:
|
||||
if not self._inbound_compressed or not self._inbound_is_compressible:
|
||||
return data
|
||||
assert self._decompressor is not None
|
||||
|
||||
try:
|
||||
return self._decompressor.decompress(bytes(data))
|
||||
except zlib.error:
|
||||
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
|
||||
|
||||
def frame_inbound_complete(
|
||||
self, proto: Union[FrameDecoder, FrameProtocol], fin: bool
|
||||
) -> Union[bytes, CloseReason, None]:
|
||||
if not fin:
|
||||
return None
|
||||
if not self._inbound_is_compressible:
|
||||
self._inbound_compressed = None
|
||||
return None
|
||||
if not self._inbound_compressed:
|
||||
self._inbound_compressed = None
|
||||
return None
|
||||
assert self._decompressor is not None
|
||||
|
||||
try:
|
||||
data = self._decompressor.decompress(b"\x00\x00\xff\xff")
|
||||
data += self._decompressor.flush()
|
||||
except zlib.error:
|
||||
return CloseReason.INVALID_FRAME_PAYLOAD_DATA
|
||||
|
||||
if proto.client:
|
||||
no_context_takeover = self.server_no_context_takeover
|
||||
else:
|
||||
no_context_takeover = self.client_no_context_takeover
|
||||
|
||||
if no_context_takeover:
|
||||
self._decompressor = None
|
||||
|
||||
self._inbound_compressed = None
|
||||
|
||||
return data
|
||||
|
||||
def frame_outbound(
|
||||
self,
|
||||
proto: Union[FrameDecoder, FrameProtocol],
|
||||
opcode: Opcode,
|
||||
rsv: RsvBits,
|
||||
data: bytes,
|
||||
fin: bool,
|
||||
) -> Tuple[RsvBits, bytes]:
|
||||
if not self._compressible_opcode(opcode):
|
||||
return (rsv, data)
|
||||
|
||||
if opcode is not Opcode.CONTINUATION:
|
||||
rsv = RsvBits(True, *rsv[1:])
|
||||
|
||||
if self._compressor is None:
|
||||
assert opcode is not Opcode.CONTINUATION
|
||||
if proto.client:
|
||||
bits = self.client_max_window_bits
|
||||
else:
|
||||
bits = self.server_max_window_bits
|
||||
self._compressor = zlib.compressobj(
|
||||
zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits)
|
||||
)
|
||||
|
||||
data = self._compressor.compress(bytes(data))
|
||||
|
||||
if fin:
|
||||
data += self._compressor.flush(zlib.Z_SYNC_FLUSH)
|
||||
data = data[:-4]
|
||||
|
||||
if proto.client:
|
||||
no_context_takeover = self.client_no_context_takeover
|
||||
else:
|
||||
no_context_takeover = self.server_no_context_takeover
|
||||
|
||||
if no_context_takeover:
|
||||
self._compressor = None
|
||||
|
||||
return (rsv, data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
descr = ["client_max_window_bits=%d" % self.client_max_window_bits]
|
||||
if self.client_no_context_takeover:
|
||||
descr.append("client_no_context_takeover")
|
||||
descr.append("server_max_window_bits=%d" % self.server_max_window_bits)
|
||||
if self.server_no_context_takeover:
|
||||
descr.append("server_no_context_takeover")
|
||||
|
||||
return "<{} {}>".format(self.__class__.__name__, "; ".join(descr))
|
||||
|
||||
|
||||
#: SUPPORTED_EXTENSIONS maps all supported extension names to their class.
|
||||
#: This can be used to iterate all supported extensions of wsproto, instantiate
|
||||
#: new extensions based on their name, or check if a given extension is
|
||||
#: supported or not.
|
||||
SUPPORTED_EXTENSIONS = {PerMessageDeflate.name: PerMessageDeflate}
|
673
lib/python3.13/site-packages/wsproto/frame_protocol.py
Normal file
673
lib/python3.13/site-packages/wsproto/frame_protocol.py
Normal file
@ -0,0 +1,673 @@
|
||||
"""
|
||||
wsproto/frame_protocol
|
||||
~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
WebSocket frame protocol implementation.
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
from codecs import getincrementaldecoder, IncrementalDecoder
|
||||
from enum import IntEnum
|
||||
from typing import Generator, List, NamedTuple, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .extensions import Extension # pragma: no cover
|
||||
|
||||
|
||||
_XOR_TABLE = [bytes(a ^ b for a in range(256)) for b in range(256)]
|
||||
|
||||
|
||||
class XorMaskerSimple:
|
||||
def __init__(self, masking_key: bytes) -> None:
|
||||
self._masking_key = masking_key
|
||||
|
||||
def process(self, data: bytes) -> bytes:
|
||||
if data:
|
||||
data_array = bytearray(data)
|
||||
a, b, c, d = (_XOR_TABLE[n] for n in self._masking_key)
|
||||
data_array[::4] = data_array[::4].translate(a)
|
||||
data_array[1::4] = data_array[1::4].translate(b)
|
||||
data_array[2::4] = data_array[2::4].translate(c)
|
||||
data_array[3::4] = data_array[3::4].translate(d)
|
||||
|
||||
# Rotate the masking key so that the next usage continues
|
||||
# with the next key element, rather than restarting.
|
||||
key_rotation = len(data) % 4
|
||||
self._masking_key = (
|
||||
self._masking_key[key_rotation:] + self._masking_key[:key_rotation]
|
||||
)
|
||||
|
||||
return bytes(data_array)
|
||||
return data
|
||||
|
||||
|
||||
class XorMaskerNull:
|
||||
def process(self, data: bytes) -> bytes:
|
||||
return data
|
||||
|
||||
|
||||
# RFC6455, Section 5.2 - Base Framing Protocol
|
||||
|
||||
# Payload length constants
|
||||
PAYLOAD_LENGTH_TWO_BYTE = 126
|
||||
PAYLOAD_LENGTH_EIGHT_BYTE = 127
|
||||
MAX_PAYLOAD_NORMAL = 125
|
||||
MAX_PAYLOAD_TWO_BYTE = 2**16 - 1
|
||||
MAX_PAYLOAD_EIGHT_BYTE = 2**64 - 1
|
||||
MAX_FRAME_PAYLOAD = MAX_PAYLOAD_EIGHT_BYTE
|
||||
|
||||
# MASK and PAYLOAD LEN are packed into a byte
|
||||
MASK_MASK = 0x80
|
||||
PAYLOAD_LEN_MASK = 0x7F
|
||||
|
||||
# FIN, RSV[123] and OPCODE are packed into a single byte
|
||||
FIN_MASK = 0x80
|
||||
RSV1_MASK = 0x40
|
||||
RSV2_MASK = 0x20
|
||||
RSV3_MASK = 0x10
|
||||
OPCODE_MASK = 0x0F
|
||||
|
||||
|
||||
class Opcode(IntEnum):
|
||||
"""
|
||||
RFC 6455, Section 5.2 - Base Framing Protocol
|
||||
"""
|
||||
|
||||
#: Continuation frame
|
||||
CONTINUATION = 0x0
|
||||
|
||||
#: Text message
|
||||
TEXT = 0x1
|
||||
|
||||
#: Binary message
|
||||
BINARY = 0x2
|
||||
|
||||
#: Close frame
|
||||
CLOSE = 0x8
|
||||
|
||||
#: Ping frame
|
||||
PING = 0x9
|
||||
|
||||
#: Pong frame
|
||||
PONG = 0xA
|
||||
|
||||
def iscontrol(self) -> bool:
|
||||
return bool(self & 0x08)
|
||||
|
||||
|
||||
class CloseReason(IntEnum):
|
||||
"""
|
||||
RFC 6455, Section 7.4.1 - Defined Status Codes
|
||||
"""
|
||||
|
||||
#: indicates a normal closure, meaning that the purpose for
|
||||
#: which the connection was established has been fulfilled.
|
||||
NORMAL_CLOSURE = 1000
|
||||
|
||||
#: indicates that an endpoint is "going away", such as a server
|
||||
#: going down or a browser having navigated away from a page.
|
||||
GOING_AWAY = 1001
|
||||
|
||||
#: indicates that an endpoint is terminating the connection due
|
||||
#: to a protocol error.
|
||||
PROTOCOL_ERROR = 1002
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received a type of data it cannot accept (e.g., an
|
||||
#: endpoint that understands only text data MAY send this if it
|
||||
#: receives a binary message).
|
||||
UNSUPPORTED_DATA = 1003
|
||||
|
||||
#: Reserved. The specific meaning might be defined in the future.
|
||||
# DON'T DEFINE THIS: RESERVED_1004 = 1004
|
||||
|
||||
#: is a reserved value and MUST NOT be set as a status code in a
|
||||
#: Close control frame by an endpoint. It is designated for use in
|
||||
#: applications expecting a status code to indicate that no status
|
||||
#: code was actually present.
|
||||
NO_STATUS_RCVD = 1005
|
||||
|
||||
#: is a reserved value and MUST NOT be set as a status code in a
|
||||
#: Close control frame by an endpoint. It is designated for use in
|
||||
#: applications expecting a status code to indicate that the
|
||||
#: connection was closed abnormally, e.g., without sending or
|
||||
#: receiving a Close control frame.
|
||||
ABNORMAL_CLOSURE = 1006
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received data within a message that was not
|
||||
#: consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
|
||||
#: data within a text message).
|
||||
INVALID_FRAME_PAYLOAD_DATA = 1007
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received a message that violates its policy. This
|
||||
#: is a generic status code that can be returned when there is no
|
||||
#: other more suitable status code (e.g., 1003 or 1009) or if there
|
||||
#: is a need to hide specific details about the policy.
|
||||
POLICY_VIOLATION = 1008
|
||||
|
||||
#: indicates that an endpoint is terminating the connection
|
||||
#: because it has received a message that is too big for it to
|
||||
#: process.
|
||||
MESSAGE_TOO_BIG = 1009
|
||||
|
||||
#: indicates that an endpoint (client) is terminating the
|
||||
#: connection because it has expected the server to negotiate one or
|
||||
#: more extension, but the server didn't return them in the response
|
||||
#: message of the WebSocket handshake. The list of extensions that
|
||||
#: are needed SHOULD appear in the /reason/ part of the Close frame.
|
||||
#: Note that this status code is not used by the server, because it
|
||||
#: can fail the WebSocket handshake instead.
|
||||
MANDATORY_EXT = 1010
|
||||
|
||||
#: indicates that a server is terminating the connection because
|
||||
#: it encountered an unexpected condition that prevented it from
|
||||
#: fulfilling the request.
|
||||
INTERNAL_ERROR = 1011
|
||||
|
||||
#: Server/service is restarting
|
||||
#: (not part of RFC6455)
|
||||
SERVICE_RESTART = 1012
|
||||
|
||||
#: Temporary server condition forced blocking client's request
|
||||
#: (not part of RFC6455)
|
||||
TRY_AGAIN_LATER = 1013
|
||||
|
||||
#: is a reserved value and MUST NOT be set as a status code in a
|
||||
#: Close control frame by an endpoint. It is designated for use in
|
||||
#: applications expecting a status code to indicate that the
|
||||
#: connection was closed due to a failure to perform a TLS handshake
|
||||
#: (e.g., the server certificate can't be verified).
|
||||
TLS_HANDSHAKE_FAILED = 1015
|
||||
|
||||
|
||||
# RFC 6455, Section 7.4.1 - Defined Status Codes
|
||||
LOCAL_ONLY_CLOSE_REASONS = (
|
||||
CloseReason.NO_STATUS_RCVD,
|
||||
CloseReason.ABNORMAL_CLOSURE,
|
||||
CloseReason.TLS_HANDSHAKE_FAILED,
|
||||
)
|
||||
|
||||
|
||||
# RFC 6455, Section 7.4.2 - Status Code Ranges
|
||||
MIN_CLOSE_REASON = 1000
|
||||
MIN_PROTOCOL_CLOSE_REASON = 1000
|
||||
MAX_PROTOCOL_CLOSE_REASON = 2999
|
||||
MIN_LIBRARY_CLOSE_REASON = 3000
|
||||
MAX_LIBRARY_CLOSE_REASON = 3999
|
||||
MIN_PRIVATE_CLOSE_REASON = 4000
|
||||
MAX_PRIVATE_CLOSE_REASON = 4999
|
||||
MAX_CLOSE_REASON = 4999
|
||||
|
||||
|
||||
NULL_MASK = struct.pack("!I", 0)
|
||||
|
||||
|
||||
class ParseFailed(Exception):
|
||||
def __init__(
|
||||
self, msg: str, code: CloseReason = CloseReason.PROTOCOL_ERROR
|
||||
) -> None:
|
||||
super().__init__(msg)
|
||||
self.code = code
|
||||
|
||||
|
||||
class RsvBits(NamedTuple):
|
||||
rsv1: bool
|
||||
rsv2: bool
|
||||
rsv3: bool
|
||||
|
||||
|
||||
class Header(NamedTuple):
|
||||
fin: bool
|
||||
rsv: RsvBits
|
||||
opcode: Opcode
|
||||
payload_len: int
|
||||
masking_key: Optional[bytes]
|
||||
|
||||
|
||||
class Frame(NamedTuple):
|
||||
opcode: Opcode
|
||||
payload: Union[bytes, str, Tuple[int, str]]
|
||||
frame_finished: bool
|
||||
message_finished: bool
|
||||
|
||||
|
||||
def _truncate_utf8(data: bytes, nbytes: int) -> bytes:
|
||||
if len(data) <= nbytes:
|
||||
return data
|
||||
|
||||
# Truncate
|
||||
data = data[:nbytes]
|
||||
# But we might have cut a codepoint in half, in which case we want to
|
||||
# discard the partial character so the data is at least
|
||||
# well-formed. This is a little inefficient since it processes the
|
||||
# whole message twice when in theory we could just peek at the last
|
||||
# few characters, but since this is only used for close messages (max
|
||||
# length = 125 bytes) it really doesn't matter.
|
||||
data = data.decode("utf-8", errors="ignore").encode("utf-8")
|
||||
return data
|
||||
|
||||
|
||||
class Buffer:
|
||||
def __init__(self, initial_bytes: Optional[bytes] = None) -> None:
|
||||
self.buffer = bytearray()
|
||||
self.bytes_used = 0
|
||||
if initial_bytes:
|
||||
self.feed(initial_bytes)
|
||||
|
||||
def feed(self, new_bytes: bytes) -> None:
|
||||
self.buffer += new_bytes
|
||||
|
||||
def consume_at_most(self, nbytes: int) -> bytes:
|
||||
if not nbytes:
|
||||
return bytearray()
|
||||
|
||||
data = self.buffer[self.bytes_used : self.bytes_used + nbytes]
|
||||
self.bytes_used += len(data)
|
||||
return data
|
||||
|
||||
def consume_exactly(self, nbytes: int) -> Optional[bytes]:
|
||||
if len(self.buffer) - self.bytes_used < nbytes:
|
||||
return None
|
||||
|
||||
return self.consume_at_most(nbytes)
|
||||
|
||||
def commit(self) -> None:
|
||||
# In CPython 3.4+, del[:n] is amortized O(n), *not* quadratic
|
||||
del self.buffer[: self.bytes_used]
|
||||
self.bytes_used = 0
|
||||
|
||||
def rollback(self) -> None:
|
||||
self.bytes_used = 0
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.buffer)
|
||||
|
||||
|
||||
class MessageDecoder:
|
||||
def __init__(self) -> None:
|
||||
self.opcode: Optional[Opcode] = None
|
||||
self.decoder: Optional[IncrementalDecoder] = None
|
||||
|
||||
def process_frame(self, frame: Frame) -> Frame:
|
||||
assert not frame.opcode.iscontrol()
|
||||
|
||||
if self.opcode is None:
|
||||
if frame.opcode is Opcode.CONTINUATION:
|
||||
raise ParseFailed("unexpected CONTINUATION")
|
||||
self.opcode = frame.opcode
|
||||
elif frame.opcode is not Opcode.CONTINUATION:
|
||||
raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode)
|
||||
|
||||
if frame.opcode is Opcode.TEXT:
|
||||
self.decoder = getincrementaldecoder("utf-8")()
|
||||
|
||||
finished = frame.frame_finished and frame.message_finished
|
||||
|
||||
if self.decoder is None:
|
||||
data = frame.payload
|
||||
else:
|
||||
assert isinstance(frame.payload, (bytes, bytearray))
|
||||
try:
|
||||
data = self.decoder.decode(frame.payload, finished)
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ParseFailed(str(exc), CloseReason.INVALID_FRAME_PAYLOAD_DATA)
|
||||
|
||||
frame = Frame(self.opcode, data, frame.frame_finished, finished)
|
||||
|
||||
if finished:
|
||||
self.opcode = None
|
||||
self.decoder = None
|
||||
|
||||
return frame
|
||||
|
||||
|
||||
class FrameDecoder:
|
||||
def __init__(
|
||||
self, client: bool, extensions: Optional[List["Extension"]] = None
|
||||
) -> None:
|
||||
self.client = client
|
||||
self.extensions = extensions or []
|
||||
|
||||
self.buffer = Buffer()
|
||||
|
||||
self.header: Optional[Header] = None
|
||||
self.effective_opcode: Optional[Opcode] = None
|
||||
self.masker: Union[None, XorMaskerNull, XorMaskerSimple] = None
|
||||
self.payload_required = 0
|
||||
self.payload_consumed = 0
|
||||
|
||||
def receive_bytes(self, data: bytes) -> None:
|
||||
self.buffer.feed(data)
|
||||
|
||||
def process_buffer(self) -> Optional[Frame]:
|
||||
if not self.header:
|
||||
if not self.parse_header():
|
||||
return None
|
||||
# parse_header() sets these.
|
||||
assert self.header is not None
|
||||
assert self.masker is not None
|
||||
assert self.effective_opcode is not None
|
||||
|
||||
if len(self.buffer) < self.payload_required:
|
||||
return None
|
||||
|
||||
payload_remaining = self.header.payload_len - self.payload_consumed
|
||||
payload = self.buffer.consume_at_most(payload_remaining)
|
||||
if not payload and self.header.payload_len > 0:
|
||||
return None
|
||||
self.buffer.commit()
|
||||
|
||||
self.payload_consumed += len(payload)
|
||||
finished = self.payload_consumed == self.header.payload_len
|
||||
|
||||
payload = self.masker.process(payload)
|
||||
|
||||
for extension in self.extensions:
|
||||
payload_ = extension.frame_inbound_payload_data(self, payload)
|
||||
if isinstance(payload_, CloseReason):
|
||||
raise ParseFailed("error in extension", payload_)
|
||||
payload = payload_
|
||||
|
||||
if finished:
|
||||
final = bytearray()
|
||||
for extension in self.extensions:
|
||||
result = extension.frame_inbound_complete(self, self.header.fin)
|
||||
if isinstance(result, CloseReason):
|
||||
raise ParseFailed("error in extension", result)
|
||||
if result is not None:
|
||||
final += result
|
||||
payload += final
|
||||
|
||||
frame = Frame(self.effective_opcode, payload, finished, self.header.fin)
|
||||
|
||||
if finished:
|
||||
self.header = None
|
||||
self.effective_opcode = None
|
||||
self.masker = None
|
||||
else:
|
||||
self.effective_opcode = Opcode.CONTINUATION
|
||||
|
||||
return frame
|
||||
|
||||
def parse_header(self) -> bool:
|
||||
data = self.buffer.consume_exactly(2)
|
||||
if data is None:
|
||||
self.buffer.rollback()
|
||||
return False
|
||||
|
||||
fin = bool(data[0] & FIN_MASK)
|
||||
rsv = RsvBits(
|
||||
bool(data[0] & RSV1_MASK),
|
||||
bool(data[0] & RSV2_MASK),
|
||||
bool(data[0] & RSV3_MASK),
|
||||
)
|
||||
opcode = data[0] & OPCODE_MASK
|
||||
try:
|
||||
opcode = Opcode(opcode)
|
||||
except ValueError:
|
||||
raise ParseFailed(f"Invalid opcode {opcode:#x}")
|
||||
|
||||
if opcode.iscontrol() and not fin:
|
||||
raise ParseFailed("Invalid attempt to fragment control frame")
|
||||
|
||||
has_mask = bool(data[1] & MASK_MASK)
|
||||
payload_len_short = data[1] & PAYLOAD_LEN_MASK
|
||||
payload_len = self.parse_extended_payload_length(opcode, payload_len_short)
|
||||
if payload_len is None:
|
||||
self.buffer.rollback()
|
||||
return False
|
||||
|
||||
self.extension_processing(opcode, rsv, payload_len)
|
||||
|
||||
if has_mask and self.client:
|
||||
raise ParseFailed("client received unexpected masked frame")
|
||||
if not has_mask and not self.client:
|
||||
raise ParseFailed("server received unexpected unmasked frame")
|
||||
if has_mask:
|
||||
masking_key = self.buffer.consume_exactly(4)
|
||||
if masking_key is None:
|
||||
self.buffer.rollback()
|
||||
return False
|
||||
self.masker = XorMaskerSimple(masking_key)
|
||||
else:
|
||||
self.masker = XorMaskerNull()
|
||||
|
||||
self.buffer.commit()
|
||||
self.header = Header(fin, rsv, opcode, payload_len, None)
|
||||
self.effective_opcode = self.header.opcode
|
||||
if self.header.opcode.iscontrol():
|
||||
self.payload_required = payload_len
|
||||
else:
|
||||
self.payload_required = 0
|
||||
self.payload_consumed = 0
|
||||
return True
|
||||
|
||||
def parse_extended_payload_length(
|
||||
self, opcode: Opcode, payload_len: int
|
||||
) -> Optional[int]:
|
||||
if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL:
|
||||
raise ParseFailed("Control frame with payload len > 125")
|
||||
if payload_len == PAYLOAD_LENGTH_TWO_BYTE:
|
||||
data = self.buffer.consume_exactly(2)
|
||||
if data is None:
|
||||
return None
|
||||
(payload_len,) = struct.unpack("!H", data)
|
||||
if payload_len <= MAX_PAYLOAD_NORMAL:
|
||||
raise ParseFailed(
|
||||
"Payload length used 2 bytes when 1 would have sufficed"
|
||||
)
|
||||
elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE:
|
||||
data = self.buffer.consume_exactly(8)
|
||||
if data is None:
|
||||
return None
|
||||
(payload_len,) = struct.unpack("!Q", data)
|
||||
if payload_len <= MAX_PAYLOAD_TWO_BYTE:
|
||||
raise ParseFailed(
|
||||
"Payload length used 8 bytes when 2 would have sufficed"
|
||||
)
|
||||
if payload_len >> 63:
|
||||
# I'm not sure why this is illegal, but that's what the RFC
|
||||
# says, so...
|
||||
raise ParseFailed("8-byte payload length with non-zero MSB")
|
||||
|
||||
return payload_len
|
||||
|
||||
def extension_processing(
|
||||
self, opcode: Opcode, rsv: RsvBits, payload_len: int
|
||||
) -> None:
|
||||
rsv_used = [False, False, False]
|
||||
for extension in self.extensions:
|
||||
result = extension.frame_inbound_header(self, opcode, rsv, payload_len)
|
||||
if isinstance(result, CloseReason):
|
||||
raise ParseFailed("error in extension", result)
|
||||
for bit, used in enumerate(result):
|
||||
if used:
|
||||
rsv_used[bit] = True
|
||||
for expected, found in zip(rsv_used, rsv):
|
||||
if found and not expected:
|
||||
raise ParseFailed("Reserved bit set unexpectedly")
|
||||
|
||||
|
||||
class FrameProtocol:
|
||||
def __init__(self, client: bool, extensions: List["Extension"]) -> None:
|
||||
self.client = client
|
||||
self.extensions = [ext for ext in extensions if ext.enabled()]
|
||||
|
||||
# Global state
|
||||
self._frame_decoder = FrameDecoder(self.client, self.extensions)
|
||||
self._message_decoder = MessageDecoder()
|
||||
self._parse_more = self._parse_more_gen()
|
||||
|
||||
self._outbound_opcode: Optional[Opcode] = None
|
||||
|
||||
def _process_close(self, frame: Frame) -> Frame:
|
||||
data = frame.payload
|
||||
assert isinstance(data, (bytes, bytearray))
|
||||
|
||||
if not data:
|
||||
# "If this Close control frame contains no status code, _The
|
||||
# WebSocket Connection Close Code_ is considered to be 1005"
|
||||
data = (CloseReason.NO_STATUS_RCVD, "")
|
||||
elif len(data) == 1:
|
||||
raise ParseFailed("CLOSE with 1 byte payload")
|
||||
else:
|
||||
(code,) = struct.unpack("!H", data[:2])
|
||||
if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON:
|
||||
raise ParseFailed("CLOSE with invalid code")
|
||||
try:
|
||||
code = CloseReason(code)
|
||||
except ValueError:
|
||||
pass
|
||||
if code in LOCAL_ONLY_CLOSE_REASONS:
|
||||
raise ParseFailed("remote CLOSE with local-only reason")
|
||||
if not isinstance(code, CloseReason) and code <= MAX_PROTOCOL_CLOSE_REASON:
|
||||
raise ParseFailed("CLOSE with unknown reserved code")
|
||||
try:
|
||||
reason = data[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise ParseFailed(
|
||||
"Error decoding CLOSE reason: " + str(exc),
|
||||
CloseReason.INVALID_FRAME_PAYLOAD_DATA,
|
||||
)
|
||||
data = (code, reason)
|
||||
|
||||
return Frame(frame.opcode, data, frame.frame_finished, frame.message_finished)
|
||||
|
||||
def _parse_more_gen(self) -> Generator[Optional[Frame], None, None]:
|
||||
# Consume as much as we can from self._buffer, yielding events, and
|
||||
# then yield None when we need more data. Or raise ParseFailed.
|
||||
|
||||
# XX FIXME this should probably be refactored so that we never see
|
||||
# disabled extensions in the first place...
|
||||
self.extensions = [ext for ext in self.extensions if ext.enabled()]
|
||||
closed = False
|
||||
|
||||
while not closed:
|
||||
frame = self._frame_decoder.process_buffer()
|
||||
|
||||
if frame is not None:
|
||||
if not frame.opcode.iscontrol():
|
||||
frame = self._message_decoder.process_frame(frame)
|
||||
elif frame.opcode == Opcode.CLOSE:
|
||||
frame = self._process_close(frame)
|
||||
closed = True
|
||||
|
||||
yield frame
|
||||
|
||||
def receive_bytes(self, data: bytes) -> None:
|
||||
self._frame_decoder.receive_bytes(data)
|
||||
|
||||
def received_frames(self) -> Generator[Frame, None, None]:
|
||||
for event in self._parse_more:
|
||||
if event is None:
|
||||
break
|
||||
else:
|
||||
yield event
|
||||
|
||||
def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> bytes:
|
||||
payload = bytearray()
|
||||
if code is CloseReason.NO_STATUS_RCVD:
|
||||
code = None
|
||||
if code is None and reason:
|
||||
raise TypeError("cannot specify a reason without a code")
|
||||
if code in LOCAL_ONLY_CLOSE_REASONS:
|
||||
code = CloseReason.NORMAL_CLOSURE
|
||||
if code is not None:
|
||||
payload += bytearray(struct.pack("!H", code))
|
||||
if reason is not None:
|
||||
payload += _truncate_utf8(
|
||||
reason.encode("utf-8"), MAX_PAYLOAD_NORMAL - 2
|
||||
)
|
||||
|
||||
return self._serialize_frame(Opcode.CLOSE, payload)
|
||||
|
||||
def ping(self, payload: bytes = b"") -> bytes:
|
||||
return self._serialize_frame(Opcode.PING, payload)
|
||||
|
||||
def pong(self, payload: bytes = b"") -> bytes:
|
||||
return self._serialize_frame(Opcode.PONG, payload)
|
||||
|
||||
def send_data(
|
||||
self, payload: Union[bytes, bytearray, str] = b"", fin: bool = True
|
||||
) -> bytes:
|
||||
if isinstance(payload, (bytes, bytearray, memoryview)):
|
||||
opcode = Opcode.BINARY
|
||||
elif isinstance(payload, str):
|
||||
opcode = Opcode.TEXT
|
||||
payload = payload.encode("utf-8")
|
||||
else:
|
||||
raise ValueError("Must provide bytes or text")
|
||||
|
||||
if self._outbound_opcode is None:
|
||||
self._outbound_opcode = opcode
|
||||
elif self._outbound_opcode is not opcode:
|
||||
raise TypeError("Data type mismatch inside message")
|
||||
else:
|
||||
opcode = Opcode.CONTINUATION
|
||||
|
||||
if fin:
|
||||
self._outbound_opcode = None
|
||||
|
||||
return self._serialize_frame(opcode, payload, fin)
|
||||
|
||||
def _make_fin_rsv_opcode(self, fin: bool, rsv: RsvBits, opcode: Opcode) -> int:
|
||||
fin_bits = int(fin) << 7
|
||||
rsv_bits = (int(rsv.rsv1) << 6) + (int(rsv.rsv2) << 5) + (int(rsv.rsv3) << 4)
|
||||
opcode_bits = int(opcode)
|
||||
|
||||
return fin_bits | rsv_bits | opcode_bits
|
||||
|
||||
def _serialize_frame(
|
||||
self, opcode: Opcode, payload: bytes = b"", fin: bool = True
|
||||
) -> bytes:
|
||||
rsv = RsvBits(False, False, False)
|
||||
for extension in reversed(self.extensions):
|
||||
rsv, payload = extension.frame_outbound(self, opcode, rsv, payload, fin)
|
||||
|
||||
fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode)
|
||||
|
||||
payload_length = len(payload)
|
||||
quad_payload = False
|
||||
if payload_length <= MAX_PAYLOAD_NORMAL:
|
||||
first_payload = payload_length
|
||||
second_payload = None
|
||||
elif payload_length <= MAX_PAYLOAD_TWO_BYTE:
|
||||
first_payload = PAYLOAD_LENGTH_TWO_BYTE
|
||||
second_payload = payload_length
|
||||
else:
|
||||
first_payload = PAYLOAD_LENGTH_EIGHT_BYTE
|
||||
second_payload = payload_length
|
||||
quad_payload = True
|
||||
|
||||
if self.client:
|
||||
first_payload |= 1 << 7
|
||||
|
||||
header = bytearray([fin_rsv_opcode, first_payload])
|
||||
if second_payload is not None:
|
||||
if opcode.iscontrol():
|
||||
raise ValueError("payload too long for control frame")
|
||||
if quad_payload:
|
||||
header += bytearray(struct.pack("!Q", second_payload))
|
||||
else:
|
||||
header += bytearray(struct.pack("!H", second_payload))
|
||||
|
||||
if self.client:
|
||||
# "The masking key is a 32-bit value chosen at random by the
|
||||
# client. When preparing a masked frame, the client MUST pick a
|
||||
# fresh masking key from the set of allowed 32-bit values. The
|
||||
# masking key needs to be unpredictable; thus, the masking key
|
||||
# MUST be derived from a strong source of entropy, and the masking
|
||||
# key for a given frame MUST NOT make it simple for a server/proxy
|
||||
# to predict the masking key for a subsequent frame. The
|
||||
# unpredictability of the masking key is essential to prevent
|
||||
# authors of malicious applications from selecting the bytes that
|
||||
# appear on the wire."
|
||||
# -- https://tools.ietf.org/html/rfc6455#section-5.3
|
||||
masking_key = os.urandom(4)
|
||||
masker = XorMaskerSimple(masking_key)
|
||||
return header + masking_key + masker.process(payload)
|
||||
|
||||
return header + payload
|
491
lib/python3.13/site-packages/wsproto/handshake.py
Normal file
491
lib/python3.13/site-packages/wsproto/handshake.py
Normal file
@ -0,0 +1,491 @@
|
||||
"""
|
||||
wsproto/handshake
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
An implementation of WebSocket handshakes.
|
||||
"""
|
||||
from collections import deque
|
||||
from typing import (
|
||||
cast,
|
||||
Deque,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import h11
|
||||
|
||||
from .connection import Connection, ConnectionState, ConnectionType
|
||||
from .events import AcceptConnection, Event, RejectConnection, RejectData, Request
|
||||
from .extensions import Extension
|
||||
from .typing import Headers
|
||||
from .utilities import (
|
||||
generate_accept_token,
|
||||
generate_nonce,
|
||||
LocalProtocolError,
|
||||
normed_header_dict,
|
||||
RemoteProtocolError,
|
||||
split_comma_header,
|
||||
)
|
||||
|
||||
# RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake
|
||||
WEBSOCKET_VERSION = b"13"
|
||||
|
||||
|
||||
class H11Handshake:
|
||||
"""A Handshake implementation for HTTP/1.1 connections."""
|
||||
|
||||
def __init__(self, connection_type: ConnectionType) -> None:
|
||||
self.client = connection_type is ConnectionType.CLIENT
|
||||
self._state = ConnectionState.CONNECTING
|
||||
|
||||
if self.client:
|
||||
self._h11_connection = h11.Connection(h11.CLIENT)
|
||||
else:
|
||||
self._h11_connection = h11.Connection(h11.SERVER)
|
||||
|
||||
self._connection: Optional[Connection] = None
|
||||
self._events: Deque[Event] = deque()
|
||||
self._initiating_request: Optional[Request] = None
|
||||
self._nonce: Optional[bytes] = None
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def connection(self) -> Optional[Connection]:
|
||||
"""Return the established connection.
|
||||
|
||||
This will either return the connection or raise a
|
||||
LocalProtocolError if the connection has not yet been
|
||||
established.
|
||||
|
||||
:rtype: h11.Connection
|
||||
"""
|
||||
return self._connection
|
||||
|
||||
def initiate_upgrade_connection(
|
||||
self, headers: Headers, path: Union[bytes, str]
|
||||
) -> None:
|
||||
"""Initiate an upgrade connection.
|
||||
|
||||
This should be used if the request has already be received and
|
||||
parsed.
|
||||
|
||||
:param list headers: HTTP headers represented as a list of 2-tuples.
|
||||
:param str path: A URL path.
|
||||
"""
|
||||
if self.client:
|
||||
raise LocalProtocolError(
|
||||
"Cannot initiate an upgrade connection when acting as the client"
|
||||
)
|
||||
upgrade_request = h11.Request(method=b"GET", target=path, headers=headers)
|
||||
h11_client = h11.Connection(h11.CLIENT)
|
||||
self.receive_data(h11_client.send(upgrade_request))
|
||||
|
||||
def send(self, event: Event) -> bytes:
|
||||
"""Send an event to the remote.
|
||||
|
||||
This will return the bytes to send based on the event or raise
|
||||
a LocalProtocolError if the event is not valid given the
|
||||
state.
|
||||
|
||||
:returns: Data to send to the WebSocket peer.
|
||||
:rtype: bytes
|
||||
"""
|
||||
data = b""
|
||||
if isinstance(event, Request):
|
||||
data += self._initiate_connection(event)
|
||||
elif isinstance(event, AcceptConnection):
|
||||
data += self._accept(event)
|
||||
elif isinstance(event, RejectConnection):
|
||||
data += self._reject(event)
|
||||
elif isinstance(event, RejectData):
|
||||
data += self._send_reject_data(event)
|
||||
else:
|
||||
raise LocalProtocolError(
|
||||
f"Event {event} cannot be sent during the handshake"
|
||||
)
|
||||
return data
|
||||
|
||||
def receive_data(self, data: Optional[bytes]) -> None:
|
||||
"""Receive data from the remote.
|
||||
|
||||
A list of events that the remote peer triggered by sending
|
||||
this data can be retrieved with :meth:`events`.
|
||||
|
||||
:param bytes data: Data received from the WebSocket peer.
|
||||
"""
|
||||
self._h11_connection.receive_data(data or b"")
|
||||
while True:
|
||||
try:
|
||||
event = self._h11_connection.next_event()
|
||||
except h11.RemoteProtocolError:
|
||||
raise RemoteProtocolError(
|
||||
"Bad HTTP message", event_hint=RejectConnection()
|
||||
)
|
||||
if (
|
||||
isinstance(event, h11.ConnectionClosed)
|
||||
or event is h11.NEED_DATA
|
||||
or event is h11.PAUSED
|
||||
):
|
||||
break
|
||||
|
||||
if self.client:
|
||||
if isinstance(event, h11.InformationalResponse):
|
||||
if event.status_code == 101:
|
||||
self._events.append(self._establish_client_connection(event))
|
||||
else:
|
||||
self._events.append(
|
||||
RejectConnection(
|
||||
headers=list(event.headers),
|
||||
status_code=event.status_code,
|
||||
has_body=False,
|
||||
)
|
||||
)
|
||||
self._state = ConnectionState.CLOSED
|
||||
elif isinstance(event, h11.Response):
|
||||
self._state = ConnectionState.REJECTING
|
||||
self._events.append(
|
||||
RejectConnection(
|
||||
headers=list(event.headers),
|
||||
status_code=event.status_code,
|
||||
has_body=True,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, h11.Data):
|
||||
self._events.append(
|
||||
RejectData(data=event.data, body_finished=False)
|
||||
)
|
||||
elif isinstance(event, h11.EndOfMessage):
|
||||
self._events.append(RejectData(data=b"", body_finished=True))
|
||||
self._state = ConnectionState.CLOSED
|
||||
else:
|
||||
if isinstance(event, h11.Request):
|
||||
self._events.append(self._process_connection_request(event))
|
||||
|
||||
def events(self) -> Generator[Event, None, None]:
|
||||
"""Return a generator that provides any events that have been generated
|
||||
by protocol activity.
|
||||
|
||||
:returns: a generator that yields H11 events.
|
||||
"""
|
||||
while self._events:
|
||||
yield self._events.popleft()
|
||||
|
||||
# Server mode methods
|
||||
|
||||
def _process_connection_request( # noqa: MC0001
|
||||
self, event: h11.Request
|
||||
) -> Request:
|
||||
if event.method != b"GET":
|
||||
raise RemoteProtocolError(
|
||||
"Request method must be GET", event_hint=RejectConnection()
|
||||
)
|
||||
connection_tokens = None
|
||||
extensions: List[str] = []
|
||||
host = None
|
||||
key = None
|
||||
subprotocols: List[str] = []
|
||||
upgrade = b""
|
||||
version = None
|
||||
headers: Headers = []
|
||||
for name, value in event.headers:
|
||||
name = name.lower()
|
||||
if name == b"connection":
|
||||
connection_tokens = split_comma_header(value)
|
||||
elif name == b"host":
|
||||
host = value.decode("idna")
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-extensions":
|
||||
extensions.extend(split_comma_header(value))
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-key":
|
||||
key = value
|
||||
elif name == b"sec-websocket-protocol":
|
||||
subprotocols.extend(split_comma_header(value))
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-version":
|
||||
version = value
|
||||
elif name == b"upgrade":
|
||||
upgrade = value
|
||||
headers.append((name, value))
|
||||
if connection_tokens is None or not any(
|
||||
token.lower() == "upgrade" for token in connection_tokens
|
||||
):
|
||||
raise RemoteProtocolError(
|
||||
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
|
||||
)
|
||||
if version != WEBSOCKET_VERSION:
|
||||
raise RemoteProtocolError(
|
||||
"Missing header, 'Sec-WebSocket-Version'",
|
||||
event_hint=RejectConnection(
|
||||
headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)],
|
||||
status_code=426 if version else 400,
|
||||
),
|
||||
)
|
||||
if key is None:
|
||||
raise RemoteProtocolError(
|
||||
"Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection()
|
||||
)
|
||||
if upgrade.lower() != b"websocket":
|
||||
raise RemoteProtocolError(
|
||||
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
|
||||
)
|
||||
if host is None:
|
||||
raise RemoteProtocolError(
|
||||
"Missing header, 'Host'", event_hint=RejectConnection()
|
||||
)
|
||||
|
||||
self._initiating_request = Request(
|
||||
extensions=extensions,
|
||||
extra_headers=headers,
|
||||
host=host,
|
||||
subprotocols=subprotocols,
|
||||
target=event.target.decode("ascii"),
|
||||
)
|
||||
return self._initiating_request
|
||||
|
||||
def _accept(self, event: AcceptConnection) -> bytes:
|
||||
# _accept is always called after _process_connection_request.
|
||||
assert self._initiating_request is not None
|
||||
request_headers = normed_header_dict(self._initiating_request.extra_headers)
|
||||
|
||||
nonce = request_headers[b"sec-websocket-key"]
|
||||
accept_token = generate_accept_token(nonce)
|
||||
|
||||
headers = [
|
||||
(b"Upgrade", b"WebSocket"),
|
||||
(b"Connection", b"Upgrade"),
|
||||
(b"Sec-WebSocket-Accept", accept_token),
|
||||
]
|
||||
|
||||
if event.subprotocol is not None:
|
||||
if event.subprotocol not in self._initiating_request.subprotocols:
|
||||
raise LocalProtocolError(f"unexpected subprotocol {event.subprotocol}")
|
||||
headers.append(
|
||||
(b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii"))
|
||||
)
|
||||
|
||||
if event.extensions:
|
||||
accepts = server_extensions_handshake(
|
||||
cast(Sequence[str], self._initiating_request.extensions),
|
||||
event.extensions,
|
||||
)
|
||||
if accepts:
|
||||
headers.append((b"Sec-WebSocket-Extensions", accepts))
|
||||
|
||||
response = h11.InformationalResponse(
|
||||
status_code=101, headers=headers + event.extra_headers
|
||||
)
|
||||
self._connection = Connection(
|
||||
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
|
||||
event.extensions,
|
||||
)
|
||||
self._state = ConnectionState.OPEN
|
||||
return self._h11_connection.send(response) or b""
|
||||
|
||||
def _reject(self, event: RejectConnection) -> bytes:
|
||||
if self.state != ConnectionState.CONNECTING:
|
||||
raise LocalProtocolError(
|
||||
"Connection cannot be rejected in state %s" % self.state
|
||||
)
|
||||
|
||||
headers = list(event.headers)
|
||||
if not event.has_body:
|
||||
headers.append((b"content-length", b"0"))
|
||||
response = h11.Response(status_code=event.status_code, headers=headers)
|
||||
data = self._h11_connection.send(response) or b""
|
||||
self._state = ConnectionState.REJECTING
|
||||
if not event.has_body:
|
||||
data += self._h11_connection.send(h11.EndOfMessage()) or b""
|
||||
self._state = ConnectionState.CLOSED
|
||||
return data
|
||||
|
||||
def _send_reject_data(self, event: RejectData) -> bytes:
|
||||
if self.state != ConnectionState.REJECTING:
|
||||
raise LocalProtocolError(
|
||||
f"Cannot send rejection data in state {self.state}"
|
||||
)
|
||||
|
||||
data = self._h11_connection.send(h11.Data(data=event.data)) or b""
|
||||
if event.body_finished:
|
||||
data += self._h11_connection.send(h11.EndOfMessage()) or b""
|
||||
self._state = ConnectionState.CLOSED
|
||||
return data
|
||||
|
||||
# Client mode methods
|
||||
|
||||
def _initiate_connection(self, request: Request) -> bytes:
|
||||
self._initiating_request = request
|
||||
self._nonce = generate_nonce()
|
||||
|
||||
headers = [
|
||||
(b"Host", request.host.encode("idna")),
|
||||
(b"Upgrade", b"WebSocket"),
|
||||
(b"Connection", b"Upgrade"),
|
||||
(b"Sec-WebSocket-Key", self._nonce),
|
||||
(b"Sec-WebSocket-Version", WEBSOCKET_VERSION),
|
||||
]
|
||||
|
||||
if request.subprotocols:
|
||||
headers.append(
|
||||
(
|
||||
b"Sec-WebSocket-Protocol",
|
||||
(", ".join(request.subprotocols)).encode("ascii"),
|
||||
)
|
||||
)
|
||||
|
||||
if request.extensions:
|
||||
offers: Dict[str, Union[str, bool]] = {}
|
||||
for e in request.extensions:
|
||||
assert isinstance(e, Extension)
|
||||
offers[e.name] = e.offer()
|
||||
extensions = []
|
||||
for name, params in offers.items():
|
||||
bname = name.encode("ascii")
|
||||
if isinstance(params, bool):
|
||||
if params:
|
||||
extensions.append(bname)
|
||||
else:
|
||||
extensions.append(b"%s; %s" % (bname, params.encode("ascii")))
|
||||
if extensions:
|
||||
headers.append((b"Sec-WebSocket-Extensions", b", ".join(extensions)))
|
||||
|
||||
upgrade = h11.Request(
|
||||
method=b"GET",
|
||||
target=request.target.encode("ascii"),
|
||||
headers=headers + request.extra_headers,
|
||||
)
|
||||
return self._h11_connection.send(upgrade) or b""
|
||||
|
||||
def _establish_client_connection(
|
||||
self, event: h11.InformationalResponse
|
||||
) -> AcceptConnection: # noqa: MC0001
|
||||
# _establish_client_connection is always called after _initiate_connection.
|
||||
assert self._initiating_request is not None
|
||||
assert self._nonce is not None
|
||||
|
||||
accept = None
|
||||
connection_tokens = None
|
||||
accepts: List[str] = []
|
||||
subprotocol = None
|
||||
upgrade = b""
|
||||
headers: Headers = []
|
||||
for name, value in event.headers:
|
||||
name = name.lower()
|
||||
if name == b"connection":
|
||||
connection_tokens = split_comma_header(value)
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-extensions":
|
||||
accepts = split_comma_header(value)
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-accept":
|
||||
accept = value
|
||||
continue # Skip appending to headers
|
||||
elif name == b"sec-websocket-protocol":
|
||||
subprotocol = value.decode("ascii")
|
||||
continue # Skip appending to headers
|
||||
elif name == b"upgrade":
|
||||
upgrade = value
|
||||
continue # Skip appending to headers
|
||||
headers.append((name, value))
|
||||
|
||||
if connection_tokens is None or not any(
|
||||
token.lower() == "upgrade" for token in connection_tokens
|
||||
):
|
||||
raise RemoteProtocolError(
|
||||
"Missing header, 'Connection: Upgrade'", event_hint=RejectConnection()
|
||||
)
|
||||
if upgrade.lower() != b"websocket":
|
||||
raise RemoteProtocolError(
|
||||
"Missing header, 'Upgrade: WebSocket'", event_hint=RejectConnection()
|
||||
)
|
||||
accept_token = generate_accept_token(self._nonce)
|
||||
if accept != accept_token:
|
||||
raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection())
|
||||
if subprotocol is not None:
|
||||
if subprotocol not in self._initiating_request.subprotocols:
|
||||
raise RemoteProtocolError(
|
||||
f"unrecognized subprotocol {subprotocol}",
|
||||
event_hint=RejectConnection(),
|
||||
)
|
||||
extensions = client_extensions_handshake(
|
||||
accepts, cast(Sequence[Extension], self._initiating_request.extensions)
|
||||
)
|
||||
|
||||
self._connection = Connection(
|
||||
ConnectionType.CLIENT if self.client else ConnectionType.SERVER,
|
||||
extensions,
|
||||
self._h11_connection.trailing_data[0],
|
||||
)
|
||||
self._state = ConnectionState.OPEN
|
||||
return AcceptConnection(
|
||||
extensions=extensions, extra_headers=headers, subprotocol=subprotocol
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "{}(client={}, state={})".format(
|
||||
self.__class__.__name__, self.client, self.state
|
||||
)
|
||||
|
||||
|
||||
def server_extensions_handshake(
|
||||
requested: Iterable[str], supported: List[Extension]
|
||||
) -> Optional[bytes]:
|
||||
"""Agree on the extensions to use returning an appropriate header value.
|
||||
|
||||
This returns None if there are no agreed extensions
|
||||
"""
|
||||
accepts: Dict[str, Union[bool, bytes]] = {}
|
||||
for offer in requested:
|
||||
name = offer.split(";", 1)[0].strip()
|
||||
for extension in supported:
|
||||
if extension.name == name:
|
||||
accept = extension.accept(offer)
|
||||
if isinstance(accept, bool):
|
||||
if accept:
|
||||
accepts[extension.name] = True
|
||||
elif accept is not None:
|
||||
accepts[extension.name] = accept.encode("ascii")
|
||||
|
||||
if accepts:
|
||||
extensions: List[bytes] = []
|
||||
for name, params in accepts.items():
|
||||
name_bytes = name.encode("ascii")
|
||||
if isinstance(params, bool):
|
||||
assert params
|
||||
extensions.append(name_bytes)
|
||||
else:
|
||||
if params == b"":
|
||||
extensions.append(b"%s" % (name_bytes))
|
||||
else:
|
||||
extensions.append(b"%s; %s" % (name_bytes, params))
|
||||
return b", ".join(extensions)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def client_extensions_handshake(
|
||||
accepted: Iterable[str], supported: Sequence[Extension]
|
||||
) -> List[Extension]:
|
||||
# This raises RemoteProtocolError is the accepted extension is not
|
||||
# supported.
|
||||
extensions = []
|
||||
for accept in accepted:
|
||||
name = accept.split(";", 1)[0].strip()
|
||||
for extension in supported:
|
||||
if extension.name == name:
|
||||
extension.finalize(accept)
|
||||
extensions.append(extension)
|
||||
break
|
||||
else:
|
||||
raise RemoteProtocolError(
|
||||
f"unrecognized extension {name}", event_hint=RejectConnection()
|
||||
)
|
||||
return extensions
|
1
lib/python3.13/site-packages/wsproto/py.typed
Normal file
1
lib/python3.13/site-packages/wsproto/py.typed
Normal file
@ -0,0 +1 @@
|
||||
Marker
|
3
lib/python3.13/site-packages/wsproto/typing.py
Normal file
3
lib/python3.13/site-packages/wsproto/typing.py
Normal file
@ -0,0 +1,3 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
Headers = List[Tuple[bytes, bytes]]
|
88
lib/python3.13/site-packages/wsproto/utilities.py
Normal file
88
lib/python3.13/site-packages/wsproto/utilities.py
Normal file
@ -0,0 +1,88 @@
|
||||
"""
|
||||
wsproto/utilities
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
Utility functions that do not belong in a separate module.
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from h11._headers import Headers as H11Headers
|
||||
|
||||
from .events import Event
|
||||
from .typing import Headers
|
||||
|
||||
# RFC6455, Section 1.3 - Opening Handshake
|
||||
ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
class ProtocolError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class LocalProtocolError(ProtocolError):
|
||||
"""Indicates an error due to local/programming errors.
|
||||
|
||||
This is raised when the connection is asked to do something that
|
||||
is either incompatible with the state or the websocket standard.
|
||||
|
||||
"""
|
||||
|
||||
pass # noqa
|
||||
|
||||
|
||||
class RemoteProtocolError(ProtocolError):
|
||||
"""Indicates an error due to the remote's actions.
|
||||
|
||||
This is raised when processing the bytes from the remote if the
|
||||
remote has sent data that is incompatible with the websocket
|
||||
standard.
|
||||
|
||||
.. attribute:: event_hint
|
||||
|
||||
This is a suggested wsproto Event to send to the client based
|
||||
on the error. It could be None if no hint is available.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, message: str, event_hint: Optional[Event] = None) -> None:
|
||||
self.event_hint = event_hint
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# Some convenience utilities for working with HTTP headers
|
||||
def normed_header_dict(h11_headers: Union[Headers, H11Headers]) -> Dict[bytes, bytes]:
|
||||
# This mangles Set-Cookie headers. But it happens that we don't care about
|
||||
# any of those, so it's OK. For every other HTTP header, if there are
|
||||
# multiple instances then you're allowed to join them together with
|
||||
# commas.
|
||||
name_to_values: Dict[bytes, List[bytes]] = {}
|
||||
for name, value in h11_headers:
|
||||
name_to_values.setdefault(name, []).append(value)
|
||||
name_to_normed_value = {}
|
||||
for name, values in name_to_values.items():
|
||||
name_to_normed_value[name] = b", ".join(values)
|
||||
return name_to_normed_value
|
||||
|
||||
|
||||
# We use this for parsing the proposed protocol list, and for parsing the
|
||||
# proposed and accepted extension lists. For the proposed protocol list it's
|
||||
# fine, because the ABNF is just 1#token. But for the extension lists, it's
|
||||
# wrong, because those can contain quoted strings, which can in turn contain
|
||||
# commas. XX FIXME
|
||||
def split_comma_header(value: bytes) -> List[str]:
|
||||
return [piece.decode("ascii").strip() for piece in value.split(b",")]
|
||||
|
||||
|
||||
def generate_nonce() -> bytes:
|
||||
# os.urandom may be overkill for this use case, but I don't think this
|
||||
# is a bottleneck, and better safe than sorry...
|
||||
return base64.b64encode(os.urandom(16))
|
||||
|
||||
|
||||
def generate_accept_token(token: bytes) -> bytes:
|
||||
accept_token = token + ACCEPT_GUID
|
||||
accept_token = hashlib.sha1(accept_token).digest()
|
||||
return base64.b64encode(accept_token)
|
Reference in New Issue
Block a user