573 lines
16 KiB
Python
573 lines
16 KiB
Python
from typing import Any, Callable, Generator, List
|
|
|
|
import pytest
|
|
|
|
from .._events import (
|
|
ConnectionClosed,
|
|
Data,
|
|
EndOfMessage,
|
|
Event,
|
|
InformationalResponse,
|
|
Request,
|
|
Response,
|
|
)
|
|
from .._headers import Headers, normalize_and_validate
|
|
from .._readers import (
|
|
_obsolete_line_fold,
|
|
ChunkedReader,
|
|
ContentLengthReader,
|
|
Http10Reader,
|
|
READERS,
|
|
)
|
|
from .._receivebuffer import ReceiveBuffer
|
|
from .._state import (
|
|
CLIENT,
|
|
CLOSED,
|
|
DONE,
|
|
IDLE,
|
|
MIGHT_SWITCH_PROTOCOL,
|
|
MUST_CLOSE,
|
|
SEND_BODY,
|
|
SEND_RESPONSE,
|
|
SERVER,
|
|
SWITCHED_PROTOCOL,
|
|
)
|
|
from .._util import LocalProtocolError
|
|
from .._writers import (
|
|
ChunkedWriter,
|
|
ContentLengthWriter,
|
|
Http10Writer,
|
|
write_any_response,
|
|
write_headers,
|
|
write_request,
|
|
WRITERS,
|
|
)
|
|
from .helpers import normalize_data_events
|
|
|
|
SIMPLE_CASES = [
|
|
(
|
|
(CLIENT, IDLE),
|
|
Request(
|
|
method="GET",
|
|
target="/a",
|
|
headers=[("Host", "foo"), ("Connection", "close")],
|
|
),
|
|
b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n",
|
|
),
|
|
(
|
|
(SERVER, SEND_RESPONSE),
|
|
Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"),
|
|
b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n",
|
|
),
|
|
(
|
|
(SERVER, SEND_RESPONSE),
|
|
Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type]
|
|
b"HTTP/1.1 200 OK\r\n\r\n",
|
|
),
|
|
(
|
|
(SERVER, SEND_RESPONSE),
|
|
InformationalResponse(
|
|
status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade"
|
|
),
|
|
b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n",
|
|
),
|
|
(
|
|
(SERVER, SEND_RESPONSE),
|
|
InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type]
|
|
b"HTTP/1.1 101 Upgrade\r\n\r\n",
|
|
),
|
|
]
|
|
|
|
|
|
def dowrite(writer: Callable[..., None], obj: Any) -> bytes:
|
|
got_list: List[bytes] = []
|
|
writer(obj, got_list.append)
|
|
return b"".join(got_list)
|
|
|
|
|
|
def tw(writer: Any, obj: Any, expected: Any) -> None:
|
|
got = dowrite(writer, obj)
|
|
assert got == expected
|
|
|
|
|
|
def makebuf(data: bytes) -> ReceiveBuffer:
|
|
buf = ReceiveBuffer()
|
|
buf += data
|
|
return buf
|
|
|
|
|
|
def tr(reader: Any, data: bytes, expected: Any) -> None:
|
|
def check(got: Any) -> None:
|
|
assert got == expected
|
|
# Headers should always be returned as bytes, not e.g. bytearray
|
|
# https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478
|
|
for name, value in getattr(got, "headers", []):
|
|
assert type(name) is bytes
|
|
assert type(value) is bytes
|
|
|
|
# Simple: consume whole thing
|
|
buf = makebuf(data)
|
|
check(reader(buf))
|
|
assert not buf
|
|
|
|
# Incrementally growing buffer
|
|
buf = ReceiveBuffer()
|
|
for i in range(len(data)):
|
|
assert reader(buf) is None
|
|
buf += data[i : i + 1]
|
|
check(reader(buf))
|
|
|
|
# Trailing data
|
|
buf = makebuf(data)
|
|
buf += b"trailing"
|
|
check(reader(buf))
|
|
assert bytes(buf) == b"trailing"
|
|
|
|
|
|
def test_writers_simple() -> None:
|
|
for ((role, state), event, binary) in SIMPLE_CASES:
|
|
tw(WRITERS[role, state], event, binary)
|
|
|
|
|
|
def test_readers_simple() -> None:
|
|
for ((role, state), event, binary) in SIMPLE_CASES:
|
|
tr(READERS[role, state], binary, event)
|
|
|
|
|
|
def test_writers_unusual() -> None:
|
|
# Simple test of the write_headers utility routine
|
|
tw(
|
|
write_headers,
|
|
normalize_and_validate([("foo", "bar"), ("baz", "quux")]),
|
|
b"foo: bar\r\nbaz: quux\r\n\r\n",
|
|
)
|
|
tw(write_headers, Headers([]), b"\r\n")
|
|
|
|
# We understand HTTP/1.0, but we don't speak it
|
|
with pytest.raises(LocalProtocolError):
|
|
tw(
|
|
write_request,
|
|
Request(
|
|
method="GET",
|
|
target="/",
|
|
headers=[("Host", "foo"), ("Connection", "close")],
|
|
http_version="1.0",
|
|
),
|
|
None,
|
|
)
|
|
with pytest.raises(LocalProtocolError):
|
|
tw(
|
|
write_any_response,
|
|
Response(
|
|
status_code=200, headers=[("Connection", "close")], http_version="1.0"
|
|
),
|
|
None,
|
|
)
|
|
|
|
|
|
def test_readers_unusual() -> None:
|
|
# Reading HTTP/1.0
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.0\r\nSome: header\r\n\r\n",
|
|
Request(
|
|
method="HEAD",
|
|
target="/foo",
|
|
headers=[("Some", "header")],
|
|
http_version="1.0",
|
|
),
|
|
)
|
|
|
|
# check no-headers, since it's only legal with HTTP/1.0
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.0\r\n\r\n",
|
|
Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type]
|
|
)
|
|
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.0 200 OK\r\nSome: header\r\n\r\n",
|
|
Response(
|
|
status_code=200,
|
|
headers=[("Some", "header")],
|
|
http_version="1.0",
|
|
reason=b"OK",
|
|
),
|
|
)
|
|
|
|
# single-character header values (actually disallowed by the ABNF in RFC
|
|
# 7230 -- this is a bug in the standard that we originally copied...)
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.0 200 OK\r\n" b"Foo: a a a a a \r\n\r\n",
|
|
Response(
|
|
status_code=200,
|
|
headers=[("Foo", "a a a a a")],
|
|
http_version="1.0",
|
|
reason=b"OK",
|
|
),
|
|
)
|
|
|
|
# Empty headers -- also legal
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.0 200 OK\r\n" b"Foo:\r\n\r\n",
|
|
Response(
|
|
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
|
|
),
|
|
)
|
|
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.0 200 OK\r\n" b"Foo: \t \t \r\n\r\n",
|
|
Response(
|
|
status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK"
|
|
),
|
|
)
|
|
|
|
# Tolerate broken servers that leave off the response code
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.0 200\r\n" b"Foo: bar\r\n\r\n",
|
|
Response(
|
|
status_code=200, headers=[("Foo", "bar")], http_version="1.0", reason=b""
|
|
),
|
|
)
|
|
|
|
# Tolerate headers line endings (\r\n and \n)
|
|
# \n\r\b between headers and body
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.1 200 OK\r\nSomeHeader: val\n\r\n",
|
|
Response(
|
|
status_code=200,
|
|
headers=[("SomeHeader", "val")],
|
|
http_version="1.1",
|
|
reason="OK",
|
|
),
|
|
)
|
|
|
|
# delimited only with \n
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.1 200 OK\nSomeHeader1: val1\nSomeHeader2: val2\n\n",
|
|
Response(
|
|
status_code=200,
|
|
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
|
|
http_version="1.1",
|
|
reason="OK",
|
|
),
|
|
)
|
|
|
|
# mixed \r\n and \n
|
|
tr(
|
|
READERS[SERVER, SEND_RESPONSE],
|
|
b"HTTP/1.1 200 OK\r\nSomeHeader1: val1\nSomeHeader2: val2\n\r\n",
|
|
Response(
|
|
status_code=200,
|
|
headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")],
|
|
http_version="1.1",
|
|
reason="OK",
|
|
),
|
|
)
|
|
|
|
# obsolete line folding
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1\r\n"
|
|
b"Host: example.com\r\n"
|
|
b"Some: multi-line\r\n"
|
|
b" header\r\n"
|
|
b"\tnonsense\r\n"
|
|
b" \t \t\tI guess\r\n"
|
|
b"Connection: close\r\n"
|
|
b"More-nonsense: in the\r\n"
|
|
b" last header \r\n\r\n",
|
|
Request(
|
|
method="HEAD",
|
|
target="/foo",
|
|
headers=[
|
|
("Host", "example.com"),
|
|
("Some", "multi-line header nonsense I guess"),
|
|
("Connection", "close"),
|
|
("More-nonsense", "in the last header"),
|
|
],
|
|
),
|
|
)
|
|
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1\r\n" b" folded: line\r\n\r\n",
|
|
None,
|
|
)
|
|
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1\r\n" b"foo : line\r\n\r\n",
|
|
None,
|
|
)
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
|
|
None,
|
|
)
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n",
|
|
None,
|
|
)
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None)
|
|
|
|
|
|
def test__obsolete_line_fold_bytes() -> None:
|
|
# _obsolete_line_fold has a defensive cast to bytearray, which is
|
|
# necessary to protect against O(n^2) behavior in case anyone ever passes
|
|
# in regular bytestrings... but right now we never pass in regular
|
|
# bytestrings. so this test just exists to get some coverage on that
|
|
# defensive cast.
|
|
assert list(_obsolete_line_fold([b"aaa", b"bbb", b" ccc", b"ddd"])) == [
|
|
b"aaa",
|
|
bytearray(b"bbb ccc"),
|
|
b"ddd",
|
|
]
|
|
|
|
|
|
def _run_reader_iter(
|
|
reader: Any, buf: bytes, do_eof: bool
|
|
) -> Generator[Any, None, None]:
|
|
while True:
|
|
event = reader(buf)
|
|
if event is None:
|
|
break
|
|
yield event
|
|
# body readers have undefined behavior after returning EndOfMessage,
|
|
# because this changes the state so they don't get called again
|
|
if type(event) is EndOfMessage:
|
|
break
|
|
if do_eof:
|
|
assert not buf
|
|
yield reader.read_eof()
|
|
|
|
|
|
def _run_reader(*args: Any) -> List[Event]:
|
|
events = list(_run_reader_iter(*args))
|
|
return normalize_data_events(events)
|
|
|
|
|
|
def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None:
|
|
# Simple: consume whole thing
|
|
print("Test 1")
|
|
buf = makebuf(data)
|
|
assert _run_reader(thunk(), buf, do_eof) == expected
|
|
|
|
# Incrementally growing buffer
|
|
print("Test 2")
|
|
reader = thunk()
|
|
buf = ReceiveBuffer()
|
|
events = []
|
|
for i in range(len(data)):
|
|
events += _run_reader(reader, buf, False)
|
|
buf += data[i : i + 1]
|
|
events += _run_reader(reader, buf, do_eof)
|
|
assert normalize_data_events(events) == expected
|
|
|
|
is_complete = any(type(event) is EndOfMessage for event in expected)
|
|
if is_complete and not do_eof:
|
|
buf = makebuf(data + b"trailing")
|
|
assert _run_reader(thunk(), buf, False) == expected
|
|
|
|
|
|
def test_ContentLengthReader() -> None:
|
|
t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()])
|
|
|
|
t_body_reader(
|
|
lambda: ContentLengthReader(10),
|
|
b"0123456789",
|
|
[Data(data=b"0123456789"), EndOfMessage()],
|
|
)
|
|
|
|
|
|
def test_Http10Reader() -> None:
|
|
t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True)
|
|
t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False)
|
|
t_body_reader(
|
|
Http10Reader, b"asdf", [Data(data=b"asdf"), EndOfMessage()], do_eof=True
|
|
)
|
|
|
|
|
|
def test_ChunkedReader() -> None:
|
|
t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()])
|
|
|
|
t_body_reader(
|
|
ChunkedReader,
|
|
b"0\r\nSome: header\r\n\r\n",
|
|
[EndOfMessage(headers=[("Some", "header")])],
|
|
)
|
|
|
|
t_body_reader(
|
|
ChunkedReader,
|
|
b"5\r\n01234\r\n"
|
|
+ b"10\r\n0123456789abcdef\r\n"
|
|
+ b"0\r\n"
|
|
+ b"Some: header\r\n\r\n",
|
|
[
|
|
Data(data=b"012340123456789abcdef"),
|
|
EndOfMessage(headers=[("Some", "header")]),
|
|
],
|
|
)
|
|
|
|
t_body_reader(
|
|
ChunkedReader,
|
|
b"5\r\n01234\r\n" + b"10\r\n0123456789abcdef\r\n" + b"0\r\n\r\n",
|
|
[Data(data=b"012340123456789abcdef"), EndOfMessage()],
|
|
)
|
|
|
|
# handles upper and lowercase hex
|
|
t_body_reader(
|
|
ChunkedReader,
|
|
b"aA\r\n" + b"x" * 0xAA + b"\r\n" + b"0\r\n\r\n",
|
|
[Data(data=b"x" * 0xAA), EndOfMessage()],
|
|
)
|
|
|
|
# refuses arbitrarily long chunk integers
|
|
with pytest.raises(LocalProtocolError):
|
|
# Technically this is legal HTTP/1.1, but we refuse to process chunk
|
|
# sizes that don't fit into 20 characters of hex
|
|
t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")])
|
|
|
|
# refuses garbage in the chunk count
|
|
with pytest.raises(LocalProtocolError):
|
|
t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None)
|
|
|
|
# handles (and discards) "chunk extensions" omg wtf
|
|
t_body_reader(
|
|
ChunkedReader,
|
|
b"5; hello=there\r\n"
|
|
+ b"xxxxx"
|
|
+ b"\r\n"
|
|
+ b'0; random="junk"; some=more; canbe=lonnnnngg\r\n\r\n',
|
|
[Data(data=b"xxxxx"), EndOfMessage()],
|
|
)
|
|
|
|
t_body_reader(
|
|
ChunkedReader,
|
|
b"5 \r\n01234\r\n" + b"0\r\n\r\n",
|
|
[Data(data=b"01234"), EndOfMessage()],
|
|
)
|
|
|
|
|
|
def test_ContentLengthWriter() -> None:
|
|
w = ContentLengthWriter(5)
|
|
assert dowrite(w, Data(data=b"123")) == b"123"
|
|
assert dowrite(w, Data(data=b"45")) == b"45"
|
|
assert dowrite(w, EndOfMessage()) == b""
|
|
|
|
w = ContentLengthWriter(5)
|
|
with pytest.raises(LocalProtocolError):
|
|
dowrite(w, Data(data=b"123456"))
|
|
|
|
w = ContentLengthWriter(5)
|
|
dowrite(w, Data(data=b"123"))
|
|
with pytest.raises(LocalProtocolError):
|
|
dowrite(w, Data(data=b"456"))
|
|
|
|
w = ContentLengthWriter(5)
|
|
dowrite(w, Data(data=b"123"))
|
|
with pytest.raises(LocalProtocolError):
|
|
dowrite(w, EndOfMessage())
|
|
|
|
w = ContentLengthWriter(5)
|
|
dowrite(w, Data(data=b"123")) == b"123"
|
|
dowrite(w, Data(data=b"45")) == b"45"
|
|
with pytest.raises(LocalProtocolError):
|
|
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
|
|
|
|
|
|
def test_ChunkedWriter() -> None:
|
|
w = ChunkedWriter()
|
|
assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n"
|
|
assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n"
|
|
|
|
assert dowrite(w, Data(data=b"")) == b""
|
|
|
|
assert dowrite(w, EndOfMessage()) == b"0\r\n\r\n"
|
|
|
|
assert (
|
|
dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")]))
|
|
== b"0\r\nEtag: asdf\r\na: b\r\n\r\n"
|
|
)
|
|
|
|
|
|
def test_Http10Writer() -> None:
|
|
w = Http10Writer()
|
|
assert dowrite(w, Data(data=b"1234")) == b"1234"
|
|
assert dowrite(w, EndOfMessage()) == b""
|
|
|
|
with pytest.raises(LocalProtocolError):
|
|
dowrite(w, EndOfMessage(headers=[("Etag", "asdf")]))
|
|
|
|
|
|
def test_reject_garbage_after_request_line() -> None:
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None)
|
|
|
|
|
|
def test_reject_garbage_after_response_line() -> None:
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1 xxxxxx\r\n" b"Host: a\r\n\r\n",
|
|
None,
|
|
)
|
|
|
|
|
|
def test_reject_garbage_in_header_line() -> None:
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1\r\n" b"Host: foo\x00bar\r\n\r\n",
|
|
None,
|
|
)
|
|
|
|
|
|
def test_reject_non_vchar_in_path() -> None:
|
|
for bad_char in b"\x00\x20\x7f\xee":
|
|
message = bytearray(b"HEAD /")
|
|
message.append(bad_char)
|
|
message.extend(b" HTTP/1.1\r\nHost: foobar\r\n\r\n")
|
|
with pytest.raises(LocalProtocolError):
|
|
tr(READERS[CLIENT, IDLE], message, None)
|
|
|
|
|
|
# https://github.com/python-hyper/h11/issues/57
|
|
def test_allow_some_garbage_in_cookies() -> None:
|
|
tr(
|
|
READERS[CLIENT, IDLE],
|
|
b"HEAD /foo HTTP/1.1\r\n"
|
|
b"Host: foo\r\n"
|
|
b"Set-Cookie: ___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900\r\n"
|
|
b"\r\n",
|
|
Request(
|
|
method="HEAD",
|
|
target="/foo",
|
|
headers=[
|
|
("Host", "foo"),
|
|
("Set-Cookie", "___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900"),
|
|
],
|
|
),
|
|
)
|
|
|
|
|
|
def test_host_comes_first() -> None:
|
|
tw(
|
|
write_headers,
|
|
normalize_and_validate([("foo", "bar"), ("Host", "example.com")]),
|
|
b"Host: example.com\r\nfoo: bar\r\n\r\n",
|
|
)
|