webdriver_template/telecli/lib/python3.11/site-packages/starlette/datastructures.py

706 lines
22 KiB
Python
Raw Normal View History

2024-08-10 14:48:21 +03:00
from __future__ import annotations
import typing
from shlex import shlex
from urllib.parse import SplitResult, parse_qsl, urlencode, urlsplit
from starlette.concurrency import run_in_threadpool
from starlette.types import Scope
class Address(typing.NamedTuple):
host: str
port: int
_KeyType = typing.TypeVar("_KeyType")
# Mapping keys are invariant but their values are covariant since
# you can only read them
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
_CovariantValueType = typing.TypeVar("_CovariantValueType", covariant=True)
class URL:
def __init__(
self,
url: str = "",
scope: Scope | None = None,
**components: typing.Any,
) -> None:
if scope is not None:
assert not url, 'Cannot set both "url" and "scope".'
assert not components, 'Cannot set both "scope" and "**components".'
scheme = scope.get("scheme", "http")
server = scope.get("server", None)
path = scope["path"]
query_string = scope.get("query_string", b"")
host_header = None
for key, value in scope["headers"]:
if key == b"host":
host_header = value.decode("latin-1")
break
if host_header is not None:
url = f"{scheme}://{host_header}{path}"
elif server is None:
url = path
else:
host, port = server
default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
if port == default_port:
url = f"{scheme}://{host}{path}"
else:
url = f"{scheme}://{host}:{port}{path}"
if query_string:
url += "?" + query_string.decode()
elif components:
assert not url, 'Cannot set both "url" and "**components".'
url = URL("").replace(**components).components.geturl()
self._url = url
@property
def components(self) -> SplitResult:
if not hasattr(self, "_components"):
self._components = urlsplit(self._url)
return self._components
@property
def scheme(self) -> str:
return self.components.scheme
@property
def netloc(self) -> str:
return self.components.netloc
@property
def path(self) -> str:
return self.components.path
@property
def query(self) -> str:
return self.components.query
@property
def fragment(self) -> str:
return self.components.fragment
@property
def username(self) -> None | str:
return self.components.username
@property
def password(self) -> None | str:
return self.components.password
@property
def hostname(self) -> None | str:
return self.components.hostname
@property
def port(self) -> int | None:
return self.components.port
@property
def is_secure(self) -> bool:
return self.scheme in ("https", "wss")
def replace(self, **kwargs: typing.Any) -> URL:
if (
"username" in kwargs
or "password" in kwargs
or "hostname" in kwargs
or "port" in kwargs
):
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
password = kwargs.pop("password", self.password)
if hostname is None:
netloc = self.netloc
_, _, hostname = netloc.rpartition("@")
if hostname[-1] != "]":
hostname = hostname.rsplit(":", 1)[0]
netloc = hostname
if port is not None:
netloc += f":{port}"
if username is not None:
userpass = username
if password is not None:
userpass += f":{password}"
netloc = f"{userpass}@{netloc}"
kwargs["netloc"] = netloc
components = self.components._replace(**kwargs)
return self.__class__(components.geturl())
def include_query_params(self, **kwargs: typing.Any) -> URL:
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
params.update({str(key): str(value) for key, value in kwargs.items()})
query = urlencode(params.multi_items())
return self.replace(query=query)
def replace_query_params(self, **kwargs: typing.Any) -> URL:
query = urlencode([(str(key), str(value)) for key, value in kwargs.items()])
return self.replace(query=query)
def remove_query_params(self, keys: str | typing.Sequence[str]) -> URL:
if isinstance(keys, str):
keys = [keys]
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
for key in keys:
params.pop(key, None)
query = urlencode(params.multi_items())
return self.replace(query=query)
def __eq__(self, other: typing.Any) -> bool:
return str(self) == str(other)
def __str__(self) -> str:
return self._url
def __repr__(self) -> str:
url = str(self)
if self.password:
url = str(self.replace(password="********"))
return f"{self.__class__.__name__}({repr(url)})"
class URLPath(str):
"""
A URL path string that may also hold an associated protocol and/or host.
Used by the routing to return `url_path_for` matches.
"""
def __new__(cls, path: str, protocol: str = "", host: str = "") -> URLPath:
assert protocol in ("http", "websocket", "")
return str.__new__(cls, path)
def __init__(self, path: str, protocol: str = "", host: str = "") -> None:
self.protocol = protocol
self.host = host
def make_absolute_url(self, base_url: str | URL) -> URL:
if isinstance(base_url, str):
base_url = URL(base_url)
if self.protocol:
scheme = {
"http": {True: "https", False: "http"},
"websocket": {True: "wss", False: "ws"},
}[self.protocol][base_url.is_secure]
else:
scheme = base_url.scheme
netloc = self.host or base_url.netloc
path = base_url.path.rstrip("/") + str(self)
return URL(scheme=scheme, netloc=netloc, path=path)
class Secret:
"""
Holds a string value that should not be revealed in tracebacks etc.
You should cast the value to `str` at the point it is required.
"""
def __init__(self, value: str):
self._value = value
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"{class_name}('**********')"
def __str__(self) -> str:
return self._value
def __bool__(self) -> bool:
return bool(self._value)
class CommaSeparatedStrings(typing.Sequence[str]):
def __init__(self, value: str | typing.Sequence[str]):
if isinstance(value, str):
splitter = shlex(value, posix=True)
splitter.whitespace = ","
splitter.whitespace_split = True
self._items = [item.strip() for item in splitter]
else:
self._items = list(value)
def __len__(self) -> int:
return len(self._items)
def __getitem__(self, index: int | slice) -> typing.Any:
return self._items[index]
def __iter__(self) -> typing.Iterator[str]:
return iter(self._items)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = [item for item in self]
return f"{class_name}({items!r})"
def __str__(self) -> str:
return ", ".join(repr(item) for item in self)
class ImmutableMultiDict(typing.Mapping[_KeyType, _CovariantValueType]):
_dict: dict[_KeyType, _CovariantValueType]
def __init__(
self,
*args: ImmutableMultiDict[_KeyType, _CovariantValueType]
| typing.Mapping[_KeyType, _CovariantValueType]
| typing.Iterable[tuple[_KeyType, _CovariantValueType]],
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value: typing.Any = args[0] if args else []
if kwargs:
value = (
ImmutableMultiDict(value).multi_items()
+ ImmutableMultiDict(kwargs).multi_items()
)
if not value:
_items: list[tuple[typing.Any, typing.Any]] = []
elif hasattr(value, "multi_items"):
value = typing.cast(
ImmutableMultiDict[_KeyType, _CovariantValueType], value
)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = typing.cast(typing.Mapping[_KeyType, _CovariantValueType], value)
_items = list(value.items())
else:
value = typing.cast("list[tuple[typing.Any, typing.Any]]", value)
_items = list(value)
self._dict = {k: v for k, v in _items}
self._list = _items
def getlist(self, key: typing.Any) -> list[_CovariantValueType]:
return [item_value for item_key, item_value in self._list if item_key == key]
def keys(self) -> typing.KeysView[_KeyType]:
return self._dict.keys()
def values(self) -> typing.ValuesView[_CovariantValueType]:
return self._dict.values()
def items(self) -> typing.ItemsView[_KeyType, _CovariantValueType]:
return self._dict.items()
def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
return list(self._list)
def __getitem__(self, key: _KeyType) -> _CovariantValueType:
return self._dict[key]
def __contains__(self, key: typing.Any) -> bool:
return key in self._dict
def __iter__(self) -> typing.Iterator[_KeyType]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._dict)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, self.__class__):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
items = self.multi_items()
return f"{class_name}({items!r})"
class MultiDict(ImmutableMultiDict[typing.Any, typing.Any]):
def __setitem__(self, key: typing.Any, value: typing.Any) -> None:
self.setlist(key, [value])
def __delitem__(self, key: typing.Any) -> None:
self._list = [(k, v) for k, v in self._list if k != key]
del self._dict[key]
def pop(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
self._list = [(k, v) for k, v in self._list if k != key]
return self._dict.pop(key, default)
def popitem(self) -> tuple[typing.Any, typing.Any]:
key, value = self._dict.popitem()
self._list = [(k, v) for k, v in self._list if k != key]
return key, value
def poplist(self, key: typing.Any) -> list[typing.Any]:
values = [v for k, v in self._list if k == key]
self.pop(key)
return values
def clear(self) -> None:
self._dict.clear()
self._list.clear()
def setdefault(self, key: typing.Any, default: typing.Any = None) -> typing.Any:
if key not in self:
self._dict[key] = default
self._list.append((key, default))
return self[key]
def setlist(self, key: typing.Any, values: list[typing.Any]) -> None:
if not values:
self.pop(key, None)
else:
existing_items = [(k, v) for (k, v) in self._list if k != key]
self._list = existing_items + [(key, value) for value in values]
self._dict[key] = values[-1]
def append(self, key: typing.Any, value: typing.Any) -> None:
self._list.append((key, value))
self._dict[key] = value
def update(
self,
*args: MultiDict
| typing.Mapping[typing.Any, typing.Any]
| list[tuple[typing.Any, typing.Any]],
**kwargs: typing.Any,
) -> None:
value = MultiDict(*args, **kwargs)
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
self._list = existing_items + value.multi_items()
self._dict.update(value)
class QueryParams(ImmutableMultiDict[str, str]):
"""
An immutable multidict.
"""
def __init__(
self,
*args: ImmutableMultiDict[typing.Any, typing.Any]
| typing.Mapping[typing.Any, typing.Any]
| list[tuple[typing.Any, typing.Any]]
| str
| bytes,
**kwargs: typing.Any,
) -> None:
assert len(args) < 2, "Too many arguments."
value = args[0] if args else []
if isinstance(value, str):
super().__init__(parse_qsl(value, keep_blank_values=True), **kwargs)
elif isinstance(value, bytes):
super().__init__(
parse_qsl(value.decode("latin-1"), keep_blank_values=True), **kwargs
)
else:
super().__init__(*args, **kwargs) # type: ignore[arg-type]
self._list = [(str(k), str(v)) for k, v in self._list]
self._dict = {str(k): str(v) for k, v in self._dict.items()}
def __str__(self) -> str:
return urlencode(self._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
query_string = str(self)
return f"{class_name}({query_string!r})"
class UploadFile:
"""
An uploaded file included as part of the request data.
"""
def __init__(
self,
file: typing.BinaryIO,
*,
size: int | None = None,
filename: str | None = None,
headers: Headers | None = None,
) -> None:
self.filename = filename
self.file = file
self.size = size
self.headers = headers or Headers()
@property
def content_type(self) -> str | None:
return self.headers.get("content-type", None)
@property
def _in_memory(self) -> bool:
# check for SpooledTemporaryFile._rolled
rolled_to_disk = getattr(self.file, "_rolled", True)
return not rolled_to_disk
async def write(self, data: bytes) -> None:
if self.size is not None:
self.size += len(data)
if self._in_memory:
self.file.write(data)
else:
await run_in_threadpool(self.file.write, data)
async def read(self, size: int = -1) -> bytes:
if self._in_memory:
return self.file.read(size)
return await run_in_threadpool(self.file.read, size)
async def seek(self, offset: int) -> None:
if self._in_memory:
self.file.seek(offset)
else:
await run_in_threadpool(self.file.seek, offset)
async def close(self) -> None:
if self._in_memory:
self.file.close()
else:
await run_in_threadpool(self.file.close)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"filename={self.filename!r}, "
f"size={self.size!r}, "
f"headers={self.headers!r})"
)
class FormData(ImmutableMultiDict[str, typing.Union[UploadFile, str]]):
"""
An immutable multidict, containing both file uploads and text input.
"""
def __init__(
self,
*args: FormData
| typing.Mapping[str, str | UploadFile]
| list[tuple[str, str | UploadFile]],
**kwargs: str | UploadFile,
) -> None:
super().__init__(*args, **kwargs)
async def close(self) -> None:
for key, value in self.multi_items():
if isinstance(value, UploadFile):
await value.close()
class Headers(typing.Mapping[str, str]):
"""
An immutable, case-insensitive multidict.
"""
def __init__(
self,
headers: typing.Mapping[str, str] | None = None,
raw: list[tuple[bytes, bytes]] | None = None,
scope: typing.MutableMapping[str, typing.Any] | None = None,
) -> None:
self._list: list[tuple[bytes, bytes]] = []
if headers is not None:
assert raw is None, 'Cannot set both "headers" and "raw".'
assert scope is None, 'Cannot set both "headers" and "scope".'
self._list = [
(key.lower().encode("latin-1"), value.encode("latin-1"))
for key, value in headers.items()
]
elif raw is not None:
assert scope is None, 'Cannot set both "raw" and "scope".'
self._list = raw
elif scope is not None:
# scope["headers"] isn't necessarily a list
# it might be a tuple or other iterable
self._list = scope["headers"] = list(scope["headers"])
@property
def raw(self) -> list[tuple[bytes, bytes]]:
return list(self._list)
def keys(self) -> list[str]: # type: ignore[override]
return [key.decode("latin-1") for key, value in self._list]
def values(self) -> list[str]: # type: ignore[override]
return [value.decode("latin-1") for key, value in self._list]
def items(self) -> list[tuple[str, str]]: # type: ignore[override]
return [
(key.decode("latin-1"), value.decode("latin-1"))
for key, value in self._list
]
def getlist(self, key: str) -> list[str]:
get_header_key = key.lower().encode("latin-1")
return [
item_value.decode("latin-1")
for item_key, item_value in self._list
if item_key == get_header_key
]
def mutablecopy(self) -> MutableHeaders:
return MutableHeaders(raw=self._list[:])
def __getitem__(self, key: str) -> str:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return header_value.decode("latin-1")
raise KeyError(key)
def __contains__(self, key: typing.Any) -> bool:
get_header_key = key.lower().encode("latin-1")
for header_key, header_value in self._list:
if header_key == get_header_key:
return True
return False
def __iter__(self) -> typing.Iterator[typing.Any]:
return iter(self.keys())
def __len__(self) -> int:
return len(self._list)
def __eq__(self, other: typing.Any) -> bool:
if not isinstance(other, Headers):
return False
return sorted(self._list) == sorted(other._list)
def __repr__(self) -> str:
class_name = self.__class__.__name__
as_dict = dict(self.items())
if len(as_dict) == len(self):
return f"{class_name}({as_dict!r})"
return f"{class_name}(raw={self.raw!r})"
class MutableHeaders(Headers):
def __setitem__(self, key: str, value: str) -> None:
"""
Set the header `key` to `value`, removing any duplicate entries.
Retains insertion order.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
found_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
found_indexes.append(idx)
for idx in reversed(found_indexes[1:]):
del self._list[idx]
if found_indexes:
idx = found_indexes[0]
self._list[idx] = (set_key, set_value)
else:
self._list.append((set_key, set_value))
def __delitem__(self, key: str) -> None:
"""
Remove the header `key`.
"""
del_key = key.lower().encode("latin-1")
pop_indexes: list[int] = []
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == del_key:
pop_indexes.append(idx)
for idx in reversed(pop_indexes):
del self._list[idx]
def __ior__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, typing.Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
self.update(other)
return self
def __or__(self, other: typing.Mapping[str, str]) -> MutableHeaders:
if not isinstance(other, typing.Mapping):
raise TypeError(f"Expected a mapping but got {other.__class__.__name__}")
new = self.mutablecopy()
new.update(other)
return new
@property
def raw(self) -> list[tuple[bytes, bytes]]:
return self._list
def setdefault(self, key: str, value: str) -> str:
"""
If the header `key` does not exist, then set it to `value`.
Returns the header value.
"""
set_key = key.lower().encode("latin-1")
set_value = value.encode("latin-1")
for idx, (item_key, item_value) in enumerate(self._list):
if item_key == set_key:
return item_value.decode("latin-1")
self._list.append((set_key, set_value))
return value
def update(self, other: typing.Mapping[str, str]) -> None:
for key, val in other.items():
self[key] = val
def append(self, key: str, value: str) -> None:
"""
Append a header, preserving any duplicate entries.
"""
append_key = key.lower().encode("latin-1")
append_value = value.encode("latin-1")
self._list.append((append_key, append_value))
def add_vary_header(self, vary: str) -> None:
existing = self.get("vary")
if existing is not None:
vary = ", ".join([existing, vary])
self["vary"] = vary
class State:
"""
An object that can be used to store arbitrary state.
Used for `request.state` and `app.state`.
"""
_state: dict[str, typing.Any]
def __init__(self, state: dict[str, typing.Any] | None = None):
if state is None:
state = {}
super().__setattr__("_state", state)
def __setattr__(self, key: typing.Any, value: typing.Any) -> None:
self._state[key] = value
def __getattr__(self, key: typing.Any) -> typing.Any:
try:
return self._state[key]
except KeyError:
message = "'{}' object has no attribute '{}'"
raise AttributeError(message.format(self.__class__.__name__, key))
def __delattr__(self, key: typing.Any) -> None:
del self._state[key]