Updated script that can be controled by Nodejs web app

This commit is contained in:
mac OS
2024-11-25 12:24:18 +07:00
parent c440eda1f4
commit 8b0ab2bd3a
8662 changed files with 1803808 additions and 34 deletions

View File

@ -0,0 +1,80 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
from typing import List, Tuple
import dns._features
import dns.asyncbackend
if dns._features.have("doq"):
import aioquic.quic.configuration # type: ignore
from dns._asyncbackend import NullContext
from dns.quic._asyncio import (
AsyncioQuicConnection,
AsyncioQuicManager,
AsyncioQuicStream,
)
from dns.quic._common import AsyncQuicConnection, AsyncQuicManager
from dns.quic._sync import SyncQuicConnection, SyncQuicManager, SyncQuicStream
have_quic = True
def null_factory(
*args, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
return NullContext(None)
def _asyncio_manager_factory(
context, *args, **kwargs # pylint: disable=unused-argument
):
return AsyncioQuicManager(*args, **kwargs)
# We have a context factory and a manager factory as for trio we need to have
# a nursery.
_async_factories = {"asyncio": (null_factory, _asyncio_manager_factory)}
if dns._features.have("trio"):
import trio
from dns.quic._trio import ( # pylint: disable=ungrouped-imports
TrioQuicConnection,
TrioQuicManager,
TrioQuicStream,
)
def _trio_context_factory():
return trio.open_nursery()
def _trio_manager_factory(context, *args, **kwargs):
return TrioQuicManager(context, *args, **kwargs)
_async_factories["trio"] = (_trio_context_factory, _trio_manager_factory)
def factories_for_backend(backend=None):
if backend is None:
backend = dns.asyncbackend.get_default_backend()
return _async_factories[backend.name()]
else: # pragma: no cover
have_quic = False
from typing import Any
class AsyncQuicStream: # type: ignore
pass
class AsyncQuicConnection: # type: ignore
async def make_stream(self) -> Any:
raise NotImplementedError
class SyncQuicStream: # type: ignore
pass
class SyncQuicConnection: # type: ignore
def make_stream(self) -> Any:
raise NotImplementedError
Headers = List[Tuple[bytes, bytes]]

View File

@ -0,0 +1,267 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import asyncio
import socket
import ssl
import struct
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.asyncbackend
import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
AsyncQuicConnection,
AsyncQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
class AsyncioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = asyncio.Condition()
async def _wait_for_wake_up(self):
async with self._wake_up:
await self._wake_up.wait()
async def wait_for(self, amount, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.have(amount):
return
self._expecting = amount
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
self._expecting = 0
async def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
if self._buffer.seen_end():
return
try:
await asyncio.wait_for(self._wait_for_wake_up(), timeout)
except TimeoutError:
raise dns.exception.Timeout
async def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
await self.wait_for_end(expiration)
return self._buffer.get_all()
else:
await self.wait_for(2, expiration)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size, expiration)
return self._buffer.get(size)
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class AsyncioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = None
self._handshake_complete = asyncio.Event()
self._socket_created = asyncio.Event()
self._wake_timer = asyncio.Condition()
self._receiver_task = None
self._sender_task = None
self._wake_pending = False
async def _receiver(self):
try:
af = dns.inet.af_for_address(self._address)
backend = dns.asyncbackend.get_backend("asyncio")
# Note that peer is a low-level address tuple, but make_socket() wants
# a high-level address tuple, so we convert.
self._socket = await backend.make_socket(
af, socket.SOCK_DGRAM, 0, self._source, (self._peer[0], self._peer[1])
)
self._socket_created.set()
async with self._socket:
while not self._done:
(datagram, address) = await self._socket.recvfrom(
QUIC_MAX_DATAGRAM, None
)
if address[0] != self._peer[0] or address[1] != self._peer[1]:
continue
self._connection.receive_datagram(datagram, address, time.time())
# Wake up the timer in case the sender is sleeping, as there may be
# stuff to send now.
await self._wakeup()
except Exception:
pass
finally:
self._done = True
await self._wakeup()
self._handshake_complete.set()
async def _wakeup(self):
self._wake_pending = True
async with self._wake_timer:
self._wake_timer.notify_all()
async def _wait_for_wake_timer(self):
async with self._wake_timer:
if not self._wake_pending:
await self._wake_timer.wait()
self._wake_pending = False
async def _sender(self):
await self._socket_created.wait()
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, address in datagrams:
assert address == self._peer
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
try:
await asyncio.wait_for(self._wait_for_wake_timer(), interval)
except Exception:
pass
self._handle_timer(expiration)
await self._handle_events()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
self._receiver_task.cancel()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
count = 0
await asyncio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
await self._wakeup()
def run(self):
if self._closed:
return
self._receiver_task = asyncio.Task(self._receiver())
self._sender_task = asyncio.Task(self._sender())
async def make_stream(self, timeout=None):
try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout)
except TimeoutError:
raise dns.exception.Timeout
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
async def close(self):
if not self._closed:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
# sender might be blocked on this, so set it
self._socket_created.set()
await self._wakeup()
try:
await self._receiver_task
except asyncio.CancelledError:
pass
try:
await self._sender_task
except asyncio.CancelledError:
pass
await self._socket.close()
class AsyncioQuicManager(AsyncQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, AsyncioQuicConnection, server_name, h3)
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
connection.run()
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False

View File

@ -0,0 +1,339 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import base64
import copy
import functools
import socket
import struct
import time
import urllib
from typing import Any, Optional
import aioquic.h3.connection # type: ignore
import aioquic.h3.events # type: ignore
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import dns.inet
QUIC_MAX_DATAGRAM = 2048
MAX_SESSION_TICKETS = 8
# If we hit the max sessions limit we will delete this many of the oldest connections.
# The value must be a integer > 0 and <= MAX_SESSION_TICKETS.
SESSIONS_TO_DELETE = MAX_SESSION_TICKETS // 4
class UnexpectedEOF(Exception):
pass
class Buffer:
def __init__(self):
self._buffer = b""
self._seen_end = False
def put(self, data, is_end):
if self._seen_end:
return
self._buffer += data
if is_end:
self._seen_end = True
def have(self, amount):
if len(self._buffer) >= amount:
return True
if self._seen_end:
raise UnexpectedEOF
return False
def seen_end(self):
return self._seen_end
def get(self, amount):
assert self.have(amount)
data = self._buffer[:amount]
self._buffer = self._buffer[amount:]
return data
def get_all(self):
assert self.seen_end()
data = self._buffer
self._buffer = b""
return data
class BaseQuicStream:
def __init__(self, connection, stream_id):
self._connection = connection
self._stream_id = stream_id
self._buffer = Buffer()
self._expecting = 0
self._headers = None
self._trailers = None
def id(self):
return self._stream_id
def headers(self):
return self._headers
def trailers(self):
return self._trailers
def _expiration_from_timeout(self, timeout):
if timeout is not None:
expiration = time.time() + timeout
else:
expiration = None
return expiration
def _timeout_from_expiration(self, expiration):
if expiration is not None:
timeout = max(expiration - time.time(), 0.0)
else:
timeout = None
return timeout
# Subclass must implement receive() as sync / async and which returns a message
# or raises.
# Subclass must implement send() as sync / async and which takes a message and
# an EOF indicator.
def send_h3(self, url, datagram, post=True):
if not self._connection.is_h3():
raise SyntaxError("cannot send H3 to a non-H3 connection")
url_parts = urllib.parse.urlparse(url)
path = url_parts.path.encode()
if post:
method = b"POST"
else:
method = b"GET"
path += b"?dns=" + base64.urlsafe_b64encode(datagram).rstrip(b"=")
headers = [
(b":method", method),
(b":scheme", url_parts.scheme.encode()),
(b":authority", url_parts.netloc.encode()),
(b":path", path),
(b"accept", b"application/dns-message"),
]
if post:
headers.extend(
[
(b"content-type", b"application/dns-message"),
(b"content-length", str(len(datagram)).encode()),
]
)
self._connection.send_headers(self._stream_id, headers, not post)
if post:
self._connection.send_data(self._stream_id, datagram, True)
def _encapsulate(self, datagram):
if self._connection.is_h3():
return datagram
l = len(datagram)
return struct.pack("!H", l) + datagram
def _common_add_input(self, data, is_end):
self._buffer.put(data, is_end)
try:
return (
self._expecting > 0 and self._buffer.have(self._expecting)
) or self._buffer.seen_end
except UnexpectedEOF:
return True
def _close(self):
self._connection.close_stream(self._stream_id)
self._buffer.put(b"", True) # send EOF in case we haven't seen it.
class BaseQuicConnection:
def __init__(
self,
connection,
address,
port,
source=None,
source_port=0,
manager=None,
):
self._done = False
self._connection = connection
self._address = address
self._port = port
self._closed = False
self._manager = manager
self._streams = {}
if manager.is_h3():
self._h3_conn = aioquic.h3.connection.H3Connection(connection, False)
else:
self._h3_conn = None
self._af = dns.inet.af_for_address(address)
self._peer = dns.inet.low_level_address_tuple((address, port))
if source is None and source_port != 0:
if self._af == socket.AF_INET:
source = "0.0.0.0"
elif self._af == socket.AF_INET6:
source = "::"
else:
raise NotImplementedError
if source:
self._source = (source, source_port)
else:
self._source = None
def is_h3(self):
return self._h3_conn is not None
def close_stream(self, stream_id):
del self._streams[stream_id]
def send_headers(self, stream_id, headers, is_end=False):
self._h3_conn.send_headers(stream_id, headers, is_end)
def send_data(self, stream_id, data, is_end=False):
self._h3_conn.send_data(stream_id, data, is_end)
def _get_timer_values(self, closed_is_special=True):
now = time.time()
expiration = self._connection.get_timer()
if expiration is None:
expiration = now + 3600 # arbitrary "big" value
interval = max(expiration - now, 0)
if self._closed and closed_is_special:
# lower sleep interval to avoid a race in the closing process
# which can lead to higher latency closing due to sleeping when
# we have events.
interval = min(interval, 0.05)
return (expiration, interval)
def _handle_timer(self, expiration):
now = time.time()
if expiration <= now:
self._connection.handle_timer(now)
class AsyncQuicConnection(BaseQuicConnection):
async def make_stream(self, timeout: Optional[float] = None) -> Any:
pass
class BaseQuicManager:
def __init__(
self, conf, verify_mode, connection_factory, server_name=None, h3=False
):
self._connections = {}
self._connection_factory = connection_factory
self._session_tickets = {}
self._tokens = {}
self._h3 = h3
if conf is None:
verify_path = None
if isinstance(verify_mode, str):
verify_path = verify_mode
verify_mode = True
if h3:
alpn_protocols = ["h3"]
else:
alpn_protocols = ["doq", "doq-i03"]
conf = aioquic.quic.configuration.QuicConfiguration(
alpn_protocols=alpn_protocols,
verify_mode=verify_mode,
server_name=server_name,
)
if verify_path is not None:
conf.load_verify_locations(verify_path)
self._conf = conf
def _connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
connection = self._connections.get((address, port))
if connection is not None:
return (connection, False)
conf = self._conf
if want_session_ticket:
try:
session_ticket = self._session_tickets.pop((address, port))
# We found a session ticket, so make a configuration that uses it.
conf = copy.copy(conf)
conf.session_ticket = session_ticket
except KeyError:
# No session ticket.
pass
# Whether or not we found a session ticket, we want a handler to save
# one.
session_ticket_handler = functools.partial(
self.save_session_ticket, address, port
)
else:
session_ticket_handler = None
if want_token:
try:
token = self._tokens.pop((address, port))
# We found a token, so make a configuration that uses it.
conf = copy.copy(conf)
conf.token = token
except KeyError:
# No token
pass
# Whether or not we found a token, we want a handler to save # one.
token_handler = functools.partial(self.save_token, address, port)
else:
token_handler = None
qconn = aioquic.quic.connection.QuicConnection(
configuration=conf,
session_ticket_handler=session_ticket_handler,
token_handler=token_handler,
)
lladdress = dns.inet.low_level_address_tuple((address, port))
qconn.connect(lladdress, time.time())
connection = self._connection_factory(
qconn, address, port, source, source_port, self
)
self._connections[(address, port)] = connection
return (connection, True)
def closed(self, address, port):
try:
del self._connections[(address, port)]
except KeyError:
pass
def is_h3(self):
return self._h3
def save_session_ticket(self, address, port, ticket):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._session_tickets)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._session_tickets.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._session_tickets[key]
self._session_tickets[(address, port)] = ticket
def save_token(self, address, port, token):
# We rely on dictionaries keys() being in insertion order here. We
# can't just popitem() as that would be LIFO which is the opposite of
# what we want.
l = len(self._tokens)
if l >= MAX_SESSION_TICKETS:
keys_to_delete = list(self._tokens.keys())[0:SESSIONS_TO_DELETE]
for key in keys_to_delete:
del self._tokens[key]
self._tokens[(address, port)] = token
class AsyncQuicManager(BaseQuicManager):
def connect(self, address, port=853, source=None, source_port=0):
raise NotImplementedError

View File

@ -0,0 +1,295 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import selectors
import socket
import ssl
import struct
import threading
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import dns.exception
import dns.inet
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
BaseQuicConnection,
BaseQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
# Function used to create a socket. Can be overridden if needed in special
# situations.
socket_factory = socket.socket
class SyncQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = threading.Condition()
self._lock = threading.Lock()
def wait_for(self, amount, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.have(amount):
return
self._expecting = amount
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
self._expecting = 0
def wait_for_end(self, expiration):
while True:
timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.seen_end():
return
with self._wake_up:
if not self._wake_up.wait(timeout):
raise dns.exception.Timeout
def receive(self, timeout=None):
expiration = self._expiration_from_timeout(timeout)
if self._connection.is_h3():
self.wait_for_end(expiration)
with self._lock:
return self._buffer.get_all()
else:
self.wait_for(2, expiration)
with self._lock:
(size,) = struct.unpack("!H", self._buffer.get(2))
self.wait_for(size, expiration)
with self._lock:
return self._buffer.get(size)
def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
self._connection.write(self._stream_id, data, is_end)
def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
with self._wake_up:
self._wake_up.notify()
def close(self):
with self._lock:
self._close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
with self._wake_up:
self._wake_up.notify()
return False
class SyncQuicConnection(BaseQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = socket_factory(self._af, socket.SOCK_DGRAM, 0)
if self._source is not None:
try:
self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
except Exception:
self._socket.close()
raise
self._socket.connect(self._peer)
(self._send_wakeup, self._receive_wakeup) = socket.socketpair()
self._receive_wakeup.setblocking(False)
self._socket.setblocking(False)
self._handshake_complete = threading.Event()
self._worker_thread = None
self._lock = threading.Lock()
def _read(self):
count = 0
while count < 10:
count += 1
try:
datagram = self._socket.recv(QUIC_MAX_DATAGRAM)
except BlockingIOError:
return
with self._lock:
self._connection.receive_datagram(datagram, self._peer, time.time())
def _drain_wakeup(self):
while True:
try:
self._receive_wakeup.recv(32)
except BlockingIOError:
return
def _worker(self):
try:
sel = selectors.DefaultSelector()
sel.register(self._socket, selectors.EVENT_READ, self._read)
sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
items = sel.select(interval)
for key, _ in items:
key.data()
with self._lock:
self._handle_timer(expiration)
self._handle_events()
with self._lock:
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
try:
self._socket.send(datagram)
except BlockingIOError:
# we let QUIC handle any lossage
pass
finally:
with self._lock:
self._done = True
self._socket.close()
# Ensure anyone waiting for this gets woken up.
self._handshake_complete.set()
def _handle_events(self):
while True:
with self._lock:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(h3_event.data, h3_event.stream_ended)
else:
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
with self._lock:
self._done = True
elif isinstance(event, aioquic.quic.events.StreamReset):
with self._lock:
stream = self._streams.get(event.stream_id)
if stream:
stream._add_input(b"", True)
def write(self, stream, data, is_end=False):
with self._lock:
self._connection.send_stream_data(stream, data, is_end)
self._send_wakeup.send(b"\x01")
def send_headers(self, stream_id, headers, is_end=False):
with self._lock:
super().send_headers(stream_id, headers, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def send_data(self, stream_id, data, is_end=False):
with self._lock:
super().send_data(stream_id, data, is_end)
if is_end:
self._send_wakeup.send(b"\x01")
def run(self):
if self._closed:
return
self._worker_thread = threading.Thread(target=self._worker)
self._worker_thread.start()
def make_stream(self, timeout=None):
if not self._handshake_complete.wait(timeout):
raise dns.exception.Timeout
with self._lock:
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
def close_stream(self, stream_id):
with self._lock:
super().close_stream(stream_id)
def close(self):
with self._lock:
if self._closed:
return
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_wakeup.send(b"\x01")
self._worker_thread.join()
class SyncQuicManager(BaseQuicManager):
def __init__(
self, conf=None, verify_mode=ssl.CERT_REQUIRED, server_name=None, h3=False
):
super().__init__(conf, verify_mode, SyncQuicConnection, server_name, h3)
self._lock = threading.Lock()
def connect(
self,
address,
port=853,
source=None,
source_port=0,
want_session_ticket=True,
want_token=True,
):
with self._lock:
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket, want_token
)
if start:
connection.run()
return connection
def closed(self, address, port):
with self._lock:
super().closed(address, port)
def save_session_ticket(self, address, port, ticket):
with self._lock:
super().save_session_ticket(address, port, ticket)
def save_token(self, address, port, token):
with self._lock:
super().save_token(address, port, token)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
connection.close()
return False

View File

@ -0,0 +1,246 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import socket
import ssl
import struct
import time
import aioquic.quic.configuration # type: ignore
import aioquic.quic.connection # type: ignore
import aioquic.quic.events # type: ignore
import trio
import dns.exception
import dns.inet
from dns._asyncbackend import NullContext
from dns.quic._common import (
QUIC_MAX_DATAGRAM,
AsyncQuicConnection,
AsyncQuicManager,
BaseQuicStream,
UnexpectedEOF,
)
class TrioQuicStream(BaseQuicStream):
def __init__(self, connection, stream_id):
super().__init__(connection, stream_id)
self._wake_up = trio.Condition()
async def wait_for(self, amount):
while True:
if self._buffer.have(amount):
return
self._expecting = amount
async with self._wake_up:
await self._wake_up.wait()
self._expecting = 0
async def wait_for_end(self):
while True:
if self._buffer.seen_end():
return
async with self._wake_up:
await self._wake_up.wait()
async def receive(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
if self._connection.is_h3():
await self.wait_for_end()
return self._buffer.get_all()
else:
await self.wait_for(2)
(size,) = struct.unpack("!H", self._buffer.get(2))
await self.wait_for(size)
return self._buffer.get(size)
raise dns.exception.Timeout
async def send(self, datagram, is_end=False):
data = self._encapsulate(datagram)
await self._connection.write(self._stream_id, data, is_end)
async def _add_input(self, data, is_end):
if self._common_add_input(data, is_end):
async with self._wake_up:
self._wake_up.notify()
async def close(self):
self._close()
# Streams are async context managers
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async with self._wake_up:
self._wake_up.notify()
return False
class TrioQuicConnection(AsyncQuicConnection):
def __init__(self, connection, address, port, source, source_port, manager=None):
super().__init__(connection, address, port, source, source_port, manager)
self._socket = trio.socket.socket(self._af, socket.SOCK_DGRAM, 0)
self._handshake_complete = trio.Event()
self._run_done = trio.Event()
self._worker_scope = None
self._send_pending = False
async def _worker(self):
try:
if self._source:
await self._socket.bind(
dns.inet.low_level_address_tuple(self._source, self._af)
)
await self._socket.connect(self._peer)
while not self._done:
(expiration, interval) = self._get_timer_values(False)
if self._send_pending:
# Do not block forever if sends are pending. Even though we
# have a wake-up mechanism if we've already started the blocking
# read, the possibility of context switching in send means that
# more writes can happen while we have no wake up context, so
# we need self._send_pending to avoid (effectively) a "lost wakeup"
# race.
interval = 0.0
with trio.CancelScope(
deadline=trio.current_time() + interval
) as self._worker_scope:
datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
self._connection.receive_datagram(datagram, self._peer, time.time())
self._worker_scope = None
self._handle_timer(expiration)
await self._handle_events()
# We clear this now, before sending anything, as sending can cause
# context switches that do more sends. We want to know if that
# happens so we don't block a long time on the recv() above.
self._send_pending = False
datagrams = self._connection.datagrams_to_send(time.time())
for datagram, _ in datagrams:
await self._socket.send(datagram)
finally:
self._done = True
self._socket.close()
self._handshake_complete.set()
async def _handle_events(self):
count = 0
while True:
event = self._connection.next_event()
if event is None:
return
if isinstance(event, aioquic.quic.events.StreamDataReceived):
if self.is_h3():
h3_events = self._h3_conn.handle_event(event)
for h3_event in h3_events:
if isinstance(h3_event, aioquic.h3.events.HeadersReceived):
stream = self._streams.get(event.stream_id)
if stream:
if stream._headers is None:
stream._headers = h3_event.headers
elif stream._trailers is None:
stream._trailers = h3_event.headers
if h3_event.stream_ended:
await stream._add_input(b"", True)
elif isinstance(h3_event, aioquic.h3.events.DataReceived):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(
h3_event.data, h3_event.stream_ended
)
else:
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(event.data, event.end_stream)
elif isinstance(event, aioquic.quic.events.HandshakeCompleted):
self._handshake_complete.set()
elif isinstance(event, aioquic.quic.events.ConnectionTerminated):
self._done = True
self._socket.close()
elif isinstance(event, aioquic.quic.events.StreamReset):
stream = self._streams.get(event.stream_id)
if stream:
await stream._add_input(b"", True)
count += 1
if count > 10:
# yield
count = 0
await trio.sleep(0)
async def write(self, stream, data, is_end=False):
self._connection.send_stream_data(stream, data, is_end)
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
async def run(self):
if self._closed:
return
async with trio.open_nursery() as nursery:
nursery.start_soon(self._worker)
self._run_done.set()
async def make_stream(self, timeout=None):
if timeout is None:
context = NullContext(None)
else:
context = trio.move_on_after(timeout)
with context:
await self._handshake_complete.wait()
if self._done:
raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream
return stream
raise dns.exception.Timeout
async def close(self):
if not self._closed:
self._manager.closed(self._peer[0], self._peer[1])
self._closed = True
self._connection.close()
self._send_pending = True
if self._worker_scope is not None:
self._worker_scope.cancel()
await self._run_done.wait()
class TrioQuicManager(AsyncQuicManager):
def __init__(
self,
nursery,
conf=None,
verify_mode=ssl.CERT_REQUIRED,
server_name=None,
h3=False,
):
super().__init__(conf, verify_mode, TrioQuicConnection, server_name, h3)
self._nursery = nursery
def connect(
self, address, port=853, source=None, source_port=0, want_session_ticket=True
):
(connection, start) = self._connect(
address, port, source, source_port, want_session_ticket
)
if start:
self._nursery.start_soon(connection.run)
return connection
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Copy the iterator into a list as exiting things will mutate the connections
# table.
connections = list(self._connections.values())
for connection in connections:
await connection.close()
return False