279 lines
10 KiB
Python
279 lines
10 KiB
Python
|
from __future__ import annotations
|
||
|
|
||
|
import typing
|
||
|
from dataclasses import dataclass, field
|
||
|
from enum import Enum
|
||
|
from tempfile import SpooledTemporaryFile
|
||
|
from urllib.parse import unquote_plus
|
||
|
|
||
|
from starlette.datastructures import FormData, Headers, UploadFile
|
||
|
|
||
|
try:
|
||
|
import multipart
|
||
|
from multipart.multipart import parse_options_header
|
||
|
except ModuleNotFoundError: # pragma: nocover
|
||
|
parse_options_header = None
|
||
|
multipart = None
|
||
|
|
||
|
|
||
|
class FormMessage(Enum):
|
||
|
FIELD_START = 1
|
||
|
FIELD_NAME = 2
|
||
|
FIELD_DATA = 3
|
||
|
FIELD_END = 4
|
||
|
END = 5
|
||
|
|
||
|
|
||
|
@dataclass
|
||
|
class MultipartPart:
|
||
|
content_disposition: bytes | None = None
|
||
|
field_name: str = ""
|
||
|
data: bytes = b""
|
||
|
file: UploadFile | None = None
|
||
|
item_headers: list[tuple[bytes, bytes]] = field(default_factory=list)
|
||
|
|
||
|
|
||
|
def _user_safe_decode(src: bytes, codec: str) -> str:
|
||
|
try:
|
||
|
return src.decode(codec)
|
||
|
except (UnicodeDecodeError, LookupError):
|
||
|
return src.decode("latin-1")
|
||
|
|
||
|
|
||
|
class MultiPartException(Exception):
|
||
|
def __init__(self, message: str) -> None:
|
||
|
self.message = message
|
||
|
|
||
|
|
||
|
class FormParser:
|
||
|
def __init__(
|
||
|
self, headers: Headers, stream: typing.AsyncGenerator[bytes, None]
|
||
|
) -> None:
|
||
|
assert (
|
||
|
multipart is not None
|
||
|
), "The `python-multipart` library must be installed to use form parsing."
|
||
|
self.headers = headers
|
||
|
self.stream = stream
|
||
|
self.messages: list[tuple[FormMessage, bytes]] = []
|
||
|
|
||
|
def on_field_start(self) -> None:
|
||
|
message = (FormMessage.FIELD_START, b"")
|
||
|
self.messages.append(message)
|
||
|
|
||
|
def on_field_name(self, data: bytes, start: int, end: int) -> None:
|
||
|
message = (FormMessage.FIELD_NAME, data[start:end])
|
||
|
self.messages.append(message)
|
||
|
|
||
|
def on_field_data(self, data: bytes, start: int, end: int) -> None:
|
||
|
message = (FormMessage.FIELD_DATA, data[start:end])
|
||
|
self.messages.append(message)
|
||
|
|
||
|
def on_field_end(self) -> None:
|
||
|
message = (FormMessage.FIELD_END, b"")
|
||
|
self.messages.append(message)
|
||
|
|
||
|
def on_end(self) -> None:
|
||
|
message = (FormMessage.END, b"")
|
||
|
self.messages.append(message)
|
||
|
|
||
|
async def parse(self) -> FormData:
|
||
|
# Callbacks dictionary.
|
||
|
callbacks = {
|
||
|
"on_field_start": self.on_field_start,
|
||
|
"on_field_name": self.on_field_name,
|
||
|
"on_field_data": self.on_field_data,
|
||
|
"on_field_end": self.on_field_end,
|
||
|
"on_end": self.on_end,
|
||
|
}
|
||
|
|
||
|
# Create the parser.
|
||
|
parser = multipart.QuerystringParser(callbacks)
|
||
|
field_name = b""
|
||
|
field_value = b""
|
||
|
|
||
|
items: list[tuple[str, str | UploadFile]] = []
|
||
|
|
||
|
# Feed the parser with data from the request.
|
||
|
async for chunk in self.stream:
|
||
|
if chunk:
|
||
|
parser.write(chunk)
|
||
|
else:
|
||
|
parser.finalize()
|
||
|
messages = list(self.messages)
|
||
|
self.messages.clear()
|
||
|
for message_type, message_bytes in messages:
|
||
|
if message_type == FormMessage.FIELD_START:
|
||
|
field_name = b""
|
||
|
field_value = b""
|
||
|
elif message_type == FormMessage.FIELD_NAME:
|
||
|
field_name += message_bytes
|
||
|
elif message_type == FormMessage.FIELD_DATA:
|
||
|
field_value += message_bytes
|
||
|
elif message_type == FormMessage.FIELD_END:
|
||
|
name = unquote_plus(field_name.decode("latin-1"))
|
||
|
value = unquote_plus(field_value.decode("latin-1"))
|
||
|
items.append((name, value))
|
||
|
|
||
|
return FormData(items)
|
||
|
|
||
|
|
||
|
class MultiPartParser:
|
||
|
max_file_size = 1024 * 1024
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
headers: Headers,
|
||
|
stream: typing.AsyncGenerator[bytes, None],
|
||
|
*,
|
||
|
max_files: int | float = 1000,
|
||
|
max_fields: int | float = 1000,
|
||
|
) -> None:
|
||
|
assert (
|
||
|
multipart is not None
|
||
|
), "The `python-multipart` library must be installed to use form parsing."
|
||
|
self.headers = headers
|
||
|
self.stream = stream
|
||
|
self.max_files = max_files
|
||
|
self.max_fields = max_fields
|
||
|
self.items: list[tuple[str, str | UploadFile]] = []
|
||
|
self._current_files = 0
|
||
|
self._current_fields = 0
|
||
|
self._current_partial_header_name: bytes = b""
|
||
|
self._current_partial_header_value: bytes = b""
|
||
|
self._current_part = MultipartPart()
|
||
|
self._charset = ""
|
||
|
self._file_parts_to_write: list[tuple[MultipartPart, bytes]] = []
|
||
|
self._file_parts_to_finish: list[MultipartPart] = []
|
||
|
self._files_to_close_on_error: list[SpooledTemporaryFile[bytes]] = []
|
||
|
|
||
|
def on_part_begin(self) -> None:
|
||
|
self._current_part = MultipartPart()
|
||
|
|
||
|
def on_part_data(self, data: bytes, start: int, end: int) -> None:
|
||
|
message_bytes = data[start:end]
|
||
|
if self._current_part.file is None:
|
||
|
self._current_part.data += message_bytes
|
||
|
else:
|
||
|
self._file_parts_to_write.append((self._current_part, message_bytes))
|
||
|
|
||
|
def on_part_end(self) -> None:
|
||
|
if self._current_part.file is None:
|
||
|
self.items.append(
|
||
|
(
|
||
|
self._current_part.field_name,
|
||
|
_user_safe_decode(self._current_part.data, self._charset),
|
||
|
)
|
||
|
)
|
||
|
else:
|
||
|
self._file_parts_to_finish.append(self._current_part)
|
||
|
# The file can be added to the items right now even though it's not
|
||
|
# finished yet, because it will be finished in the `parse()` method, before
|
||
|
# self.items is used in the return value.
|
||
|
self.items.append((self._current_part.field_name, self._current_part.file))
|
||
|
|
||
|
def on_header_field(self, data: bytes, start: int, end: int) -> None:
|
||
|
self._current_partial_header_name += data[start:end]
|
||
|
|
||
|
def on_header_value(self, data: bytes, start: int, end: int) -> None:
|
||
|
self._current_partial_header_value += data[start:end]
|
||
|
|
||
|
def on_header_end(self) -> None:
|
||
|
field = self._current_partial_header_name.lower()
|
||
|
if field == b"content-disposition":
|
||
|
self._current_part.content_disposition = self._current_partial_header_value
|
||
|
self._current_part.item_headers.append(
|
||
|
(field, self._current_partial_header_value)
|
||
|
)
|
||
|
self._current_partial_header_name = b""
|
||
|
self._current_partial_header_value = b""
|
||
|
|
||
|
def on_headers_finished(self) -> None:
|
||
|
disposition, options = parse_options_header(
|
||
|
self._current_part.content_disposition
|
||
|
)
|
||
|
try:
|
||
|
self._current_part.field_name = _user_safe_decode(
|
||
|
options[b"name"], self._charset
|
||
|
)
|
||
|
except KeyError:
|
||
|
raise MultiPartException(
|
||
|
'The Content-Disposition header field "name" must be ' "provided."
|
||
|
)
|
||
|
if b"filename" in options:
|
||
|
self._current_files += 1
|
||
|
if self._current_files > self.max_files:
|
||
|
raise MultiPartException(
|
||
|
f"Too many files. Maximum number of files is {self.max_files}."
|
||
|
)
|
||
|
filename = _user_safe_decode(options[b"filename"], self._charset)
|
||
|
tempfile = SpooledTemporaryFile(max_size=self.max_file_size)
|
||
|
self._files_to_close_on_error.append(tempfile)
|
||
|
self._current_part.file = UploadFile(
|
||
|
file=tempfile, # type: ignore[arg-type]
|
||
|
size=0,
|
||
|
filename=filename,
|
||
|
headers=Headers(raw=self._current_part.item_headers),
|
||
|
)
|
||
|
else:
|
||
|
self._current_fields += 1
|
||
|
if self._current_fields > self.max_fields:
|
||
|
raise MultiPartException(
|
||
|
f"Too many fields. Maximum number of fields is {self.max_fields}."
|
||
|
)
|
||
|
self._current_part.file = None
|
||
|
|
||
|
def on_end(self) -> None:
|
||
|
pass
|
||
|
|
||
|
async def parse(self) -> FormData:
|
||
|
# Parse the Content-Type header to get the multipart boundary.
|
||
|
_, params = parse_options_header(self.headers["Content-Type"])
|
||
|
charset = params.get(b"charset", "utf-8")
|
||
|
if isinstance(charset, bytes):
|
||
|
charset = charset.decode("latin-1")
|
||
|
self._charset = charset
|
||
|
try:
|
||
|
boundary = params[b"boundary"]
|
||
|
except KeyError:
|
||
|
raise MultiPartException("Missing boundary in multipart.")
|
||
|
|
||
|
# Callbacks dictionary.
|
||
|
callbacks = {
|
||
|
"on_part_begin": self.on_part_begin,
|
||
|
"on_part_data": self.on_part_data,
|
||
|
"on_part_end": self.on_part_end,
|
||
|
"on_header_field": self.on_header_field,
|
||
|
"on_header_value": self.on_header_value,
|
||
|
"on_header_end": self.on_header_end,
|
||
|
"on_headers_finished": self.on_headers_finished,
|
||
|
"on_end": self.on_end,
|
||
|
}
|
||
|
|
||
|
# Create the parser.
|
||
|
parser = multipart.MultipartParser(boundary, callbacks)
|
||
|
try:
|
||
|
# Feed the parser with data from the request.
|
||
|
async for chunk in self.stream:
|
||
|
parser.write(chunk)
|
||
|
# Write file data, it needs to use await with the UploadFile methods
|
||
|
# that call the corresponding file methods *in a threadpool*,
|
||
|
# otherwise, if they were called directly in the callback methods above
|
||
|
# (regular, non-async functions), that would block the event loop in
|
||
|
# the main thread.
|
||
|
for part, data in self._file_parts_to_write:
|
||
|
assert part.file # for type checkers
|
||
|
await part.file.write(data)
|
||
|
for part in self._file_parts_to_finish:
|
||
|
assert part.file # for type checkers
|
||
|
await part.file.seek(0)
|
||
|
self._file_parts_to_write.clear()
|
||
|
self._file_parts_to_finish.clear()
|
||
|
except MultiPartException as exc:
|
||
|
# Close all the files if there was an error.
|
||
|
for file in self._files_to_close_on_error:
|
||
|
file.close()
|
||
|
raise exc
|
||
|
|
||
|
parser.finalize()
|
||
|
return FormData(self.items)
|