webdriver_template/telecli/lib/python3.11/site-packages/trio/_dtls.py

1374 lines
54 KiB
Python
Raw Normal View History

2024-08-10 14:48:21 +03:00
# Implementation of DTLS 1.2, using pyopenssl
# https://datatracker.ietf.org/doc/html/rfc6347
#
# OpenSSL's APIs for DTLS are extremely awkward and limited, which forces us to jump
# through a *lot* of hoops and implement important chunks of the protocol ourselves.
# Hopefully they fix this before implementing DTLS 1.3, because it's a very different
# protocol, and it's probably impossible to pull tricks like we do here.
from __future__ import annotations
import contextlib
import enum
import errno
import hmac
import os
import struct
import warnings
import weakref
from itertools import count
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Generic,
Iterable,
Iterator,
TypeVar,
Union,
)
from weakref import ReferenceType, WeakValueDictionary
import attrs
import trio
from ._util import NoPublicConstructor, final
if TYPE_CHECKING:
from types import TracebackType
# See DTLSEndpoint.__init__ for why this is imported here
from OpenSSL import SSL # noqa: TCH004
from typing_extensions import Self, TypeAlias, TypeVarTuple, Unpack
from trio.socket import SocketType
PosArgsT = TypeVarTuple("PosArgsT")
MAX_UDP_PACKET_SIZE = 65527
def packet_header_overhead(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 28
else:
return 48
def worst_case_mtu(sock: SocketType) -> int:
if sock.family == trio.socket.AF_INET:
return 576 - packet_header_overhead(sock)
else:
return 1280 - packet_header_overhead(sock)
def best_guess_mtu(sock: SocketType) -> int:
return 1500 - packet_header_overhead(sock)
# There are a bunch of different RFCs that define these codes, so for a
# comprehensive collection look here:
# https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml
class ContentType(enum.IntEnum):
change_cipher_spec = 20
alert = 21
handshake = 22
application_data = 23
heartbeat = 24
class HandshakeType(enum.IntEnum):
hello_request = 0
client_hello = 1
server_hello = 2
hello_verify_request = 3
new_session_ticket = 4
end_of_early_data = 4
encrypted_extensions = 8
certificate = 11
server_key_exchange = 12
certificate_request = 13
server_hello_done = 14
certificate_verify = 15
client_key_exchange = 16
finished = 20
certificate_url = 21
certificate_status = 22
supplemental_data = 23
key_update = 24
compressed_certificate = 25
ekt_key = 26
message_hash = 254
class ProtocolVersion:
DTLS10 = bytes([254, 255])
DTLS12 = bytes([254, 253])
EPOCH_MASK = 0xFFFF << (6 * 8)
# Conventions:
# - All functions that handle network data end in _untrusted.
# - All functions end in _untrusted MUST make sure that bad data from the
# network cannot *only* cause BadPacket to be raised. No IndexError or
# struct.error or whatever.
class BadPacket(Exception):
pass
# This checks that the DTLS 'epoch' field is 0, which is true iff we're in the
# initial handshake. It doesn't check the ContentType, because not all
# handshake messages have ContentType==handshake -- for example,
# ChangeCipherSpec is used during the handshake but has its own ContentType.
#
# Cannot fail.
def part_of_handshake_untrusted(packet: bytes) -> bool:
# If the packet is too short, then slicing will successfully return a
# short string, which will necessarily fail to match.
return packet[3:5] == b"\x00\x00"
# Cannot fail
def is_client_hello_untrusted(packet: bytes) -> bool:
try:
return (
packet[0] == ContentType.handshake
and packet[13] == HandshakeType.client_hello
)
except IndexError:
# Invalid DTLS record
return False
# DTLS records are:
# - 1 byte content type
# - 2 bytes version
# - 8 bytes epoch+seqno
# Technically this is 2 bytes epoch then 6 bytes seqno, but we treat it as
# a single 8-byte integer, where epoch changes are represented as jumping
# forward by 2**(6*8).
# - 2 bytes payload length (unsigned big-endian)
# - payload
RECORD_HEADER = struct.Struct("!B2sQH")
def to_hex(data: bytes) -> str: # pragma: no cover
return data.hex()
@attrs.frozen
class Record:
content_type: int
version: bytes = attrs.field(repr=to_hex)
epoch_seqno: int
payload: bytes = attrs.field(repr=to_hex)
def records_untrusted(packet: bytes) -> Iterator[Record]:
i = 0
while i < len(packet):
try:
ct, version, epoch_seqno, payload_len = RECORD_HEADER.unpack_from(packet, i)
# Marked as no-cover because at time of writing, this code is unreachable
# (records_untrusted only gets called on packets that are either trusted or that
# have passed is_client_hello_untrusted, which filters out short packets)
except struct.error as exc: # pragma: no cover
raise BadPacket("invalid record header") from exc
i += RECORD_HEADER.size
payload = packet[i : i + payload_len]
if len(payload) != payload_len:
raise BadPacket("short record")
i += payload_len
yield Record(ct, version, epoch_seqno, payload)
def encode_record(record: Record) -> bytes:
header = RECORD_HEADER.pack(
record.content_type,
record.version,
record.epoch_seqno,
len(record.payload),
)
return header + record.payload
# Handshake messages are:
# - 1 byte message type
# - 3 bytes total message length
# - 2 bytes message sequence number
# - 3 bytes fragment offset
# - 3 bytes fragment length
HANDSHAKE_MESSAGE_HEADER = struct.Struct("!B3sH3s3s")
@attrs.frozen
class HandshakeFragment:
msg_type: int
msg_len: int
msg_seq: int
frag_offset: int
frag_len: int
frag: bytes = attrs.field(repr=to_hex)
def decode_handshake_fragment_untrusted(payload: bytes) -> HandshakeFragment:
# Raises BadPacket if decoding fails
try:
(
msg_type,
msg_len_bytes,
msg_seq,
frag_offset_bytes,
frag_len_bytes,
) = HANDSHAKE_MESSAGE_HEADER.unpack_from(payload)
except struct.error as exc:
raise BadPacket("bad handshake message header") from exc
# 'struct' doesn't have built-in support for 24-bit integers, so we
# have to do it by hand. These can't fail.
msg_len = int.from_bytes(msg_len_bytes, "big")
frag_offset = int.from_bytes(frag_offset_bytes, "big")
frag_len = int.from_bytes(frag_len_bytes, "big")
frag = payload[HANDSHAKE_MESSAGE_HEADER.size :]
if len(frag) != frag_len:
raise BadPacket("handshake fragment length doesn't match record length")
return HandshakeFragment(
msg_type,
msg_len,
msg_seq,
frag_offset,
frag_len,
frag,
)
def encode_handshake_fragment(hsf: HandshakeFragment) -> bytes:
hs_header = HANDSHAKE_MESSAGE_HEADER.pack(
hsf.msg_type,
hsf.msg_len.to_bytes(3, "big"),
hsf.msg_seq,
hsf.frag_offset.to_bytes(3, "big"),
hsf.frag_len.to_bytes(3, "big"),
)
return hs_header + hsf.frag
def decode_client_hello_untrusted(packet: bytes) -> tuple[int, bytes, bytes]:
# Raises BadPacket if parsing fails
# Returns (record epoch_seqno, cookie from the packet, data that should be
# hashed into cookie)
try:
# ClientHello has to be the first record in the packet
record = next(records_untrusted(packet))
# no-cover because at time of writing, this is unreachable:
# decode_client_hello_untrusted is only called on packets that have passed
# is_client_hello_untrusted, which confirms the content type.
if record.content_type != ContentType.handshake: # pragma: no cover
raise BadPacket("not a handshake record")
fragment = decode_handshake_fragment_untrusted(record.payload)
if fragment.msg_type != HandshakeType.client_hello:
raise BadPacket("not a ClientHello")
# ClientHello can't be fragmented, because reassembly requires holding
# per-connection state, and we refuse to allocate per-connection state
# until after we get a valid ClientHello.
if fragment.frag_offset != 0:
raise BadPacket("fragmented ClientHello")
if fragment.frag_len != fragment.msg_len:
raise BadPacket("fragmented ClientHello")
# As per RFC 6347:
#
# When responding to a HelloVerifyRequest, the client MUST use the
# same parameter values (version, random, session_id, cipher_suites,
# compression_method) as it did in the original ClientHello. The
# server SHOULD use those values to generate its cookie and verify that
# they are correct upon cookie receipt.
#
# However, the record-layer framing can and will change (e.g. the
# second ClientHello will have a new record-layer sequence number). So
# we need to pull out the handshake message alone, discarding the
# record-layer stuff, and then we're going to hash all of it *except*
# the cookie.
body = fragment.frag
# ClientHello is:
#
# - 2 bytes client_version
# - 32 bytes random
# - 1 byte session_id length
# - session_id
# - 1 byte cookie length
# - cookie
# - everything else
#
# So to find the cookie, so we need to figure out how long the
# session_id is and skip past it.
session_id_len = body[2 + 32]
cookie_len_offset = 2 + 32 + 1 + session_id_len
cookie_len = body[cookie_len_offset]
cookie_start = cookie_len_offset + 1
cookie_end = cookie_start + cookie_len
before_cookie = body[:cookie_len_offset]
cookie = body[cookie_start:cookie_end]
after_cookie = body[cookie_end:]
if len(cookie) != cookie_len:
raise BadPacket("short cookie")
return (record.epoch_seqno, cookie, before_cookie + after_cookie)
except (struct.error, IndexError) as exc:
raise BadPacket("bad ClientHello") from exc
@attrs.frozen
class HandshakeMessage:
record_version: bytes = attrs.field(repr=to_hex)
msg_type: HandshakeType
msg_seq: int
body: bytearray = attrs.field(repr=to_hex)
# ChangeCipherSpec is part of the handshake, but it's not a "handshake
# message" and can't be fragmented the same way. Sigh.
@attrs.frozen
class PseudoHandshakeMessage:
record_version: bytes = attrs.field(repr=to_hex)
content_type: int
payload: bytes = attrs.field(repr=to_hex)
# The final record in a handshake is Finished, which is encrypted, can't be fragmented
# (at least by us), and keeps its record number (because it's in a new epoch). So we
# just pass it through unchanged. (Fortunately, the payload is only a single hash value,
# so the largest it will ever be is 64 bytes for a 512-bit hash. Which is small enough
# that it never requires fragmenting to fit into a UDP packet.
@attrs.frozen
class OpaqueHandshakeMessage:
record: Record
_AnyHandshakeMessage: TypeAlias = Union[
HandshakeMessage, PseudoHandshakeMessage, OpaqueHandshakeMessage
]
# This takes a raw outgoing handshake volley that openssl generated, and
# reconstructs the handshake messages inside it, so that we can repack them
# into records while retransmitting. So the data ought to be well-behaved --
# it's not coming from the network.
def decode_volley_trusted(
volley: bytes,
) -> list[_AnyHandshakeMessage]:
messages: list[_AnyHandshakeMessage] = []
messages_by_seq = {}
for record in records_untrusted(volley):
# ChangeCipherSpec isn't a handshake message, so it can't be fragmented.
# Handshake messages with epoch > 0 are encrypted, so we can't fragment them
# either. Fortunately, ChangeCipherSpec has a 1 byte payload, and the only
# encrypted handshake message is Finished, whose payload is a single hash value
# -- so 32 bytes for SHA-256, 64 for SHA-512, etc. Neither is going to be so
# large that it has to be fragmented to fit into a single packet.
if record.epoch_seqno & EPOCH_MASK:
messages.append(OpaqueHandshakeMessage(record))
elif record.content_type in (ContentType.change_cipher_spec, ContentType.alert):
messages.append(
PseudoHandshakeMessage(
record.version, record.content_type, record.payload
)
)
else:
assert record.content_type == ContentType.handshake
fragment = decode_handshake_fragment_untrusted(record.payload)
msg_type = HandshakeType(fragment.msg_type)
if fragment.msg_seq not in messages_by_seq:
msg = HandshakeMessage(
record.version,
msg_type,
fragment.msg_seq,
bytearray(fragment.msg_len),
)
messages.append(msg)
messages_by_seq[fragment.msg_seq] = msg
else:
msg = messages_by_seq[fragment.msg_seq]
assert msg.msg_type == fragment.msg_type
assert msg.msg_seq == fragment.msg_seq
assert len(msg.body) == fragment.msg_len
msg.body[
fragment.frag_offset : fragment.frag_offset + fragment.frag_len
] = fragment.frag
return messages
class RecordEncoder:
def __init__(self) -> None:
self._record_seq = count()
def set_first_record_number(self, n: int) -> None:
self._record_seq = count(n)
def encode_volley(
self,
messages: Iterable[_AnyHandshakeMessage],
mtu: int,
) -> list[bytearray]:
packets = []
packet = bytearray()
for message in messages:
if isinstance(message, OpaqueHandshakeMessage):
encoded = encode_record(message.record)
if mtu - len(packet) - len(encoded) <= 0:
packets.append(packet)
packet = bytearray()
packet += encoded
assert len(packet) <= mtu
elif isinstance(message, PseudoHandshakeMessage):
space = mtu - len(packet) - RECORD_HEADER.size - len(message.payload)
if space <= 0:
packets.append(packet)
packet = bytearray()
packet += RECORD_HEADER.pack(
message.content_type,
message.record_version,
next(self._record_seq),
len(message.payload),
)
packet += message.payload
assert len(packet) <= mtu
else:
msg_len_bytes = len(message.body).to_bytes(3, "big")
frag_offset = 0
frags_encoded = 0
# If message.body is empty, then we still want to encode it in one
# fragment, not zero.
while frag_offset < len(message.body) or not frags_encoded:
space = (
mtu
- len(packet)
- RECORD_HEADER.size
- HANDSHAKE_MESSAGE_HEADER.size
)
if space <= 0:
packets.append(packet)
packet = bytearray()
continue
frag = message.body[frag_offset : frag_offset + space]
frag_offset_bytes = frag_offset.to_bytes(3, "big")
frag_len_bytes = len(frag).to_bytes(3, "big")
frag_offset += len(frag)
packet += RECORD_HEADER.pack(
ContentType.handshake,
message.record_version,
next(self._record_seq),
HANDSHAKE_MESSAGE_HEADER.size + len(frag),
)
packet += HANDSHAKE_MESSAGE_HEADER.pack(
message.msg_type,
msg_len_bytes,
message.msg_seq,
frag_offset_bytes,
frag_len_bytes,
)
packet += frag
frags_encoded += 1
assert len(packet) <= mtu
if packet:
packets.append(packet)
return packets
# This bit requires implementing a bona fide cryptographic protocol, so even though it's
# a simple one let's take a moment to discuss the design.
#
# Our goal is to force new incoming handshakes that claim to be coming from a
# given ip:port to prove that they can also receive packets sent to that
# ip:port. (There's nothing in UDP to stop someone from forging the return
# address, and it's often used for stuff like DoS reflection attacks, where
# an attacker tries to trick us into sending data at some innocent victim.)
# For more details, see:
#
# https://datatracker.ietf.org/doc/html/rfc6347#section-4.2.1
#
# To do this, when we receive an initial ClientHello, we calculate a magic
# cookie, and send it back as a HelloVerifyRequest. Then the client sends us a
# second ClientHello, this time with the magic cookie included, and after we
# check that this cookie is valid we go ahead and start the handshake proper.
#
# So the magic cookie needs the following properties:
# - No-one can forge it without knowing our secret key
# - It ensures that the ip, port, and ClientHello contents from the response
# match those in the challenge
# - It expires after a short-ish period (so that if an attacker manages to steal one, it
# won't be useful for long)
# - It doesn't require storing any peer-specific state on our side
#
# To do that, we take the ip/port/ClientHello data and compute an HMAC of them, using a
# secret key we generate on startup. We also include:
#
# - The current time (using Trio's clock), rounded to the nearest 30 seconds
# - A random salt
#
# Then the cookie is the salt and the HMAC digest concatenated together.
#
# When verifying a cookie, we use the salt + new ip/port/ClientHello data to recompute
# the HMAC digest, for both the current time and the current time minus 30 seconds, and
# if either of them match, we consider the cookie good.
#
# Including the rounded-off time like this means that each cookie is good for at least
# 30 seconds, and possibly as much as 60 seconds.
#
# The salt is probably not necessary -- I'm pretty sure that all it does is make it hard
# for an attacker to figure out when our clock ticks over a 30 second boundary. Which is
# probably pretty harmless? But it's easier to add the salt than to convince myself that
# it's *completely* harmless, so, salt it is.
COOKIE_REFRESH_INTERVAL = 30 # seconds
KEY_BYTES = 32
COOKIE_HASH = "sha256"
SALT_BYTES = 8
# 32 bytes was the maximum cookie length in DTLS 1.0. DTLS 1.2 raised it to 255. I doubt
# there are any DTLS 1.0 implementations still in the wild, but really 32 bytes is
# plenty, and it also gets rid of a confusing warning in Wireshark output.
#
# We truncate the cookie to 32 bytes, of which 8 bytes is salt, so that leaves 24 bytes
# of truncated HMAC = 192 bit security, which is still massive overkill. (TCP uses 32
# *bits* for this.) HMAC truncation is explicitly noted as safe in RFC 2104:
# https://datatracker.ietf.org/doc/html/rfc2104#section-5
COOKIE_LENGTH = 32
def _current_cookie_tick() -> int:
return int(trio.current_time() / COOKIE_REFRESH_INTERVAL)
# Simple deterministic and invertible serializer -- i.e., a useful tool for converting
# structured data into something we can cryptographically sign.
def _signable(*fields: bytes) -> bytes:
out = []
for field in fields:
out.append(struct.pack("!Q", len(field)))
out.append(field)
return b"".join(out)
def _make_cookie(
key: bytes, salt: bytes, tick: int, address: Any, client_hello_bits: bytes
) -> bytes:
assert len(salt) == SALT_BYTES
assert len(key) == KEY_BYTES
signable_data = _signable(
salt,
struct.pack("!Q", tick),
# address is a mix of strings and ints, and variable length, so pack
# it into a single nested field
_signable(*(str(part).encode() for part in address)),
client_hello_bits,
)
return (salt + hmac.digest(key, signable_data, COOKIE_HASH))[:COOKIE_LENGTH]
def valid_cookie(
key: bytes, cookie: bytes, address: Any, client_hello_bits: bytes
) -> bool:
if len(cookie) > SALT_BYTES:
salt = cookie[:SALT_BYTES]
tick = _current_cookie_tick()
cur_cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
old_cookie = _make_cookie(
key, salt, max(tick - 1, 0), address, client_hello_bits
)
# I doubt using a short-circuiting 'or' here would leak any meaningful
# information, but why risk it when '|' is just as easy.
return hmac.compare_digest(cookie, cur_cookie) | hmac.compare_digest(
cookie, old_cookie
)
else:
return False
def challenge_for(
key: bytes, address: Any, epoch_seqno: int, client_hello_bits: bytes
) -> bytes:
salt = os.urandom(SALT_BYTES)
tick = _current_cookie_tick()
cookie = _make_cookie(key, salt, tick, address, client_hello_bits)
# HelloVerifyRequest body is:
# - 2 bytes version
# - length-prefixed cookie
#
# The DTLS 1.2 spec says that for this message specifically we should use
# the DTLS 1.0 version.
#
# (It also says the opposite of that, but that part is a mistake:
# https://www.rfc-editor.org/errata/eid4103
# ).
#
# And I guess we use this for both the message-level and record-level
# ProtocolVersions, since we haven't negotiated anything else yet?
body = ProtocolVersion.DTLS10 + bytes([len(cookie)]) + cookie
# RFC says have to copy the client's record number
# Errata says it should be handshake message number
# Openssl copies back record sequence number, and always sets message seq
# number 0. So I guess we'll follow openssl.
hs = HandshakeFragment(
msg_type=HandshakeType.hello_verify_request,
msg_len=len(body),
msg_seq=0,
frag_offset=0,
frag_len=len(body),
frag=body,
)
payload = encode_handshake_fragment(hs)
packet = encode_record(
Record(ContentType.handshake, ProtocolVersion.DTLS10, epoch_seqno, payload)
)
return packet
_T = TypeVar("_T")
class _Queue(Generic[_T]):
def __init__(self, incoming_packets_buffer: int | float): # noqa: PYI041
self.s, self.r = trio.open_memory_channel[_T](incoming_packets_buffer)
def _read_loop(read_fn: Callable[[int], bytes]) -> bytes:
chunks = []
while True:
try:
chunk = read_fn(2**14) # max TLS record size
except SSL.WantReadError:
break
chunks.append(chunk)
return b"".join(chunks)
async def handle_client_hello_untrusted(
endpoint: DTLSEndpoint, address: Any, packet: bytes
) -> None:
# it's trivial to write a simple function that directly calls this to
# get code coverage, but it should maybe:
# 1. be removed
# 2. be asserted
# 3. Write a complicated test case where this happens "organically"
if endpoint._listening_context is None: # pragma: no cover
return
try:
epoch_seqno, cookie, bits = decode_client_hello_untrusted(packet)
except BadPacket:
return
if endpoint._listening_key is None:
endpoint._listening_key = os.urandom(KEY_BYTES)
if not valid_cookie(endpoint._listening_key, cookie, address, bits):
challenge_packet = challenge_for(
endpoint._listening_key, address, epoch_seqno, bits
)
try:
async with endpoint._send_lock:
await endpoint.socket.sendto(challenge_packet, address)
except (OSError, trio.ClosedResourceError):
pass
else:
# We got a real, valid ClientHello!
stream = DTLSChannel._create(endpoint, address, endpoint._listening_context)
# Our HelloRetryRequest had some sequence number. We need our future sequence
# numbers to be larger than it, so our peer knows that our future records aren't
# stale/duplicates. But, we don't know what this sequence number was. What we do
# know is:
# - the HelloRetryRequest seqno was copied it from the initial ClientHello
# - the new ClientHello has a higher seqno than the initial ClientHello
# So, if we copy the new ClientHello's seqno into our first real handshake
# record and increment from there, that should work.
stream._record_encoder.set_first_record_number(epoch_seqno)
# Process the ClientHello
try:
stream._ssl.bio_write(packet)
stream._ssl.DTLSv1_listen()
except SSL.Error: # pragma: no cover
# ...OpenSSL didn't like it, so I guess we didn't have a valid ClientHello
# after all.
return
# Some old versions of OpenSSL have a bug with memory BIOs, where DTLSv1_listen
# consumes the ClientHello out of the BIO, but then do_handshake expects the
# ClientHello to still be in there (but not the one that ships with Ubuntu
# 20.04). In particular, this is known to affect the OpenSSL v1.1.1 that ships
# with Ubuntu 18.04. To work around this, we deliver a second copy of the
# ClientHello after DTLSv1_listen has completed. This is safe to do
# unconditionally, because on newer versions of OpenSSL, the second ClientHello
# is treated as a duplicate packet, which is a normal thing that can happen over
# UDP. For more details, see:
#
# https://github.com/pyca/pyopenssl/blob/e84e7b57d1838de70ab7a27089fbee78ce0d2106/tests/test_ssl.py#L4226-L4293
#
# This was fixed in v1.1.1a, and all later versions. So maybe in 2024 or so we
# can delete this. The fix landed in OpenSSL master as 079ef6bd534d2, and then
# was backported to the 1.1.1 branch as d1bfd8076e28.
stream._ssl.bio_write(packet)
# Check if we have an existing association
old_stream = endpoint._streams.get(address)
if old_stream is not None:
if old_stream._client_hello == (cookie, bits):
# ...This was just a duplicate of the last ClientHello, so never mind.
return
else:
# Ok, this *really is* a new handshake; the old stream should go away.
old_stream._set_replaced()
stream._client_hello = (cookie, bits)
endpoint._streams[address] = stream
endpoint._incoming_connections_q.s.send_nowait(stream)
async def dtls_receive_loop(
endpoint_ref: ReferenceType[DTLSEndpoint], sock: SocketType
) -> None:
try:
while True:
try:
packet, address = await sock.recvfrom(MAX_UDP_PACKET_SIZE)
except OSError as exc:
if exc.errno == errno.ECONNRESET:
# Windows only: "On a UDP-datagram socket [ECONNRESET]
# indicates a previous send operation resulted in an ICMP Port
# Unreachable message" -- https://docs.microsoft.com/en-us/windows/win32/api/winsock/nf-winsock-recvfrom
#
# This is totally useless -- there's nothing we can do with this
# information. So we just ignore it and retry the recv.
continue
else:
raise
endpoint = endpoint_ref()
try:
if endpoint is None:
return
if is_client_hello_untrusted(packet):
await handle_client_hello_untrusted(endpoint, address, packet)
elif address in endpoint._streams:
stream = endpoint._streams[address]
if stream._did_handshake and part_of_handshake_untrusted(packet):
# The peer just sent us more handshake messages, that aren't a
# ClientHello, and we thought the handshake was done. Some of
# the packets that we sent to finish the handshake must have
# gotten lost. So re-send them. We do this directly here instead
# of just putting it into the queue and letting the receiver do
# it, because there's no guarantee that anyone is reading from
# the queue, because we think the handshake is done!
await stream._resend_final_volley()
else:
try:
# mypy for some reason cannot determine type of _q
stream._q.s.send_nowait(packet) # type:ignore[has-type]
except trio.WouldBlock:
stream._packets_dropped_in_trio += 1
else:
# Drop packet
pass
finally:
del endpoint
except trio.ClosedResourceError:
# socket was closed
return
except OSError as exc:
if exc.errno in (errno.EBADF, errno.ENOTSOCK):
# socket was closed
return
else: # pragma: no cover
# ??? shouldn't happen
raise
@attrs.frozen
class DTLSChannelStatistics:
"""Currently this has only one attribute:
- ``incoming_packets_dropped_in_trio`` (``int``): Gives a count of the number of
incoming packets from this peer that Trio successfully received from the
network, but then got dropped because the internal channel buffer was full. If
this is non-zero, then you might want to call ``receive`` more often, or use a
larger ``incoming_packets_buffer``, or just not worry about it because your
UDP-based protocol should be able to handle the occasional lost packet, right?
"""
incoming_packets_dropped_in_trio: int
@final
class DTLSChannel(trio.abc.Channel[bytes], metaclass=NoPublicConstructor):
"""A DTLS connection.
This class has no public constructor you get instances by calling
`DTLSEndpoint.serve` or `~DTLSEndpoint.connect`.
.. attribute:: endpoint
The `DTLSEndpoint` that this connection is using.
.. attribute:: peer_address
The IP/port of the remote peer that this connection is associated with.
"""
def __init__(
self,
endpoint: DTLSEndpoint,
peer_address: Any,
ctx: SSL.Context,
) -> None:
self.endpoint = endpoint
self.peer_address = peer_address
self._packets_dropped_in_trio = 0
self._client_hello = None
self._did_handshake = False
# These are mandatory for all DTLS connections. OP_NO_QUERY_MTU is required to
# stop openssl from trying to query the memory BIO's MTU and then breaking, and
# OP_NO_RENEGOTIATION disables renegotiation, which is too complex for us to
# support and isn't useful anyway -- especially for DTLS where it's equivalent
# to just performing a new handshake.
ctx.set_options(
SSL.OP_NO_QUERY_MTU | SSL.OP_NO_RENEGOTIATION # type: ignore[attr-defined]
)
self._ssl = SSL.Connection(ctx)
self._handshake_mtu = 0
# This calls self._ssl.set_ciphertext_mtu, which is important, because if you
# don't call it then openssl doesn't work.
self.set_ciphertext_mtu(best_guess_mtu(self.endpoint.socket))
self._replaced = False
self._closed = False
self._q = _Queue[bytes](endpoint.incoming_packets_buffer)
self._handshake_lock = trio.Lock()
self._record_encoder: RecordEncoder = RecordEncoder()
self._final_volley: list[_AnyHandshakeMessage] = []
def _set_replaced(self) -> None:
self._replaced = True
# Any packets we already received could maybe possibly still be processed, but
# there are no more coming. So we close this on the sender side.
self._q.s.close()
def _check_replaced(self) -> None:
if self._replaced:
raise trio.BrokenResourceError(
"peer tore down this connection to start a new one"
)
# XX on systems where we can (maybe just Linux?) take advantage of the kernel's PMTU
# estimate
# XX should we send close-notify when closing? It seems particularly pointless for
# DTLS where packets are all independent and can be lost anyway. We do at least need
# to handle receiving it properly though, which might be easier if we send it...
def close(self) -> None:
"""Close this connection.
`DTLSChannel`\\s don't actually own any OS-level resources the
socket is owned by the `DTLSEndpoint`, not the individual connections. So
you don't really *have* to call this. But it will interrupt any other tasks
calling `receive` with a `ClosedResourceError`, and cause future attempts to use
this connection to fail.
You can also use this object as a synchronous or asynchronous context manager.
"""
if self._closed:
return
self._closed = True
if self.endpoint._streams.get(self.peer_address) is self:
del self.endpoint._streams[self.peer_address]
# Will wake any tasks waiting on self._q.get with a
# ClosedResourceError
self._q.r.close()
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
return self.close()
async def aclose(self) -> None:
"""Close this connection, but asynchronously.
This is included to satisfy the `trio.abc.Channel` contract. It's
identical to `close`, but async.
"""
self.close()
await trio.lowlevel.checkpoint()
async def _send_volley(self, volley_messages: list[_AnyHandshakeMessage]) -> None:
packets = self._record_encoder.encode_volley(
volley_messages, self._handshake_mtu
)
for packet in packets:
async with self.endpoint._send_lock:
await self.endpoint.socket.sendto(packet, self.peer_address)
async def _resend_final_volley(self) -> None:
await self._send_volley(self._final_volley)
async def do_handshake(self, *, initial_retransmit_timeout: float = 1.0) -> None:
"""Perform the handshake.
Calling this is optional if you don't, then it will be automatically called
the first time you call `send` or `receive`. But calling it explicitly can be
useful in case you want to control the retransmit timeout, use a cancel scope to
place an overall timeout on the handshake, or catch errors from the handshake
specifically.
It's safe to call this multiple times, or call it simultaneously from multiple
tasks the first call will perform the handshake, and the rest will be no-ops.
Args:
initial_retransmit_timeout (float): Since UDP is an unreliable protocol, it's
possible that some of the packets we send during the handshake will get
lost. To handle this, DTLS uses a timer to automatically retransmit
handshake packets that don't receive a response. This lets you set the
timeout we use to detect packet loss. Ideally, it should be set to ~1.5
times the round-trip time to your peer, but 1 second is a reasonable
default. There's `some useful guidance here
<https://tlswg.org/dtls13-spec/draft-ietf-tls-dtls13.html#name-timer-values>`__.
This is the *initial* timeout, because if packets keep being lost then Trio
will automatically back off to longer values, to avoid overloading the
network.
"""
async with self._handshake_lock:
if self._did_handshake:
return
timeout = initial_retransmit_timeout
volley_messages: list[_AnyHandshakeMessage] = []
volley_failed_sends = 0
def read_volley() -> list[_AnyHandshakeMessage]:
volley_bytes = _read_loop(self._ssl.bio_read)
new_volley_messages = decode_volley_trusted(volley_bytes)
if (
new_volley_messages
and volley_messages
and isinstance(new_volley_messages[0], HandshakeMessage)
and isinstance(volley_messages[0], HandshakeMessage)
and new_volley_messages[0].msg_seq == volley_messages[0].msg_seq
):
# openssl decided to retransmit; discard because we handle
# retransmits ourselves
return []
else:
return new_volley_messages
# If we're a client, we send the initial volley. If we're a server, then
# the initial ClientHello has already been inserted into self._ssl's
# read BIO. So either way, we start by generating a new volley.
with contextlib.suppress(SSL.WantReadError):
self._ssl.do_handshake()
volley_messages = read_volley()
# If we don't have messages to send in our initial volley, then something
# has gone very wrong. (I'm not sure this can actually happen without an
# error from OpenSSL, but we check just in case.)
if not volley_messages: # pragma: no cover
raise SSL.Error("something wrong with peer's ClientHello")
while True:
# -- at this point, we need to either send or re-send a volley --
assert volley_messages
self._check_replaced()
await self._send_volley(volley_messages)
# -- then this is where we wait for a reply --
self.endpoint._ensure_receive_loop()
with trio.move_on_after(timeout) as cscope:
async for packet in self._q.r:
self._ssl.bio_write(packet)
try:
self._ssl.do_handshake()
# We ignore generic SSL.Error here, because you can get those
# from random invalid packets
except (SSL.WantReadError, SSL.Error):
pass
else:
# No exception -> the handshake is done, and we can
# switch into data transfer mode.
self._did_handshake = True
# Might be empty, but that's ok -- we'll just send no
# packets.
self._final_volley = read_volley()
await self._send_volley(self._final_volley)
return
maybe_volley = read_volley()
if maybe_volley:
if (
isinstance(maybe_volley[0], PseudoHandshakeMessage)
and maybe_volley[0].content_type == ContentType.alert
):
# we're sending an alert (e.g. due to a corrupted
# packet). We want to send it once, but don't save it to
# retransmit -- keep the last volley as the current
# volley.
await self._send_volley(maybe_volley)
else:
# We managed to get all of the peer's volley and
# generate a new one ourselves! break out of the 'for'
# loop and restart the timer.
volley_messages = maybe_volley
# "Implementations SHOULD retain the current timer value
# until a transmission without loss occurs, at which
# time the value may be reset to the initial value."
if volley_failed_sends == 0:
timeout = initial_retransmit_timeout
volley_failed_sends = 0
break
else:
assert self._replaced
self._check_replaced()
if cscope.cancelled_caught:
# Timeout expired. Double timeout for backoff, with a limit of 60
# seconds (this matches what openssl does, and also the
# recommendation in draft-ietf-tls-dtls13).
timeout = min(2 * timeout, 60.0)
volley_failed_sends += 1
if volley_failed_sends == 2:
# We tried sending this twice and they both failed. Maybe our
# PMTU estimate is wrong? Let's try dropping it to the minimum
# and hope that helps.
self._handshake_mtu = min(
self._handshake_mtu, worst_case_mtu(self.endpoint.socket)
)
async def send(self, data: bytes) -> None:
"""Send a packet of data, securely."""
if self._closed:
raise trio.ClosedResourceError
if not data:
raise ValueError("openssl doesn't support sending empty DTLS packets")
if not self._did_handshake:
await self.do_handshake()
self._check_replaced()
self._ssl.write(data)
async with self.endpoint._send_lock:
await self.endpoint.socket.sendto(
_read_loop(self._ssl.bio_read), self.peer_address
)
async def receive(self) -> bytes:
"""Fetch the next packet of data from this connection's peer, waiting if
necessary.
This is safe to call from multiple tasks simultaneously, in case you have some
reason to do that. And more importantly, it's cancellation-safe, meaning that
cancelling a call to `receive` will never cause a packet to be lost or corrupt
the underlying connection.
"""
if not self._did_handshake:
await self.do_handshake()
# If the packet isn't really valid, then openssl can decode it to the empty
# string (e.g. b/c it's a late-arriving handshake packet, or a duplicate copy of
# a data packet). Skip over these instead of returning them.
while True:
try:
packet = await self._q.r.receive()
except trio.EndOfChannel:
assert self._replaced
self._check_replaced()
self._ssl.bio_write(packet)
cleartext = _read_loop(self._ssl.read)
if cleartext:
return cleartext
def set_ciphertext_mtu(self, new_mtu: int) -> None:
"""Tells Trio the `largest amount of data that can be sent in a single packet to
this peer <https://en.wikipedia.org/wiki/Maximum_transmission_unit>`__.
Trio doesn't actually enforce this limit if you pass a huge packet to `send`,
then we'll dutifully encrypt it and attempt to send it. But calling this method
does have two useful effects:
- If called before the handshake is performed, then Trio will automatically
fragment handshake messages to fit within the given MTU. It also might
fragment them even smaller, if it detects signs of packet loss, so setting
this should never be necessary to make a successful connection. But, the
packet loss detection only happens after multiple timeouts have expired, so if
you have reason to believe that a smaller MTU is required, then you can set
this to skip those timeouts and establish the connection more quickly.
- It changes the value returned from `get_cleartext_mtu`. So if you have some
kind of estimate of the network-level MTU, then you can use this to figure out
how much overhead DTLS will need for hashes/padding/etc., and how much space
you have left for your application data.
The MTU here is measuring the largest UDP *payload* you think can be sent, the
amount of encrypted data that can be handed to the operating system in a single
call to `send`. It should *not* include IP/UDP headers. Note that OS estimates
of the MTU often are link-layer MTUs, so you have to subtract off 28 bytes on
IPv4 and 48 bytes on IPv6 to get the ciphertext MTU.
By default, Trio assumes an MTU of 1472 bytes on IPv4, and 1452 bytes on IPv6,
which correspond to the common Ethernet MTU of 1500 bytes after accounting for
IP/UDP overhead.
"""
self._handshake_mtu = new_mtu
self._ssl.set_ciphertext_mtu(new_mtu)
def get_cleartext_mtu(self) -> int:
"""Returns the largest number of bytes that you can pass in a single call to
`send` while still fitting within the network-level MTU.
See `set_ciphertext_mtu` for more details.
"""
if not self._did_handshake:
raise trio.NeedHandshakeError
return self._ssl.get_cleartext_mtu() # type: ignore[no-any-return]
def statistics(self) -> DTLSChannelStatistics:
"""Returns a `DTLSChannelStatistics` object with statistics about this connection."""
return DTLSChannelStatistics(self._packets_dropped_in_trio)
@final
class DTLSEndpoint:
"""A DTLS endpoint.
A single UDP socket can handle arbitrarily many DTLS connections simultaneously,
acting as a client or server as needed. A `DTLSEndpoint` object holds a UDP socket
and manages these connections, which are represented as `DTLSChannel` objects.
Args:
socket: (trio.socket.SocketType): A ``SOCK_DGRAM`` socket. If you want to accept
incoming connections in server mode, then you should probably bind the socket to
some known port.
incoming_packets_buffer (int): Each `DTLSChannel` using this socket has its own
buffer that holds incoming packets until you call `~DTLSChannel.receive` to read
them. This lets you adjust the size of this buffer. `~DTLSChannel.statistics`
lets you check if the buffer has overflowed.
.. attribute:: socket
incoming_packets_buffer
Both constructor arguments are also exposed as attributes, in case you need to
access them later.
"""
def __init__(
self,
socket: SocketType,
*,
incoming_packets_buffer: int = 10,
) -> None:
# We do this lazily on first construction, so only people who actually use DTLS
# have to install PyOpenSSL.
global SSL
from OpenSSL import SSL
# for __del__, in case the next line raises
self._initialized: bool = False
if socket.type != trio.socket.SOCK_DGRAM:
raise ValueError("DTLS requires a SOCK_DGRAM socket")
self._initialized = True
self.socket: SocketType = socket
self.incoming_packets_buffer = incoming_packets_buffer
self._token = trio.lowlevel.current_trio_token()
# We don't need to track handshaking vs non-handshake connections
# separately. We only keep one connection per remote address; as soon
# as a peer provides a valid cookie, we can immediately tear down the
# old connection.
# {remote address: DTLSChannel}
self._streams: WeakValueDictionary[Any, DTLSChannel] = WeakValueDictionary()
self._listening_context: SSL.Context | None = None
self._listening_key: bytes | None = None
self._incoming_connections_q = _Queue[DTLSChannel](float("inf"))
self._send_lock = trio.Lock()
self._closed = False
self._receive_loop_spawned = False
def _ensure_receive_loop(self) -> None:
# We have to spawn this lazily, because on Windows it will immediately error out
# if the socket isn't already bound -- which for clients might not happen until
# after we send our first packet.
if not self._receive_loop_spawned:
trio.lowlevel.spawn_system_task(
dtls_receive_loop, weakref.ref(self), self.socket
)
self._receive_loop_spawned = True
def __del__(self) -> None:
# Do nothing if this object was never fully constructed
if not self._initialized:
return
# Close the socket in Trio context (if our Trio context still exists), so that
# the background task gets notified about the closure and can exit.
if not self._closed:
with contextlib.suppress(RuntimeError):
self._token.run_sync_soon(self.close)
# Do this last, because it might raise an exception
warnings.warn(
f"unclosed DTLS endpoint {self!r}",
ResourceWarning,
source=self,
stacklevel=1,
)
def close(self) -> None:
"""Close this socket, and all associated DTLS connections.
This object can also be used as a context manager.
"""
self._closed = True
self.socket.close()
for stream in list(self._streams.values()):
stream.close()
self._incoming_connections_q.s.close()
def __enter__(self) -> Self:
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
return self.close()
def _check_closed(self) -> None:
if self._closed:
raise trio.ClosedResourceError
async def serve(
self,
ssl_context: SSL.Context,
async_fn: Callable[[DTLSChannel, Unpack[PosArgsT]], Awaitable[object]],
*args: Unpack[PosArgsT],
task_status: trio.TaskStatus[None] = trio.TASK_STATUS_IGNORED,
) -> None:
"""Listen for incoming connections, and spawn a handler for each using an
internal nursery.
Similar to `~trio.serve_tcp`, this function never returns until cancelled, or
the `DTLSEndpoint` is closed and all handlers have exited.
Usage commonly looks like::
async def handler(dtls_channel):
...
async with trio.open_nursery() as nursery:
await nursery.start(dtls_endpoint.serve, ssl_context, handler)
# ... do other things here ...
The ``dtls_channel`` passed into the handler function has already performed the
"cookie exchange" part of the DTLS handshake, so the peer address is
trustworthy. But the actual cryptographic handshake doesn't happen until you
start using it, giving you a chance for any last minute configuration, and the
option to catch and handle handshake errors.
Args:
ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
incoming connections.
async_fn: The handler function that will be invoked for each incoming
connection.
*args: Additional arguments to pass to the handler function.
"""
self._check_closed()
if self._listening_context is not None:
raise trio.BusyResourceError("another task is already listening")
try:
self.socket.getsockname()
except OSError:
# TODO: Write test that triggers this
raise RuntimeError( # pragma: no cover
"DTLS socket must be bound before it can serve"
) from None
self._ensure_receive_loop()
# We do cookie verification ourselves, so tell OpenSSL not to worry about it.
# (See also _inject_client_hello_untrusted.)
ssl_context.set_cookie_verify_callback(lambda *_: True)
try:
self._listening_context = ssl_context
task_status.started()
async def handler_wrapper(stream: DTLSChannel) -> None:
with stream:
await async_fn(stream, *args)
async with trio.open_nursery() as nursery:
async for stream in self._incoming_connections_q.r: # pragma: no branch
nursery.start_soon(handler_wrapper, stream)
finally:
self._listening_context = None
def connect(
self,
address: tuple[str, int],
ssl_context: SSL.Context,
) -> DTLSChannel:
"""Initiate an outgoing DTLS connection.
Notice that this is a synchronous method. That's because it doesn't actually
initiate any I/O it just sets up a `DTLSChannel` object. The actual handshake
doesn't occur until you start using the `DTLSChannel`. This gives you a chance
to do further configuration first, like setting MTU etc.
Args:
address: The address to connect to. Usually a (host, port) tuple, like
``("127.0.0.1", 12345)``.
ssl_context (OpenSSL.SSL.Context): The PyOpenSSL context object to use for
this connection.
Returns:
DTLSChannel
"""
# it would be nice if we could detect when 'address' is our own endpoint (a
# loopback connection), because that can't work
# but I don't see how to do it reliably
self._check_closed()
channel = DTLSChannel._create(self, address, ssl_context)
channel._ssl.set_connect_state()
old_channel = self._streams.get(address)
if old_channel is not None:
old_channel._set_replaced()
self._streams[address] = channel
return channel