import inspect from contextlib import AsyncExitStack, contextmanager from copy import copy, deepcopy from typing import ( Any, Callable, Coroutine, Dict, ForwardRef, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast, ) import anyio from fastapi import params from fastapi._compat import ( PYDANTIC_V2, ErrorWrapper, ModelField, Required, Undefined, _regenerate_error_with_loc, copy_field_info, create_body_model, evaluate_forwardref, field_annotation_is_scalar, get_annotation_from_field_info, get_missing_field_error, is_bytes_field, is_bytes_sequence_field, is_scalar_field, is_scalar_sequence_field, is_sequence_field, is_uploadfile_or_nonable_uploadfile_annotation, is_uploadfile_sequence_annotation, lenient_issubclass, sequence_types, serialize_sequence_value, value_is_sequence, ) from fastapi.background import BackgroundTasks from fastapi.concurrency import ( asynccontextmanager, contextmanager_in_threadpool, ) from fastapi.dependencies.models import Dependant, SecurityRequirement from fastapi.logger import logger from fastapi.security.base import SecurityBase from fastapi.security.oauth2 import OAuth2, SecurityScopes from fastapi.security.open_id_connect_url import OpenIdConnect from fastapi.utils import create_response_field, get_path_param_names from pydantic.fields import FieldInfo from starlette.background import BackgroundTasks as StarletteBackgroundTasks from starlette.concurrency import run_in_threadpool from starlette.datastructures import FormData, Headers, QueryParams, UploadFile from starlette.requests import HTTPConnection, Request from starlette.responses import Response from starlette.websockets import WebSocket from typing_extensions import Annotated, get_args, get_origin multipart_not_installed_error = ( 'Form data requires "python-multipart" to be installed. \n' 'You can install "python-multipart" with: \n\n' "pip install python-multipart\n" ) multipart_incorrect_install_error = ( 'Form data requires "python-multipart" to be installed. ' 'It seems you installed "multipart" instead. \n' 'You can remove "multipart" with: \n\n' "pip uninstall multipart\n\n" 'And then install "python-multipart" with: \n\n' "pip install python-multipart\n" ) def check_file_field(field: ModelField) -> None: field_info = field.field_info if isinstance(field_info, params.Form): try: # __version__ is available in both multiparts, and can be mocked from multipart import __version__ # type: ignore assert __version__ try: # parse_options_header is only available in the right multipart from multipart.multipart import parse_options_header # type: ignore assert parse_options_header except ImportError: logger.error(multipart_incorrect_install_error) raise RuntimeError(multipart_incorrect_install_error) from None except ImportError: logger.error(multipart_not_installed_error) raise RuntimeError(multipart_not_installed_error) from None def get_param_sub_dependant( *, param_name: str, depends: params.Depends, path: str, security_scopes: Optional[List[str]] = None, ) -> Dependant: assert depends.dependency return get_sub_dependant( depends=depends, dependency=depends.dependency, path=path, name=param_name, security_scopes=security_scopes, ) def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant: assert callable( depends.dependency ), "A parameter-less dependency must have a callable dependency" return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path) def get_sub_dependant( *, depends: params.Depends, dependency: Callable[..., Any], path: str, name: Optional[str] = None, security_scopes: Optional[List[str]] = None, ) -> Dependant: security_requirement = None security_scopes = security_scopes or [] if isinstance(depends, params.Security): dependency_scopes = depends.scopes security_scopes.extend(dependency_scopes) if isinstance(dependency, SecurityBase): use_scopes: List[str] = [] if isinstance(dependency, (OAuth2, OpenIdConnect)): use_scopes = security_scopes security_requirement = SecurityRequirement( security_scheme=dependency, scopes=use_scopes ) sub_dependant = get_dependant( path=path, call=dependency, name=name, security_scopes=security_scopes, use_cache=depends.use_cache, ) if security_requirement: sub_dependant.security_requirements.append(security_requirement) return sub_dependant CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]] def get_flat_dependant( dependant: Dependant, *, skip_repeats: bool = False, visited: Optional[List[CacheKey]] = None, ) -> Dependant: if visited is None: visited = [] visited.append(dependant.cache_key) flat_dependant = Dependant( path_params=dependant.path_params.copy(), query_params=dependant.query_params.copy(), header_params=dependant.header_params.copy(), cookie_params=dependant.cookie_params.copy(), body_params=dependant.body_params.copy(), security_schemes=dependant.security_requirements.copy(), use_cache=dependant.use_cache, path=dependant.path, ) for sub_dependant in dependant.dependencies: if skip_repeats and sub_dependant.cache_key in visited: continue flat_sub = get_flat_dependant( sub_dependant, skip_repeats=skip_repeats, visited=visited ) flat_dependant.path_params.extend(flat_sub.path_params) flat_dependant.query_params.extend(flat_sub.query_params) flat_dependant.header_params.extend(flat_sub.header_params) flat_dependant.cookie_params.extend(flat_sub.cookie_params) flat_dependant.body_params.extend(flat_sub.body_params) flat_dependant.security_requirements.extend(flat_sub.security_requirements) return flat_dependant def get_flat_params(dependant: Dependant) -> List[ModelField]: flat_dependant = get_flat_dependant(dependant, skip_repeats=True) return ( flat_dependant.path_params + flat_dependant.query_params + flat_dependant.header_params + flat_dependant.cookie_params ) def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature: signature = inspect.signature(call) globalns = getattr(call, "__globals__", {}) typed_params = [ inspect.Parameter( name=param.name, kind=param.kind, default=param.default, annotation=get_typed_annotation(param.annotation, globalns), ) for param in signature.parameters.values() ] typed_signature = inspect.Signature(typed_params) return typed_signature def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any: if isinstance(annotation, str): annotation = ForwardRef(annotation) annotation = evaluate_forwardref(annotation, globalns, globalns) return annotation def get_typed_return_annotation(call: Callable[..., Any]) -> Any: signature = inspect.signature(call) annotation = signature.return_annotation if annotation is inspect.Signature.empty: return None globalns = getattr(call, "__globals__", {}) return get_typed_annotation(annotation, globalns) def get_dependant( *, path: str, call: Callable[..., Any], name: Optional[str] = None, security_scopes: Optional[List[str]] = None, use_cache: bool = True, ) -> Dependant: path_param_names = get_path_param_names(path) endpoint_signature = get_typed_signature(call) signature_params = endpoint_signature.parameters dependant = Dependant( call=call, name=name, path=path, security_scopes=security_scopes, use_cache=use_cache, ) for param_name, param in signature_params.items(): is_path_param = param_name in path_param_names type_annotation, depends, param_field = analyze_param( param_name=param_name, annotation=param.annotation, value=param.default, is_path_param=is_path_param, ) if depends is not None: sub_dependant = get_param_sub_dependant( param_name=param_name, depends=depends, path=path, security_scopes=security_scopes, ) dependant.dependencies.append(sub_dependant) continue if add_non_field_param_to_dependency( param_name=param_name, type_annotation=type_annotation, dependant=dependant, ): assert ( param_field is None ), f"Cannot specify multiple FastAPI annotations for {param_name!r}" continue assert param_field is not None if is_body_param(param_field=param_field, is_path_param=is_path_param): dependant.body_params.append(param_field) else: add_param_to_fields(field=param_field, dependant=dependant) return dependant def add_non_field_param_to_dependency( *, param_name: str, type_annotation: Any, dependant: Dependant ) -> Optional[bool]: if lenient_issubclass(type_annotation, Request): dependant.request_param_name = param_name return True elif lenient_issubclass(type_annotation, WebSocket): dependant.websocket_param_name = param_name return True elif lenient_issubclass(type_annotation, HTTPConnection): dependant.http_connection_param_name = param_name return True elif lenient_issubclass(type_annotation, Response): dependant.response_param_name = param_name return True elif lenient_issubclass(type_annotation, StarletteBackgroundTasks): dependant.background_tasks_param_name = param_name return True elif lenient_issubclass(type_annotation, SecurityScopes): dependant.security_scopes_param_name = param_name return True return None def analyze_param( *, param_name: str, annotation: Any, value: Any, is_path_param: bool, ) -> Tuple[Any, Optional[params.Depends], Optional[ModelField]]: field_info = None depends = None type_annotation: Any = Any use_annotation: Any = Any if annotation is not inspect.Signature.empty: use_annotation = annotation type_annotation = annotation if get_origin(use_annotation) is Annotated: annotated_args = get_args(annotation) type_annotation = annotated_args[0] fastapi_annotations = [ arg for arg in annotated_args[1:] if isinstance(arg, (FieldInfo, params.Depends)) ] fastapi_specific_annotations = [ arg for arg in fastapi_annotations if isinstance(arg, (params.Param, params.Body, params.Depends)) ] if fastapi_specific_annotations: fastapi_annotation: Union[ FieldInfo, params.Depends, None ] = fastapi_specific_annotations[-1] else: fastapi_annotation = None if isinstance(fastapi_annotation, FieldInfo): # Copy `field_info` because we mutate `field_info.default` below. field_info = copy_field_info( field_info=fastapi_annotation, annotation=use_annotation ) assert field_info.default is Undefined or field_info.default is Required, ( f"`{field_info.__class__.__name__}` default value cannot be set in" f" `Annotated` for {param_name!r}. Set the default value with `=` instead." ) if value is not inspect.Signature.empty: assert not is_path_param, "Path parameters cannot have default values" field_info.default = value else: field_info.default = Required elif isinstance(fastapi_annotation, params.Depends): depends = fastapi_annotation if isinstance(value, params.Depends): assert depends is None, ( "Cannot specify `Depends` in `Annotated` and default value" f" together for {param_name!r}" ) assert field_info is None, ( "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a" f" default value together for {param_name!r}" ) depends = value elif isinstance(value, FieldInfo): assert field_info is None, ( "Cannot specify FastAPI annotations in `Annotated` and default value" f" together for {param_name!r}" ) field_info = value if PYDANTIC_V2: field_info.annotation = type_annotation if depends is not None and depends.dependency is None: # Copy `depends` before mutating it depends = copy(depends) depends.dependency = type_annotation if lenient_issubclass( type_annotation, ( Request, WebSocket, HTTPConnection, Response, StarletteBackgroundTasks, SecurityScopes, ), ): assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}" assert ( field_info is None ), f"Cannot specify FastAPI annotation for type {type_annotation!r}" elif field_info is None and depends is None: default_value = value if value is not inspect.Signature.empty else Required if is_path_param: # We might check here that `default_value is Required`, but the fact is that the same # parameter might sometimes be a path parameter and sometimes not. See # `tests/test_infer_param_optionality.py` for an example. field_info = params.Path(annotation=use_annotation) elif is_uploadfile_or_nonable_uploadfile_annotation( type_annotation ) or is_uploadfile_sequence_annotation(type_annotation): field_info = params.File(annotation=use_annotation, default=default_value) elif not field_annotation_is_scalar(annotation=type_annotation): field_info = params.Body(annotation=use_annotation, default=default_value) else: field_info = params.Query(annotation=use_annotation, default=default_value) field = None if field_info is not None: if is_path_param: assert isinstance(field_info, params.Path), ( f"Cannot use `{field_info.__class__.__name__}` for path param" f" {param_name!r}" ) elif ( isinstance(field_info, params.Param) and getattr(field_info, "in_", None) is None ): field_info.in_ = params.ParamTypes.query use_annotation_from_field_info = get_annotation_from_field_info( use_annotation, field_info, param_name, ) if not field_info.alias and getattr(field_info, "convert_underscores", None): alias = param_name.replace("_", "-") else: alias = field_info.alias or param_name field_info.alias = alias field = create_response_field( name=param_name, type_=use_annotation_from_field_info, default=field_info.default, alias=alias, required=field_info.default in (Required, Undefined), field_info=field_info, ) return type_annotation, depends, field def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool: if is_path_param: assert is_scalar_field( field=param_field ), "Path params must be of one of the supported types" return False elif is_scalar_field(field=param_field): return False elif isinstance( param_field.field_info, (params.Query, params.Header) ) and is_scalar_sequence_field(param_field): return False else: assert isinstance( param_field.field_info, params.Body ), f"Param: {param_field.name} can only be a request body, using Body()" return True def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None: field_info = field.field_info field_info_in = getattr(field_info, "in_", None) if field_info_in == params.ParamTypes.path: dependant.path_params.append(field) elif field_info_in == params.ParamTypes.query: dependant.query_params.append(field) elif field_info_in == params.ParamTypes.header: dependant.header_params.append(field) else: assert ( field_info_in == params.ParamTypes.cookie ), f"non-body parameters must be in path, query, header or cookie: {field.name}" dependant.cookie_params.append(field) def is_coroutine_callable(call: Callable[..., Any]) -> bool: if inspect.isroutine(call): return inspect.iscoroutinefunction(call) if inspect.isclass(call): return False dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.iscoroutinefunction(dunder_call) def is_async_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isasyncgenfunction(call): return True dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isasyncgenfunction(dunder_call) def is_gen_callable(call: Callable[..., Any]) -> bool: if inspect.isgeneratorfunction(call): return True dunder_call = getattr(call, "__call__", None) # noqa: B004 return inspect.isgeneratorfunction(dunder_call) async def solve_generator( *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any] ) -> Any: if is_gen_callable(call): cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values)) elif is_async_gen_callable(call): cm = asynccontextmanager(call)(**sub_values) return await stack.enter_async_context(cm) async def solve_dependencies( *, request: Union[Request, WebSocket], dependant: Dependant, body: Optional[Union[Dict[str, Any], FormData]] = None, background_tasks: Optional[StarletteBackgroundTasks] = None, response: Optional[Response] = None, dependency_overrides_provider: Optional[Any] = None, dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None, async_exit_stack: AsyncExitStack, ) -> Tuple[ Dict[str, Any], List[Any], Optional[StarletteBackgroundTasks], Response, Dict[Tuple[Callable[..., Any], Tuple[str]], Any], ]: values: Dict[str, Any] = {} errors: List[Any] = [] if response is None: response = Response() del response.headers["content-length"] response.status_code = None # type: ignore dependency_cache = dependency_cache or {} sub_dependant: Dependant for sub_dependant in dependant.dependencies: sub_dependant.call = cast(Callable[..., Any], sub_dependant.call) sub_dependant.cache_key = cast( Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key ) call = sub_dependant.call use_sub_dependant = sub_dependant if ( dependency_overrides_provider and dependency_overrides_provider.dependency_overrides ): original_call = sub_dependant.call call = getattr( dependency_overrides_provider, "dependency_overrides", {} ).get(original_call, original_call) use_path: str = sub_dependant.path # type: ignore use_sub_dependant = get_dependant( path=use_path, call=call, name=sub_dependant.name, security_scopes=sub_dependant.security_scopes, ) solved_result = await solve_dependencies( request=request, dependant=use_sub_dependant, body=body, background_tasks=background_tasks, response=response, dependency_overrides_provider=dependency_overrides_provider, dependency_cache=dependency_cache, async_exit_stack=async_exit_stack, ) ( sub_values, sub_errors, background_tasks, _, # the subdependency returns the same response we have sub_dependency_cache, ) = solved_result dependency_cache.update(sub_dependency_cache) if sub_errors: errors.extend(sub_errors) continue if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache: solved = dependency_cache[sub_dependant.cache_key] elif is_gen_callable(call) or is_async_gen_callable(call): solved = await solve_generator( call=call, stack=async_exit_stack, sub_values=sub_values ) elif is_coroutine_callable(call): solved = await call(**sub_values) else: solved = await run_in_threadpool(call, **sub_values) if sub_dependant.name is not None: values[sub_dependant.name] = solved if sub_dependant.cache_key not in dependency_cache: dependency_cache[sub_dependant.cache_key] = solved path_values, path_errors = request_params_to_args( dependant.path_params, request.path_params ) query_values, query_errors = request_params_to_args( dependant.query_params, request.query_params ) header_values, header_errors = request_params_to_args( dependant.header_params, request.headers ) cookie_values, cookie_errors = request_params_to_args( dependant.cookie_params, request.cookies ) values.update(path_values) values.update(query_values) values.update(header_values) values.update(cookie_values) errors += path_errors + query_errors + header_errors + cookie_errors if dependant.body_params: ( body_values, body_errors, ) = await request_body_to_args( # body_params checked above required_params=dependant.body_params, received_body=body ) values.update(body_values) errors.extend(body_errors) if dependant.http_connection_param_name: values[dependant.http_connection_param_name] = request if dependant.request_param_name and isinstance(request, Request): values[dependant.request_param_name] = request elif dependant.websocket_param_name and isinstance(request, WebSocket): values[dependant.websocket_param_name] = request if dependant.background_tasks_param_name: if background_tasks is None: background_tasks = BackgroundTasks() values[dependant.background_tasks_param_name] = background_tasks if dependant.response_param_name: values[dependant.response_param_name] = response if dependant.security_scopes_param_name: values[dependant.security_scopes_param_name] = SecurityScopes( scopes=dependant.security_scopes ) return values, errors, background_tasks, response, dependency_cache def request_params_to_args( required_params: Sequence[ModelField], received_params: Union[Mapping[str, Any], QueryParams, Headers], ) -> Tuple[Dict[str, Any], List[Any]]: values = {} errors = [] for field in required_params: if is_scalar_sequence_field(field) and isinstance( received_params, (QueryParams, Headers) ): value = received_params.getlist(field.alias) or field.default else: value = received_params.get(field.alias) field_info = field.field_info assert isinstance( field_info, params.Param ), "Params must be subclasses of Param" loc = (field_info.in_.value, field.alias) if value is None: if field.required: errors.append(get_missing_field_error(loc=loc)) else: values[field.name] = deepcopy(field.default) continue v_, errors_ = field.validate(value, values, loc=loc) if isinstance(errors_, ErrorWrapper): errors.append(errors_) elif isinstance(errors_, list): new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=()) errors.extend(new_errors) else: values[field.name] = v_ return values, errors async def request_body_to_args( required_params: List[ModelField], received_body: Optional[Union[Dict[str, Any], FormData]], ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: values = {} errors: List[Dict[str, Any]] = [] if required_params: field = required_params[0] field_info = field.field_info embed = getattr(field_info, "embed", None) field_alias_omitted = len(required_params) == 1 and not embed if field_alias_omitted: received_body = {field.alias: received_body} for field in required_params: loc: Tuple[str, ...] if field_alias_omitted: loc = ("body",) else: loc = ("body", field.alias) value: Optional[Any] = None if received_body is not None: if (is_sequence_field(field)) and isinstance(received_body, FormData): value = received_body.getlist(field.alias) else: try: value = received_body.get(field.alias) except AttributeError: errors.append(get_missing_field_error(loc)) continue if ( value is None or (isinstance(field_info, params.Form) and value == "") or ( isinstance(field_info, params.Form) and is_sequence_field(field) and len(value) == 0 ) ): if field.required: errors.append(get_missing_field_error(loc)) else: values[field.name] = deepcopy(field.default) continue if ( isinstance(field_info, params.File) and is_bytes_field(field) and isinstance(value, UploadFile) ): value = await value.read() elif ( is_bytes_sequence_field(field) and isinstance(field_info, params.File) and value_is_sequence(value) ): # For types assert isinstance(value, sequence_types) # type: ignore[arg-type] results: List[Union[bytes, str]] = [] async def process_fn( fn: Callable[[], Coroutine[Any, Any, Any]], ) -> None: result = await fn() results.append(result) # noqa: B023 async with anyio.create_task_group() as tg: for sub_value in value: tg.start_soon(process_fn, sub_value.read) value = serialize_sequence_value(field=field, value=results) v_, errors_ = field.validate(value, values, loc=loc) if isinstance(errors_, list): errors.extend(errors_) elif errors_: errors.append(errors_) else: values[field.name] = v_ return values, errors def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]: flat_dependant = get_flat_dependant(dependant) if not flat_dependant.body_params: return None first_param = flat_dependant.body_params[0] field_info = first_param.field_info embed = getattr(field_info, "embed", None) body_param_names_set = {param.name for param in flat_dependant.body_params} if len(body_param_names_set) == 1 and not embed: check_file_field(first_param) return first_param # If one field requires to embed, all have to be embedded # in case a sub-dependency is evaluated with a single unique body field # That is combined (embedded) with other body fields for param in flat_dependant.body_params: setattr(param.field_info, "embed", True) # noqa: B010 model_name = "Body_" + name BodyModel = create_body_model( fields=flat_dependant.body_params, model_name=model_name ) required = any(True for f in flat_dependant.body_params if f.required) BodyFieldInfo_kwargs: Dict[str, Any] = { "annotation": BodyModel, "alias": "body", } if not required: BodyFieldInfo_kwargs["default"] = None if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params): BodyFieldInfo: Type[params.Body] = params.File elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params): BodyFieldInfo = params.Form else: BodyFieldInfo = params.Body body_param_media_types = [ f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, params.Body) ] if len(set(body_param_media_types)) == 1: BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0] final_field = create_response_field( name="body", type_=BodyModel, required=required, alias="body", field_info=BodyFieldInfo(**BodyFieldInfo_kwargs), ) check_file_field(final_field) return final_field