819 lines
30 KiB
Python
819 lines
30 KiB
Python
import inspect
|
|
from contextlib import AsyncExitStack, contextmanager
|
|
from copy import copy, deepcopy
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
Coroutine,
|
|
Dict,
|
|
ForwardRef,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
import anyio
|
|
from fastapi import params
|
|
from fastapi._compat import (
|
|
PYDANTIC_V2,
|
|
ErrorWrapper,
|
|
ModelField,
|
|
Required,
|
|
Undefined,
|
|
_regenerate_error_with_loc,
|
|
copy_field_info,
|
|
create_body_model,
|
|
evaluate_forwardref,
|
|
field_annotation_is_scalar,
|
|
get_annotation_from_field_info,
|
|
get_missing_field_error,
|
|
is_bytes_field,
|
|
is_bytes_sequence_field,
|
|
is_scalar_field,
|
|
is_scalar_sequence_field,
|
|
is_sequence_field,
|
|
is_uploadfile_or_nonable_uploadfile_annotation,
|
|
is_uploadfile_sequence_annotation,
|
|
lenient_issubclass,
|
|
sequence_types,
|
|
serialize_sequence_value,
|
|
value_is_sequence,
|
|
)
|
|
from fastapi.background import BackgroundTasks
|
|
from fastapi.concurrency import (
|
|
asynccontextmanager,
|
|
contextmanager_in_threadpool,
|
|
)
|
|
from fastapi.dependencies.models import Dependant, SecurityRequirement
|
|
from fastapi.logger import logger
|
|
from fastapi.security.base import SecurityBase
|
|
from fastapi.security.oauth2 import OAuth2, SecurityScopes
|
|
from fastapi.security.open_id_connect_url import OpenIdConnect
|
|
from fastapi.utils import create_response_field, get_path_param_names
|
|
from pydantic.fields import FieldInfo
|
|
from starlette.background import BackgroundTasks as StarletteBackgroundTasks
|
|
from starlette.concurrency import run_in_threadpool
|
|
from starlette.datastructures import FormData, Headers, QueryParams, UploadFile
|
|
from starlette.requests import HTTPConnection, Request
|
|
from starlette.responses import Response
|
|
from starlette.websockets import WebSocket
|
|
from typing_extensions import Annotated, get_args, get_origin
|
|
|
|
multipart_not_installed_error = (
|
|
'Form data requires "python-multipart" to be installed. \n'
|
|
'You can install "python-multipart" with: \n\n'
|
|
"pip install python-multipart\n"
|
|
)
|
|
multipart_incorrect_install_error = (
|
|
'Form data requires "python-multipart" to be installed. '
|
|
'It seems you installed "multipart" instead. \n'
|
|
'You can remove "multipart" with: \n\n'
|
|
"pip uninstall multipart\n\n"
|
|
'And then install "python-multipart" with: \n\n'
|
|
"pip install python-multipart\n"
|
|
)
|
|
|
|
|
|
def check_file_field(field: ModelField) -> None:
|
|
field_info = field.field_info
|
|
if isinstance(field_info, params.Form):
|
|
try:
|
|
# __version__ is available in both multiparts, and can be mocked
|
|
from multipart import __version__ # type: ignore
|
|
|
|
assert __version__
|
|
try:
|
|
# parse_options_header is only available in the right multipart
|
|
from multipart.multipart import parse_options_header # type: ignore
|
|
|
|
assert parse_options_header
|
|
except ImportError:
|
|
logger.error(multipart_incorrect_install_error)
|
|
raise RuntimeError(multipart_incorrect_install_error) from None
|
|
except ImportError:
|
|
logger.error(multipart_not_installed_error)
|
|
raise RuntimeError(multipart_not_installed_error) from None
|
|
|
|
|
|
def get_param_sub_dependant(
|
|
*,
|
|
param_name: str,
|
|
depends: params.Depends,
|
|
path: str,
|
|
security_scopes: Optional[List[str]] = None,
|
|
) -> Dependant:
|
|
assert depends.dependency
|
|
return get_sub_dependant(
|
|
depends=depends,
|
|
dependency=depends.dependency,
|
|
path=path,
|
|
name=param_name,
|
|
security_scopes=security_scopes,
|
|
)
|
|
|
|
|
|
def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
|
|
assert callable(
|
|
depends.dependency
|
|
), "A parameter-less dependency must have a callable dependency"
|
|
return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
|
|
|
|
|
|
def get_sub_dependant(
|
|
*,
|
|
depends: params.Depends,
|
|
dependency: Callable[..., Any],
|
|
path: str,
|
|
name: Optional[str] = None,
|
|
security_scopes: Optional[List[str]] = None,
|
|
) -> Dependant:
|
|
security_requirement = None
|
|
security_scopes = security_scopes or []
|
|
if isinstance(depends, params.Security):
|
|
dependency_scopes = depends.scopes
|
|
security_scopes.extend(dependency_scopes)
|
|
if isinstance(dependency, SecurityBase):
|
|
use_scopes: List[str] = []
|
|
if isinstance(dependency, (OAuth2, OpenIdConnect)):
|
|
use_scopes = security_scopes
|
|
security_requirement = SecurityRequirement(
|
|
security_scheme=dependency, scopes=use_scopes
|
|
)
|
|
sub_dependant = get_dependant(
|
|
path=path,
|
|
call=dependency,
|
|
name=name,
|
|
security_scopes=security_scopes,
|
|
use_cache=depends.use_cache,
|
|
)
|
|
if security_requirement:
|
|
sub_dependant.security_requirements.append(security_requirement)
|
|
return sub_dependant
|
|
|
|
|
|
CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
|
|
|
|
|
|
def get_flat_dependant(
|
|
dependant: Dependant,
|
|
*,
|
|
skip_repeats: bool = False,
|
|
visited: Optional[List[CacheKey]] = None,
|
|
) -> Dependant:
|
|
if visited is None:
|
|
visited = []
|
|
visited.append(dependant.cache_key)
|
|
|
|
flat_dependant = Dependant(
|
|
path_params=dependant.path_params.copy(),
|
|
query_params=dependant.query_params.copy(),
|
|
header_params=dependant.header_params.copy(),
|
|
cookie_params=dependant.cookie_params.copy(),
|
|
body_params=dependant.body_params.copy(),
|
|
security_schemes=dependant.security_requirements.copy(),
|
|
use_cache=dependant.use_cache,
|
|
path=dependant.path,
|
|
)
|
|
for sub_dependant in dependant.dependencies:
|
|
if skip_repeats and sub_dependant.cache_key in visited:
|
|
continue
|
|
flat_sub = get_flat_dependant(
|
|
sub_dependant, skip_repeats=skip_repeats, visited=visited
|
|
)
|
|
flat_dependant.path_params.extend(flat_sub.path_params)
|
|
flat_dependant.query_params.extend(flat_sub.query_params)
|
|
flat_dependant.header_params.extend(flat_sub.header_params)
|
|
flat_dependant.cookie_params.extend(flat_sub.cookie_params)
|
|
flat_dependant.body_params.extend(flat_sub.body_params)
|
|
flat_dependant.security_requirements.extend(flat_sub.security_requirements)
|
|
return flat_dependant
|
|
|
|
|
|
def get_flat_params(dependant: Dependant) -> List[ModelField]:
|
|
flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
|
|
return (
|
|
flat_dependant.path_params
|
|
+ flat_dependant.query_params
|
|
+ flat_dependant.header_params
|
|
+ flat_dependant.cookie_params
|
|
)
|
|
|
|
|
|
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
|
|
signature = inspect.signature(call)
|
|
globalns = getattr(call, "__globals__", {})
|
|
typed_params = [
|
|
inspect.Parameter(
|
|
name=param.name,
|
|
kind=param.kind,
|
|
default=param.default,
|
|
annotation=get_typed_annotation(param.annotation, globalns),
|
|
)
|
|
for param in signature.parameters.values()
|
|
]
|
|
typed_signature = inspect.Signature(typed_params)
|
|
return typed_signature
|
|
|
|
|
|
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
|
|
if isinstance(annotation, str):
|
|
annotation = ForwardRef(annotation)
|
|
annotation = evaluate_forwardref(annotation, globalns, globalns)
|
|
return annotation
|
|
|
|
|
|
def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
|
|
signature = inspect.signature(call)
|
|
annotation = signature.return_annotation
|
|
|
|
if annotation is inspect.Signature.empty:
|
|
return None
|
|
|
|
globalns = getattr(call, "__globals__", {})
|
|
return get_typed_annotation(annotation, globalns)
|
|
|
|
|
|
def get_dependant(
|
|
*,
|
|
path: str,
|
|
call: Callable[..., Any],
|
|
name: Optional[str] = None,
|
|
security_scopes: Optional[List[str]] = None,
|
|
use_cache: bool = True,
|
|
) -> Dependant:
|
|
path_param_names = get_path_param_names(path)
|
|
endpoint_signature = get_typed_signature(call)
|
|
signature_params = endpoint_signature.parameters
|
|
dependant = Dependant(
|
|
call=call,
|
|
name=name,
|
|
path=path,
|
|
security_scopes=security_scopes,
|
|
use_cache=use_cache,
|
|
)
|
|
for param_name, param in signature_params.items():
|
|
is_path_param = param_name in path_param_names
|
|
type_annotation, depends, param_field = analyze_param(
|
|
param_name=param_name,
|
|
annotation=param.annotation,
|
|
value=param.default,
|
|
is_path_param=is_path_param,
|
|
)
|
|
if depends is not None:
|
|
sub_dependant = get_param_sub_dependant(
|
|
param_name=param_name,
|
|
depends=depends,
|
|
path=path,
|
|
security_scopes=security_scopes,
|
|
)
|
|
dependant.dependencies.append(sub_dependant)
|
|
continue
|
|
if add_non_field_param_to_dependency(
|
|
param_name=param_name,
|
|
type_annotation=type_annotation,
|
|
dependant=dependant,
|
|
):
|
|
assert (
|
|
param_field is None
|
|
), f"Cannot specify multiple FastAPI annotations for {param_name!r}"
|
|
continue
|
|
assert param_field is not None
|
|
if is_body_param(param_field=param_field, is_path_param=is_path_param):
|
|
dependant.body_params.append(param_field)
|
|
else:
|
|
add_param_to_fields(field=param_field, dependant=dependant)
|
|
return dependant
|
|
|
|
|
|
def add_non_field_param_to_dependency(
|
|
*, param_name: str, type_annotation: Any, dependant: Dependant
|
|
) -> Optional[bool]:
|
|
if lenient_issubclass(type_annotation, Request):
|
|
dependant.request_param_name = param_name
|
|
return True
|
|
elif lenient_issubclass(type_annotation, WebSocket):
|
|
dependant.websocket_param_name = param_name
|
|
return True
|
|
elif lenient_issubclass(type_annotation, HTTPConnection):
|
|
dependant.http_connection_param_name = param_name
|
|
return True
|
|
elif lenient_issubclass(type_annotation, Response):
|
|
dependant.response_param_name = param_name
|
|
return True
|
|
elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
|
|
dependant.background_tasks_param_name = param_name
|
|
return True
|
|
elif lenient_issubclass(type_annotation, SecurityScopes):
|
|
dependant.security_scopes_param_name = param_name
|
|
return True
|
|
return None
|
|
|
|
|
|
def analyze_param(
|
|
*,
|
|
param_name: str,
|
|
annotation: Any,
|
|
value: Any,
|
|
is_path_param: bool,
|
|
) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]:
|
|
field_info = None
|
|
depends = None
|
|
type_annotation: Any = Any
|
|
use_annotation: Any = Any
|
|
if annotation is not inspect.Signature.empty:
|
|
use_annotation = annotation
|
|
type_annotation = annotation
|
|
if get_origin(use_annotation) is Annotated:
|
|
annotated_args = get_args(annotation)
|
|
type_annotation = annotated_args[0]
|
|
fastapi_annotations = [
|
|
arg
|
|
for arg in annotated_args[1:]
|
|
if isinstance(arg, (FieldInfo, params.Depends))
|
|
]
|
|
fastapi_specific_annotations = [
|
|
arg
|
|
for arg in fastapi_annotations
|
|
if isinstance(arg, (params.Param, params.Body, params.Depends))
|
|
]
|
|
if fastapi_specific_annotations:
|
|
fastapi_annotation: Union[
|
|
FieldInfo, params.Depends, None
|
|
] = fastapi_specific_annotations[-1]
|
|
else:
|
|
fastapi_annotation = None
|
|
if isinstance(fastapi_annotation, FieldInfo):
|
|
# Copy `field_info` because we mutate `field_info.default` below.
|
|
field_info = copy_field_info(
|
|
field_info=fastapi_annotation, annotation=use_annotation
|
|
)
|
|
assert field_info.default is Undefined or field_info.default is Required, (
|
|
f"`{field_info.__class__.__name__}` default value cannot be set in"
|
|
f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
|
|
)
|
|
if value is not inspect.Signature.empty:
|
|
assert not is_path_param, "Path parameters cannot have default values"
|
|
field_info.default = value
|
|
else:
|
|
field_info.default = Required
|
|
elif isinstance(fastapi_annotation, params.Depends):
|
|
depends = fastapi_annotation
|
|
|
|
if isinstance(value, params.Depends):
|
|
assert depends is None, (
|
|
"Cannot specify `Depends` in `Annotated` and default value"
|
|
f" together for {param_name!r}"
|
|
)
|
|
assert field_info is None, (
|
|
"Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
|
|
f" default value together for {param_name!r}"
|
|
)
|
|
depends = value
|
|
elif isinstance(value, FieldInfo):
|
|
assert field_info is None, (
|
|
"Cannot specify FastAPI annotations in `Annotated` and default value"
|
|
f" together for {param_name!r}"
|
|
)
|
|
field_info = value
|
|
if PYDANTIC_V2:
|
|
field_info.annotation = type_annotation
|
|
|
|
if depends is not None and depends.dependency is None:
|
|
# Copy `depends` before mutating it
|
|
depends = copy(depends)
|
|
depends.dependency = type_annotation
|
|
|
|
if lenient_issubclass(
|
|
type_annotation,
|
|
(
|
|
Request,
|
|
WebSocket,
|
|
HTTPConnection,
|
|
Response,
|
|
StarletteBackgroundTasks,
|
|
SecurityScopes,
|
|
),
|
|
):
|
|
assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
|
|
assert (
|
|
field_info is None
|
|
), f"Cannot specify FastAPI annotation for type {type_annotation!r}"
|
|
elif field_info is None and depends is None:
|
|
default_value = value if value is not inspect.Signature.empty else Required
|
|
if is_path_param:
|
|
# We might check here that `default_value is Required`, but the fact is that the same
|
|
# parameter might sometimes be a path parameter and sometimes not. See
|
|
# `tests/test_infer_param_optionality.py` for an example.
|
|
field_info = params.Path(annotation=use_annotation)
|
|
elif is_uploadfile_or_nonable_uploadfile_annotation(
|
|
type_annotation
|
|
) or is_uploadfile_sequence_annotation(type_annotation):
|
|
field_info = params.File(annotation=use_annotation, default=default_value)
|
|
elif not field_annotation_is_scalar(annotation=type_annotation):
|
|
field_info = params.Body(annotation=use_annotation, default=default_value)
|
|
else:
|
|
field_info = params.Query(annotation=use_annotation, default=default_value)
|
|
|
|
field = None
|
|
if field_info is not None:
|
|
if is_path_param:
|
|
assert isinstance(field_info, params.Path), (
|
|
f"Cannot use `{field_info.__class__.__name__}` for path param"
|
|
f" {param_name!r}"
|
|
)
|
|
elif (
|
|
isinstance(field_info, params.Param)
|
|
and getattr(field_info, "in_", None) is None
|
|
):
|
|
field_info.in_ = params.ParamTypes.query
|
|
use_annotation_from_field_info = get_annotation_from_field_info(
|
|
use_annotation,
|
|
field_info,
|
|
param_name,
|
|
)
|
|
if not field_info.alias and getattr(field_info, "convert_underscores", None):
|
|
alias = param_name.replace("_", "-")
|
|
else:
|
|
alias = field_info.alias or param_name
|
|
field_info.alias = alias
|
|
field = create_response_field(
|
|
name=param_name,
|
|
type_=use_annotation_from_field_info,
|
|
default=field_info.default,
|
|
alias=alias,
|
|
required=field_info.default in (Required, Undefined),
|
|
field_info=field_info,
|
|
)
|
|
|
|
return type_annotation, depends, field
|
|
|
|
|
|
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
|
|
if is_path_param:
|
|
assert is_scalar_field(
|
|
field=param_field
|
|
), "Path params must be of one of the supported types"
|
|
return False
|
|
elif is_scalar_field(field=param_field):
|
|
return False
|
|
elif isinstance(
|
|
param_field.field_info, (params.Query, params.Header)
|
|
) and is_scalar_sequence_field(param_field):
|
|
return False
|
|
else:
|
|
assert isinstance(
|
|
param_field.field_info, params.Body
|
|
), f"Param: {param_field.name} can only be a request body, using Body()"
|
|
return True
|
|
|
|
|
|
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
|
|
field_info = field.field_info
|
|
field_info_in = getattr(field_info, "in_", None)
|
|
if field_info_in == params.ParamTypes.path:
|
|
dependant.path_params.append(field)
|
|
elif field_info_in == params.ParamTypes.query:
|
|
dependant.query_params.append(field)
|
|
elif field_info_in == params.ParamTypes.header:
|
|
dependant.header_params.append(field)
|
|
else:
|
|
assert (
|
|
field_info_in == params.ParamTypes.cookie
|
|
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
|
|
dependant.cookie_params.append(field)
|
|
|
|
|
|
def is_coroutine_callable(call: Callable[..., Any]) -> bool:
|
|
if inspect.isroutine(call):
|
|
return inspect.iscoroutinefunction(call)
|
|
if inspect.isclass(call):
|
|
return False
|
|
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
|
return inspect.iscoroutinefunction(dunder_call)
|
|
|
|
|
|
def is_async_gen_callable(call: Callable[..., Any]) -> bool:
|
|
if inspect.isasyncgenfunction(call):
|
|
return True
|
|
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
|
return inspect.isasyncgenfunction(dunder_call)
|
|
|
|
|
|
def is_gen_callable(call: Callable[..., Any]) -> bool:
|
|
if inspect.isgeneratorfunction(call):
|
|
return True
|
|
dunder_call = getattr(call, "__call__", None) # noqa: B004
|
|
return inspect.isgeneratorfunction(dunder_call)
|
|
|
|
|
|
async def solve_generator(
|
|
*, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
|
|
) -> Any:
|
|
if is_gen_callable(call):
|
|
cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
|
|
elif is_async_gen_callable(call):
|
|
cm = asynccontextmanager(call)(**sub_values)
|
|
return await stack.enter_async_context(cm)
|
|
|
|
|
|
async def solve_dependencies(
|
|
*,
|
|
request: Union[Request, WebSocket],
|
|
dependant: Dependant,
|
|
body: Optional[Union[Dict[str, Any], FormData]] = None,
|
|
background_tasks: Optional[StarletteBackgroundTasks] = None,
|
|
response: Optional[Response] = None,
|
|
dependency_overrides_provider: Optional[Any] = None,
|
|
dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
|
|
async_exit_stack: AsyncExitStack,
|
|
) -> Tuple[
|
|
Dict[str, Any],
|
|
List[Any],
|
|
Optional[StarletteBackgroundTasks],
|
|
Response,
|
|
Dict[Tuple[Callable[..., Any], Tuple[str]], Any],
|
|
]:
|
|
values: Dict[str, Any] = {}
|
|
errors: List[Any] = []
|
|
if response is None:
|
|
response = Response()
|
|
del response.headers["content-length"]
|
|
response.status_code = None # type: ignore
|
|
dependency_cache = dependency_cache or {}
|
|
sub_dependant: Dependant
|
|
for sub_dependant in dependant.dependencies:
|
|
sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
|
|
sub_dependant.cache_key = cast(
|
|
Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
|
|
)
|
|
call = sub_dependant.call
|
|
use_sub_dependant = sub_dependant
|
|
if (
|
|
dependency_overrides_provider
|
|
and dependency_overrides_provider.dependency_overrides
|
|
):
|
|
original_call = sub_dependant.call
|
|
call = getattr(
|
|
dependency_overrides_provider, "dependency_overrides", {}
|
|
).get(original_call, original_call)
|
|
use_path: str = sub_dependant.path # type: ignore
|
|
use_sub_dependant = get_dependant(
|
|
path=use_path,
|
|
call=call,
|
|
name=sub_dependant.name,
|
|
security_scopes=sub_dependant.security_scopes,
|
|
)
|
|
|
|
solved_result = await solve_dependencies(
|
|
request=request,
|
|
dependant=use_sub_dependant,
|
|
body=body,
|
|
background_tasks=background_tasks,
|
|
response=response,
|
|
dependency_overrides_provider=dependency_overrides_provider,
|
|
dependency_cache=dependency_cache,
|
|
async_exit_stack=async_exit_stack,
|
|
)
|
|
(
|
|
sub_values,
|
|
sub_errors,
|
|
background_tasks,
|
|
_, # the subdependency returns the same response we have
|
|
sub_dependency_cache,
|
|
) = solved_result
|
|
dependency_cache.update(sub_dependency_cache)
|
|
if sub_errors:
|
|
errors.extend(sub_errors)
|
|
continue
|
|
if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
|
|
solved = dependency_cache[sub_dependant.cache_key]
|
|
elif is_gen_callable(call) or is_async_gen_callable(call):
|
|
solved = await solve_generator(
|
|
call=call, stack=async_exit_stack, sub_values=sub_values
|
|
)
|
|
elif is_coroutine_callable(call):
|
|
solved = await call(**sub_values)
|
|
else:
|
|
solved = await run_in_threadpool(call, **sub_values)
|
|
if sub_dependant.name is not None:
|
|
values[sub_dependant.name] = solved
|
|
if sub_dependant.cache_key not in dependency_cache:
|
|
dependency_cache[sub_dependant.cache_key] = solved
|
|
path_values, path_errors = request_params_to_args(
|
|
dependant.path_params, request.path_params
|
|
)
|
|
query_values, query_errors = request_params_to_args(
|
|
dependant.query_params, request.query_params
|
|
)
|
|
header_values, header_errors = request_params_to_args(
|
|
dependant.header_params, request.headers
|
|
)
|
|
cookie_values, cookie_errors = request_params_to_args(
|
|
dependant.cookie_params, request.cookies
|
|
)
|
|
values.update(path_values)
|
|
values.update(query_values)
|
|
values.update(header_values)
|
|
values.update(cookie_values)
|
|
errors += path_errors + query_errors + header_errors + cookie_errors
|
|
if dependant.body_params:
|
|
(
|
|
body_values,
|
|
body_errors,
|
|
) = await request_body_to_args( # body_params checked above
|
|
required_params=dependant.body_params, received_body=body
|
|
)
|
|
values.update(body_values)
|
|
errors.extend(body_errors)
|
|
if dependant.http_connection_param_name:
|
|
values[dependant.http_connection_param_name] = request
|
|
if dependant.request_param_name and isinstance(request, Request):
|
|
values[dependant.request_param_name] = request
|
|
elif dependant.websocket_param_name and isinstance(request, WebSocket):
|
|
values[dependant.websocket_param_name] = request
|
|
if dependant.background_tasks_param_name:
|
|
if background_tasks is None:
|
|
background_tasks = BackgroundTasks()
|
|
values[dependant.background_tasks_param_name] = background_tasks
|
|
if dependant.response_param_name:
|
|
values[dependant.response_param_name] = response
|
|
if dependant.security_scopes_param_name:
|
|
values[dependant.security_scopes_param_name] = SecurityScopes(
|
|
scopes=dependant.security_scopes
|
|
)
|
|
return values, errors, background_tasks, response, dependency_cache
|
|
|
|
|
|
def request_params_to_args(
|
|
required_params: Sequence[ModelField],
|
|
received_params: Union[Mapping[str, Any], QueryParams, Headers],
|
|
) -> Tuple[Dict[str, Any], List[Any]]:
|
|
values = {}
|
|
errors = []
|
|
for field in required_params:
|
|
if is_scalar_sequence_field(field) and isinstance(
|
|
received_params, (QueryParams, Headers)
|
|
):
|
|
value = received_params.getlist(field.alias) or field.default
|
|
else:
|
|
value = received_params.get(field.alias)
|
|
field_info = field.field_info
|
|
assert isinstance(
|
|
field_info, params.Param
|
|
), "Params must be subclasses of Param"
|
|
loc = (field_info.in_.value, field.alias)
|
|
if value is None:
|
|
if field.required:
|
|
errors.append(get_missing_field_error(loc=loc))
|
|
else:
|
|
values[field.name] = deepcopy(field.default)
|
|
continue
|
|
v_, errors_ = field.validate(value, values, loc=loc)
|
|
if isinstance(errors_, ErrorWrapper):
|
|
errors.append(errors_)
|
|
elif isinstance(errors_, list):
|
|
new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
|
|
errors.extend(new_errors)
|
|
else:
|
|
values[field.name] = v_
|
|
return values, errors
|
|
|
|
|
|
async def request_body_to_args(
|
|
required_params: List[ModelField],
|
|
received_body: Optional[Union[Dict[str, Any], FormData]],
|
|
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
|
values = {}
|
|
errors: List[Dict[str, Any]] = []
|
|
if required_params:
|
|
field = required_params[0]
|
|
field_info = field.field_info
|
|
embed = getattr(field_info, "embed", None)
|
|
field_alias_omitted = len(required_params) == 1 and not embed
|
|
if field_alias_omitted:
|
|
received_body = {field.alias: received_body}
|
|
|
|
for field in required_params:
|
|
loc: Tuple[str, ...]
|
|
if field_alias_omitted:
|
|
loc = ("body",)
|
|
else:
|
|
loc = ("body", field.alias)
|
|
|
|
value: Optional[Any] = None
|
|
if received_body is not None:
|
|
if (is_sequence_field(field)) and isinstance(received_body, FormData):
|
|
value = received_body.getlist(field.alias)
|
|
else:
|
|
try:
|
|
value = received_body.get(field.alias)
|
|
except AttributeError:
|
|
errors.append(get_missing_field_error(loc))
|
|
continue
|
|
if (
|
|
value is None
|
|
or (isinstance(field_info, params.Form) and value == "")
|
|
or (
|
|
isinstance(field_info, params.Form)
|
|
and is_sequence_field(field)
|
|
and len(value) == 0
|
|
)
|
|
):
|
|
if field.required:
|
|
errors.append(get_missing_field_error(loc))
|
|
else:
|
|
values[field.name] = deepcopy(field.default)
|
|
continue
|
|
if (
|
|
isinstance(field_info, params.File)
|
|
and is_bytes_field(field)
|
|
and isinstance(value, UploadFile)
|
|
):
|
|
value = await value.read()
|
|
elif (
|
|
is_bytes_sequence_field(field)
|
|
and isinstance(field_info, params.File)
|
|
and value_is_sequence(value)
|
|
):
|
|
# For types
|
|
assert isinstance(value, sequence_types) # type: ignore[arg-type]
|
|
results: List[Union[bytes, str]] = []
|
|
|
|
async def process_fn(
|
|
fn: Callable[[], Coroutine[Any, Any, Any]],
|
|
) -> None:
|
|
result = await fn()
|
|
results.append(result) # noqa: B023
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
for sub_value in value:
|
|
tg.start_soon(process_fn, sub_value.read)
|
|
value = serialize_sequence_value(field=field, value=results)
|
|
|
|
v_, errors_ = field.validate(value, values, loc=loc)
|
|
|
|
if isinstance(errors_, list):
|
|
errors.extend(errors_)
|
|
elif errors_:
|
|
errors.append(errors_)
|
|
else:
|
|
values[field.name] = v_
|
|
return values, errors
|
|
|
|
|
|
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
|
|
flat_dependant = get_flat_dependant(dependant)
|
|
if not flat_dependant.body_params:
|
|
return None
|
|
first_param = flat_dependant.body_params[0]
|
|
field_info = first_param.field_info
|
|
embed = getattr(field_info, "embed", None)
|
|
body_param_names_set = {param.name for param in flat_dependant.body_params}
|
|
if len(body_param_names_set) == 1 and not embed:
|
|
check_file_field(first_param)
|
|
return first_param
|
|
# If one field requires to embed, all have to be embedded
|
|
# in case a sub-dependency is evaluated with a single unique body field
|
|
# That is combined (embedded) with other body fields
|
|
for param in flat_dependant.body_params:
|
|
setattr(param.field_info, "embed", True) # noqa: B010
|
|
model_name = "Body_" + name
|
|
BodyModel = create_body_model(
|
|
fields=flat_dependant.body_params, model_name=model_name
|
|
)
|
|
required = any(True for f in flat_dependant.body_params if f.required)
|
|
BodyFieldInfo_kwargs: Dict[str, Any] = {
|
|
"annotation": BodyModel,
|
|
"alias": "body",
|
|
}
|
|
if not required:
|
|
BodyFieldInfo_kwargs["default"] = None
|
|
if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
|
|
BodyFieldInfo: Type[params.Body] = params.File
|
|
elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
|
|
BodyFieldInfo = params.Form
|
|
else:
|
|
BodyFieldInfo = params.Body
|
|
|
|
body_param_media_types = [
|
|
f.field_info.media_type
|
|
for f in flat_dependant.body_params
|
|
if isinstance(f.field_info, params.Body)
|
|
]
|
|
if len(set(body_param_media_types)) == 1:
|
|
BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
|
|
final_field = create_response_field(
|
|
name="body",
|
|
type_=BodyModel,
|
|
required=required,
|
|
alias="body",
|
|
field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
|
|
)
|
|
check_file_field(final_field)
|
|
return final_field
|