840 lines
30 KiB
Python
840 lines
30 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import contextlib
|
||
|
import inspect
|
||
|
import io
|
||
|
import json
|
||
|
import math
|
||
|
import queue
|
||
|
import sys
|
||
|
import typing
|
||
|
import warnings
|
||
|
from concurrent.futures import Future
|
||
|
from functools import cached_property
|
||
|
from types import GeneratorType
|
||
|
from urllib.parse import unquote, urljoin
|
||
|
|
||
|
import anyio
|
||
|
import anyio.abc
|
||
|
import anyio.from_thread
|
||
|
from anyio.abc import ObjectReceiveStream, ObjectSendStream
|
||
|
from anyio.streams.stapled import StapledObjectStream
|
||
|
|
||
|
from starlette._utils import is_async_callable
|
||
|
from starlette.types import ASGIApp, Message, Receive, Scope, Send
|
||
|
from starlette.websockets import WebSocketDisconnect
|
||
|
|
||
|
if sys.version_info >= (3, 10): # pragma: no cover
|
||
|
from typing import TypeGuard
|
||
|
else: # pragma: no cover
|
||
|
from typing_extensions import TypeGuard
|
||
|
|
||
|
try:
|
||
|
import httpx
|
||
|
except ModuleNotFoundError: # pragma: no cover
|
||
|
raise RuntimeError(
|
||
|
"The starlette.testclient module requires the httpx package to be installed.\n"
|
||
|
"You can install this with:\n"
|
||
|
" $ pip install httpx\n"
|
||
|
)
|
||
|
_PortalFactoryType = typing.Callable[
|
||
|
[], typing.ContextManager[anyio.abc.BlockingPortal]
|
||
|
]
|
||
|
|
||
|
ASGIInstance = typing.Callable[[Receive, Send], typing.Awaitable[None]]
|
||
|
ASGI2App = typing.Callable[[Scope], ASGIInstance]
|
||
|
ASGI3App = typing.Callable[[Scope, Receive, Send], typing.Awaitable[None]]
|
||
|
|
||
|
|
||
|
_RequestData = typing.Mapping[str, typing.Union[str, typing.Iterable[str], bytes]]
|
||
|
|
||
|
|
||
|
def _is_asgi3(app: ASGI2App | ASGI3App) -> TypeGuard[ASGI3App]:
|
||
|
if inspect.isclass(app):
|
||
|
return hasattr(app, "__await__")
|
||
|
return is_async_callable(app)
|
||
|
|
||
|
|
||
|
class _WrapASGI2:
|
||
|
"""
|
||
|
Provide an ASGI3 interface onto an ASGI2 app.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, app: ASGI2App) -> None:
|
||
|
self.app = app
|
||
|
|
||
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||
|
instance = self.app(scope)
|
||
|
await instance(receive, send)
|
||
|
|
||
|
|
||
|
class _AsyncBackend(typing.TypedDict):
|
||
|
backend: str
|
||
|
backend_options: dict[str, typing.Any]
|
||
|
|
||
|
|
||
|
class _Upgrade(Exception):
|
||
|
def __init__(self, session: WebSocketTestSession) -> None:
|
||
|
self.session = session
|
||
|
|
||
|
|
||
|
class WebSocketDenialResponse( # type: ignore[misc]
|
||
|
httpx.Response,
|
||
|
WebSocketDisconnect,
|
||
|
):
|
||
|
"""
|
||
|
A special case of `WebSocketDisconnect`, raised in the `TestClient` if the
|
||
|
`WebSocket` is closed before being accepted with a `send_denial_response()`.
|
||
|
"""
|
||
|
|
||
|
|
||
|
class WebSocketTestSession:
|
||
|
def __init__(
|
||
|
self,
|
||
|
app: ASGI3App,
|
||
|
scope: Scope,
|
||
|
portal_factory: _PortalFactoryType,
|
||
|
) -> None:
|
||
|
self.app = app
|
||
|
self.scope = scope
|
||
|
self.accepted_subprotocol = None
|
||
|
self.portal_factory = portal_factory
|
||
|
self._receive_queue: queue.Queue[Message] = queue.Queue()
|
||
|
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
|
||
|
self.extra_headers = None
|
||
|
|
||
|
def __enter__(self) -> WebSocketTestSession:
|
||
|
self.exit_stack = contextlib.ExitStack()
|
||
|
self.portal = self.exit_stack.enter_context(self.portal_factory())
|
||
|
|
||
|
try:
|
||
|
_: Future[None] = self.portal.start_task_soon(self._run)
|
||
|
self.send({"type": "websocket.connect"})
|
||
|
message = self.receive()
|
||
|
self._raise_on_close(message)
|
||
|
except Exception:
|
||
|
self.exit_stack.close()
|
||
|
raise
|
||
|
self.accepted_subprotocol = message.get("subprotocol", None)
|
||
|
self.extra_headers = message.get("headers", None)
|
||
|
return self
|
||
|
|
||
|
@cached_property
|
||
|
def should_close(self) -> anyio.Event:
|
||
|
return anyio.Event()
|
||
|
|
||
|
async def _notify_close(self) -> None:
|
||
|
self.should_close.set()
|
||
|
|
||
|
def __exit__(self, *args: typing.Any) -> None:
|
||
|
try:
|
||
|
self.close(1000)
|
||
|
finally:
|
||
|
self.portal.start_task_soon(self._notify_close)
|
||
|
self.exit_stack.close()
|
||
|
while not self._send_queue.empty():
|
||
|
message = self._send_queue.get()
|
||
|
if isinstance(message, BaseException):
|
||
|
raise message
|
||
|
|
||
|
async def _run(self) -> None:
|
||
|
"""
|
||
|
The sub-thread in which the websocket session runs.
|
||
|
"""
|
||
|
|
||
|
async def run_app(tg: anyio.abc.TaskGroup) -> None:
|
||
|
try:
|
||
|
await self.app(self.scope, self._asgi_receive, self._asgi_send)
|
||
|
except anyio.get_cancelled_exc_class():
|
||
|
...
|
||
|
except BaseException as exc:
|
||
|
self._send_queue.put(exc)
|
||
|
raise
|
||
|
finally:
|
||
|
tg.cancel_scope.cancel()
|
||
|
|
||
|
async with anyio.create_task_group() as tg:
|
||
|
tg.start_soon(run_app, tg)
|
||
|
await self.should_close.wait()
|
||
|
tg.cancel_scope.cancel()
|
||
|
|
||
|
async def _asgi_receive(self) -> Message:
|
||
|
while self._receive_queue.empty():
|
||
|
await anyio.sleep(0)
|
||
|
return self._receive_queue.get()
|
||
|
|
||
|
async def _asgi_send(self, message: Message) -> None:
|
||
|
self._send_queue.put(message)
|
||
|
|
||
|
def _raise_on_close(self, message: Message) -> None:
|
||
|
if message["type"] == "websocket.close":
|
||
|
raise WebSocketDisconnect(
|
||
|
code=message.get("code", 1000), reason=message.get("reason", "")
|
||
|
)
|
||
|
elif message["type"] == "websocket.http.response.start":
|
||
|
status_code: int = message["status"]
|
||
|
headers: list[tuple[bytes, bytes]] = message["headers"]
|
||
|
body: list[bytes] = []
|
||
|
while True:
|
||
|
message = self.receive()
|
||
|
assert message["type"] == "websocket.http.response.body"
|
||
|
body.append(message["body"])
|
||
|
if not message.get("more_body", False):
|
||
|
break
|
||
|
raise WebSocketDenialResponse(
|
||
|
status_code=status_code,
|
||
|
headers=headers,
|
||
|
content=b"".join(body),
|
||
|
)
|
||
|
|
||
|
def send(self, message: Message) -> None:
|
||
|
self._receive_queue.put(message)
|
||
|
|
||
|
def send_text(self, data: str) -> None:
|
||
|
self.send({"type": "websocket.receive", "text": data})
|
||
|
|
||
|
def send_bytes(self, data: bytes) -> None:
|
||
|
self.send({"type": "websocket.receive", "bytes": data})
|
||
|
|
||
|
def send_json(
|
||
|
self, data: typing.Any, mode: typing.Literal["text", "binary"] = "text"
|
||
|
) -> None:
|
||
|
text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
|
||
|
if mode == "text":
|
||
|
self.send({"type": "websocket.receive", "text": text})
|
||
|
else:
|
||
|
self.send({"type": "websocket.receive", "bytes": text.encode("utf-8")})
|
||
|
|
||
|
def close(self, code: int = 1000, reason: str | None = None) -> None:
|
||
|
self.send({"type": "websocket.disconnect", "code": code, "reason": reason})
|
||
|
|
||
|
def receive(self) -> Message:
|
||
|
message = self._send_queue.get()
|
||
|
if isinstance(message, BaseException):
|
||
|
raise message
|
||
|
return message
|
||
|
|
||
|
def receive_text(self) -> str:
|
||
|
message = self.receive()
|
||
|
self._raise_on_close(message)
|
||
|
return typing.cast(str, message["text"])
|
||
|
|
||
|
def receive_bytes(self) -> bytes:
|
||
|
message = self.receive()
|
||
|
self._raise_on_close(message)
|
||
|
return typing.cast(bytes, message["bytes"])
|
||
|
|
||
|
def receive_json(
|
||
|
self, mode: typing.Literal["text", "binary"] = "text"
|
||
|
) -> typing.Any:
|
||
|
message = self.receive()
|
||
|
self._raise_on_close(message)
|
||
|
if mode == "text":
|
||
|
text = message["text"]
|
||
|
else:
|
||
|
text = message["bytes"].decode("utf-8")
|
||
|
return json.loads(text)
|
||
|
|
||
|
|
||
|
class _TestClientTransport(httpx.BaseTransport):
|
||
|
def __init__(
|
||
|
self,
|
||
|
app: ASGI3App,
|
||
|
portal_factory: _PortalFactoryType,
|
||
|
raise_server_exceptions: bool = True,
|
||
|
root_path: str = "",
|
||
|
*,
|
||
|
app_state: dict[str, typing.Any],
|
||
|
) -> None:
|
||
|
self.app = app
|
||
|
self.raise_server_exceptions = raise_server_exceptions
|
||
|
self.root_path = root_path
|
||
|
self.portal_factory = portal_factory
|
||
|
self.app_state = app_state
|
||
|
|
||
|
def handle_request(self, request: httpx.Request) -> httpx.Response:
|
||
|
scheme = request.url.scheme
|
||
|
netloc = request.url.netloc.decode(encoding="ascii")
|
||
|
path = request.url.path
|
||
|
raw_path = request.url.raw_path
|
||
|
query = request.url.query.decode(encoding="ascii")
|
||
|
|
||
|
default_port = {"http": 80, "ws": 80, "https": 443, "wss": 443}[scheme]
|
||
|
|
||
|
if ":" in netloc:
|
||
|
host, port_string = netloc.split(":", 1)
|
||
|
port = int(port_string)
|
||
|
else:
|
||
|
host = netloc
|
||
|
port = default_port
|
||
|
|
||
|
# Include the 'host' header.
|
||
|
if "host" in request.headers:
|
||
|
headers: list[tuple[bytes, bytes]] = []
|
||
|
elif port == default_port: # pragma: no cover
|
||
|
headers = [(b"host", host.encode())]
|
||
|
else: # pragma: no cover
|
||
|
headers = [(b"host", (f"{host}:{port}").encode())]
|
||
|
|
||
|
# Include other request headers.
|
||
|
headers += [
|
||
|
(key.lower().encode(), value.encode())
|
||
|
for key, value in request.headers.multi_items()
|
||
|
]
|
||
|
|
||
|
scope: dict[str, typing.Any]
|
||
|
|
||
|
if scheme in {"ws", "wss"}:
|
||
|
subprotocol = request.headers.get("sec-websocket-protocol", None)
|
||
|
if subprotocol is None:
|
||
|
subprotocols: typing.Sequence[str] = []
|
||
|
else:
|
||
|
subprotocols = [value.strip() for value in subprotocol.split(",")]
|
||
|
scope = {
|
||
|
"type": "websocket",
|
||
|
"path": unquote(path),
|
||
|
"raw_path": raw_path,
|
||
|
"root_path": self.root_path,
|
||
|
"scheme": scheme,
|
||
|
"query_string": query.encode(),
|
||
|
"headers": headers,
|
||
|
"client": ["testclient", 50000],
|
||
|
"server": [host, port],
|
||
|
"subprotocols": subprotocols,
|
||
|
"state": self.app_state.copy(),
|
||
|
"extensions": {"websocket.http.response": {}},
|
||
|
}
|
||
|
session = WebSocketTestSession(self.app, scope, self.portal_factory)
|
||
|
raise _Upgrade(session)
|
||
|
|
||
|
scope = {
|
||
|
"type": "http",
|
||
|
"http_version": "1.1",
|
||
|
"method": request.method,
|
||
|
"path": unquote(path),
|
||
|
"raw_path": raw_path,
|
||
|
"root_path": self.root_path,
|
||
|
"scheme": scheme,
|
||
|
"query_string": query.encode(),
|
||
|
"headers": headers,
|
||
|
"client": ["testclient", 50000],
|
||
|
"server": [host, port],
|
||
|
"extensions": {"http.response.debug": {}},
|
||
|
"state": self.app_state.copy(),
|
||
|
}
|
||
|
|
||
|
request_complete = False
|
||
|
response_started = False
|
||
|
response_complete: anyio.Event
|
||
|
raw_kwargs: dict[str, typing.Any] = {"stream": io.BytesIO()}
|
||
|
template = None
|
||
|
context = None
|
||
|
|
||
|
async def receive() -> Message:
|
||
|
nonlocal request_complete
|
||
|
|
||
|
if request_complete:
|
||
|
if not response_complete.is_set():
|
||
|
await response_complete.wait()
|
||
|
return {"type": "http.disconnect"}
|
||
|
|
||
|
body = request.read()
|
||
|
if isinstance(body, str):
|
||
|
body_bytes: bytes = body.encode("utf-8") # pragma: no cover
|
||
|
elif body is None:
|
||
|
body_bytes = b"" # pragma: no cover
|
||
|
elif isinstance(body, GeneratorType):
|
||
|
try: # pragma: no cover
|
||
|
chunk = body.send(None)
|
||
|
if isinstance(chunk, str):
|
||
|
chunk = chunk.encode("utf-8")
|
||
|
return {"type": "http.request", "body": chunk, "more_body": True}
|
||
|
except StopIteration: # pragma: no cover
|
||
|
request_complete = True
|
||
|
return {"type": "http.request", "body": b""}
|
||
|
else:
|
||
|
body_bytes = body
|
||
|
|
||
|
request_complete = True
|
||
|
return {"type": "http.request", "body": body_bytes}
|
||
|
|
||
|
async def send(message: Message) -> None:
|
||
|
nonlocal raw_kwargs, response_started, template, context
|
||
|
|
||
|
if message["type"] == "http.response.start":
|
||
|
assert (
|
||
|
not response_started
|
||
|
), 'Received multiple "http.response.start" messages.'
|
||
|
raw_kwargs["status_code"] = message["status"]
|
||
|
raw_kwargs["headers"] = [
|
||
|
(key.decode(), value.decode())
|
||
|
for key, value in message.get("headers", [])
|
||
|
]
|
||
|
response_started = True
|
||
|
elif message["type"] == "http.response.body":
|
||
|
assert (
|
||
|
response_started
|
||
|
), 'Received "http.response.body" without "http.response.start".'
|
||
|
assert (
|
||
|
not response_complete.is_set()
|
||
|
), 'Received "http.response.body" after response completed.'
|
||
|
body = message.get("body", b"")
|
||
|
more_body = message.get("more_body", False)
|
||
|
if request.method != "HEAD":
|
||
|
raw_kwargs["stream"].write(body)
|
||
|
if not more_body:
|
||
|
raw_kwargs["stream"].seek(0)
|
||
|
response_complete.set()
|
||
|
elif message["type"] == "http.response.debug":
|
||
|
template = message["info"]["template"]
|
||
|
context = message["info"]["context"]
|
||
|
|
||
|
try:
|
||
|
with self.portal_factory() as portal:
|
||
|
response_complete = portal.call(anyio.Event)
|
||
|
portal.call(self.app, scope, receive, send)
|
||
|
except BaseException as exc:
|
||
|
if self.raise_server_exceptions:
|
||
|
raise exc
|
||
|
|
||
|
if self.raise_server_exceptions:
|
||
|
assert response_started, "TestClient did not receive any response."
|
||
|
elif not response_started:
|
||
|
raw_kwargs = {
|
||
|
"status_code": 500,
|
||
|
"headers": [],
|
||
|
"stream": io.BytesIO(),
|
||
|
}
|
||
|
|
||
|
raw_kwargs["stream"] = httpx.ByteStream(raw_kwargs["stream"].read())
|
||
|
|
||
|
response = httpx.Response(**raw_kwargs, request=request)
|
||
|
if template is not None:
|
||
|
response.template = template # type: ignore[attr-defined]
|
||
|
response.context = context # type: ignore[attr-defined]
|
||
|
return response
|
||
|
|
||
|
|
||
|
class TestClient(httpx.Client):
|
||
|
__test__ = False
|
||
|
task: Future[None]
|
||
|
portal: anyio.abc.BlockingPortal | None = None
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
app: ASGIApp,
|
||
|
base_url: str = "http://testserver",
|
||
|
raise_server_exceptions: bool = True,
|
||
|
root_path: str = "",
|
||
|
backend: typing.Literal["asyncio", "trio"] = "asyncio",
|
||
|
backend_options: dict[str, typing.Any] | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
headers: dict[str, str] | None = None,
|
||
|
follow_redirects: bool = True,
|
||
|
) -> None:
|
||
|
self.async_backend = _AsyncBackend(
|
||
|
backend=backend, backend_options=backend_options or {}
|
||
|
)
|
||
|
if _is_asgi3(app):
|
||
|
asgi_app = app
|
||
|
else:
|
||
|
app = typing.cast(ASGI2App, app) # type: ignore[assignment]
|
||
|
asgi_app = _WrapASGI2(app) # type: ignore[arg-type]
|
||
|
self.app = asgi_app
|
||
|
self.app_state: dict[str, typing.Any] = {}
|
||
|
transport = _TestClientTransport(
|
||
|
self.app,
|
||
|
portal_factory=self._portal_factory,
|
||
|
raise_server_exceptions=raise_server_exceptions,
|
||
|
root_path=root_path,
|
||
|
app_state=self.app_state,
|
||
|
)
|
||
|
if headers is None:
|
||
|
headers = {}
|
||
|
headers.setdefault("user-agent", "testclient")
|
||
|
super().__init__(
|
||
|
base_url=base_url,
|
||
|
headers=headers,
|
||
|
transport=transport,
|
||
|
follow_redirects=follow_redirects,
|
||
|
cookies=cookies,
|
||
|
)
|
||
|
|
||
|
@contextlib.contextmanager
|
||
|
def _portal_factory(self) -> typing.Generator[anyio.abc.BlockingPortal, None, None]:
|
||
|
if self.portal is not None:
|
||
|
yield self.portal
|
||
|
else:
|
||
|
with anyio.from_thread.start_blocking_portal(
|
||
|
**self.async_backend
|
||
|
) as portal:
|
||
|
yield portal
|
||
|
|
||
|
def _choose_redirect_arg(
|
||
|
self, follow_redirects: bool | None, allow_redirects: bool | None
|
||
|
) -> bool | httpx._client.UseClientDefault:
|
||
|
redirect: bool | httpx._client.UseClientDefault = (
|
||
|
httpx._client.USE_CLIENT_DEFAULT
|
||
|
)
|
||
|
if allow_redirects is not None:
|
||
|
message = (
|
||
|
"The `allow_redirects` argument is deprecated. "
|
||
|
"Use `follow_redirects` instead."
|
||
|
)
|
||
|
warnings.warn(message, DeprecationWarning)
|
||
|
redirect = allow_redirects
|
||
|
if follow_redirects is not None:
|
||
|
redirect = follow_redirects
|
||
|
elif allow_redirects is not None and follow_redirects is not None:
|
||
|
raise RuntimeError( # pragma: no cover
|
||
|
"Cannot use both `allow_redirects` and `follow_redirects`."
|
||
|
)
|
||
|
return redirect
|
||
|
|
||
|
def request( # type: ignore[override]
|
||
|
self,
|
||
|
method: str,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
content: httpx._types.RequestContent | None = None,
|
||
|
data: _RequestData | None = None,
|
||
|
files: httpx._types.RequestFiles | None = None,
|
||
|
json: typing.Any = None,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
url = self._merge_url(url)
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().request(
|
||
|
method,
|
||
|
url,
|
||
|
content=content,
|
||
|
data=data,
|
||
|
files=files,
|
||
|
json=json,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def get( # type: ignore[override]
|
||
|
self,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().get(
|
||
|
url,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def options( # type: ignore[override]
|
||
|
self,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().options(
|
||
|
url,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def head( # type: ignore[override]
|
||
|
self,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().head(
|
||
|
url,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def post( # type: ignore[override]
|
||
|
self,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
content: httpx._types.RequestContent | None = None,
|
||
|
data: _RequestData | None = None,
|
||
|
files: httpx._types.RequestFiles | None = None,
|
||
|
json: typing.Any = None,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().post(
|
||
|
url,
|
||
|
content=content,
|
||
|
data=data,
|
||
|
files=files,
|
||
|
json=json,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def put( # type: ignore[override]
|
||
|
self,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
content: httpx._types.RequestContent | None = None,
|
||
|
data: _RequestData | None = None,
|
||
|
files: httpx._types.RequestFiles | None = None,
|
||
|
json: typing.Any = None,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().put(
|
||
|
url,
|
||
|
content=content,
|
||
|
data=data,
|
||
|
files=files,
|
||
|
json=json,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def patch( # type: ignore[override]
|
||
|
self,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
content: httpx._types.RequestContent | None = None,
|
||
|
data: _RequestData | None = None,
|
||
|
files: httpx._types.RequestFiles | None = None,
|
||
|
json: typing.Any = None,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().patch(
|
||
|
url,
|
||
|
content=content,
|
||
|
data=data,
|
||
|
files=files,
|
||
|
json=json,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def delete( # type: ignore[override]
|
||
|
self,
|
||
|
url: httpx._types.URLTypes,
|
||
|
*,
|
||
|
params: httpx._types.QueryParamTypes | None = None,
|
||
|
headers: httpx._types.HeaderTypes | None = None,
|
||
|
cookies: httpx._types.CookieTypes | None = None,
|
||
|
auth: httpx._types.AuthTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
follow_redirects: bool | None = None,
|
||
|
allow_redirects: bool | None = None,
|
||
|
timeout: httpx._types.TimeoutTypes
|
||
|
| httpx._client.UseClientDefault = httpx._client.USE_CLIENT_DEFAULT,
|
||
|
extensions: dict[str, typing.Any] | None = None,
|
||
|
) -> httpx.Response:
|
||
|
redirect = self._choose_redirect_arg(follow_redirects, allow_redirects)
|
||
|
return super().delete(
|
||
|
url,
|
||
|
params=params,
|
||
|
headers=headers,
|
||
|
cookies=cookies,
|
||
|
auth=auth,
|
||
|
follow_redirects=redirect,
|
||
|
timeout=timeout,
|
||
|
extensions=extensions,
|
||
|
)
|
||
|
|
||
|
def websocket_connect(
|
||
|
self,
|
||
|
url: str,
|
||
|
subprotocols: typing.Sequence[str] | None = None,
|
||
|
**kwargs: typing.Any,
|
||
|
) -> WebSocketTestSession:
|
||
|
url = urljoin("ws://testserver", url)
|
||
|
headers = kwargs.get("headers", {})
|
||
|
headers.setdefault("connection", "upgrade")
|
||
|
headers.setdefault("sec-websocket-key", "testserver==")
|
||
|
headers.setdefault("sec-websocket-version", "13")
|
||
|
if subprotocols is not None:
|
||
|
headers.setdefault("sec-websocket-protocol", ", ".join(subprotocols))
|
||
|
kwargs["headers"] = headers
|
||
|
try:
|
||
|
super().request("GET", url, **kwargs)
|
||
|
except _Upgrade as exc:
|
||
|
session = exc.session
|
||
|
else:
|
||
|
raise RuntimeError("Expected WebSocket upgrade") # pragma: no cover
|
||
|
|
||
|
return session
|
||
|
|
||
|
def __enter__(self) -> TestClient:
|
||
|
with contextlib.ExitStack() as stack:
|
||
|
self.portal = portal = stack.enter_context(
|
||
|
anyio.from_thread.start_blocking_portal(**self.async_backend)
|
||
|
)
|
||
|
|
||
|
@stack.callback
|
||
|
def reset_portal() -> None:
|
||
|
self.portal = None
|
||
|
|
||
|
send1: ObjectSendStream[typing.MutableMapping[str, typing.Any] | None]
|
||
|
receive1: ObjectReceiveStream[typing.MutableMapping[str, typing.Any] | None]
|
||
|
send2: ObjectSendStream[typing.MutableMapping[str, typing.Any]]
|
||
|
receive2: ObjectReceiveStream[typing.MutableMapping[str, typing.Any]]
|
||
|
send1, receive1 = anyio.create_memory_object_stream(math.inf)
|
||
|
send2, receive2 = anyio.create_memory_object_stream(math.inf)
|
||
|
self.stream_send = StapledObjectStream(send1, receive1)
|
||
|
self.stream_receive = StapledObjectStream(send2, receive2)
|
||
|
self.task = portal.start_task_soon(self.lifespan)
|
||
|
portal.call(self.wait_startup)
|
||
|
|
||
|
@stack.callback
|
||
|
def wait_shutdown() -> None:
|
||
|
portal.call(self.wait_shutdown)
|
||
|
|
||
|
self.exit_stack = stack.pop_all()
|
||
|
|
||
|
return self
|
||
|
|
||
|
def __exit__(self, *args: typing.Any) -> None:
|
||
|
self.exit_stack.close()
|
||
|
|
||
|
async def lifespan(self) -> None:
|
||
|
scope = {"type": "lifespan", "state": self.app_state}
|
||
|
try:
|
||
|
await self.app(scope, self.stream_receive.receive, self.stream_send.send)
|
||
|
finally:
|
||
|
await self.stream_send.send(None)
|
||
|
|
||
|
async def wait_startup(self) -> None:
|
||
|
await self.stream_receive.send({"type": "lifespan.startup"})
|
||
|
|
||
|
async def receive() -> typing.Any:
|
||
|
message = await self.stream_send.receive()
|
||
|
if message is None:
|
||
|
self.task.result()
|
||
|
return message
|
||
|
|
||
|
message = await receive()
|
||
|
assert message["type"] in (
|
||
|
"lifespan.startup.complete",
|
||
|
"lifespan.startup.failed",
|
||
|
)
|
||
|
if message["type"] == "lifespan.startup.failed":
|
||
|
await receive()
|
||
|
|
||
|
async def wait_shutdown(self) -> None:
|
||
|
async def receive() -> typing.Any:
|
||
|
message = await self.stream_send.receive()
|
||
|
if message is None:
|
||
|
self.task.result()
|
||
|
return message
|
||
|
|
||
|
async with self.stream_send:
|
||
|
await self.stream_receive.send({"type": "lifespan.shutdown"})
|
||
|
message = await receive()
|
||
|
assert message["type"] in (
|
||
|
"lifespan.shutdown.complete",
|
||
|
"lifespan.shutdown.failed",
|
||
|
)
|
||
|
if message["type"] == "lifespan.shutdown.failed":
|
||
|
await receive()
|