import http.client import inspect import warnings from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast from fastapi import routing from fastapi._compat import ( GenerateJsonSchema, JsonSchemaValue, ModelField, Undefined, get_compat_model_name_map, get_definitions, get_schema_from_model_field, lenient_issubclass, ) from fastapi.datastructures import DefaultPlaceholder from fastapi.dependencies.models import Dependant from fastapi.dependencies.utils import get_flat_dependant, get_flat_params from fastapi.encoders import jsonable_encoder from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE from fastapi.openapi.models import OpenAPI from fastapi.params import Body, Param from fastapi.responses import Response from fastapi.types import ModelNameMap from fastapi.utils import ( deep_dict_update, generate_operation_id_for_path, is_body_allowed_for_status_code, ) from starlette.responses import JSONResponse from starlette.routing import BaseRoute from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY from typing_extensions import Literal validation_error_definition = { "title": "ValidationError", "type": "object", "properties": { "loc": { "title": "Location", "type": "array", "items": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, }, "msg": {"title": "Message", "type": "string"}, "type": {"title": "Error Type", "type": "string"}, }, "required": ["loc", "msg", "type"], } validation_error_response_definition = { "title": "HTTPValidationError", "type": "object", "properties": { "detail": { "title": "Detail", "type": "array", "items": {"$ref": REF_PREFIX + "ValidationError"}, } }, } status_code_ranges: Dict[str, str] = { "1XX": "Information", "2XX": "Success", "3XX": "Redirection", "4XX": "Client Error", "5XX": "Server Error", "DEFAULT": "Default Response", } def get_openapi_security_definitions( flat_dependant: Dependant, ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]: security_definitions = {} operation_security = [] for security_requirement in flat_dependant.security_requirements: security_definition = jsonable_encoder( security_requirement.security_scheme.model, by_alias=True, exclude_none=True, ) security_name = security_requirement.security_scheme.scheme_name security_definitions[security_name] = security_definition operation_security.append({security_name: security_requirement.scopes}) return security_definitions, operation_security def get_openapi_operation_parameters( *, all_route_params: Sequence[ModelField], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, field_mapping: Dict[ Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue ], separate_input_output_schemas: bool = True, ) -> List[Dict[str, Any]]: parameters = [] for param in all_route_params: field_info = param.field_info field_info = cast(Param, field_info) if not field_info.include_in_schema: continue param_schema = get_schema_from_model_field( field=param, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) parameter = { "name": param.alias, "in": field_info.in_.value, "required": param.required, "schema": param_schema, } if field_info.description: parameter["description"] = field_info.description if field_info.openapi_examples: parameter["examples"] = jsonable_encoder(field_info.openapi_examples) elif field_info.example != Undefined: parameter["example"] = jsonable_encoder(field_info.example) if field_info.deprecated: parameter["deprecated"] = True parameters.append(parameter) return parameters def get_openapi_operation_request_body( *, body_field: Optional[ModelField], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, field_mapping: Dict[ Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue ], separate_input_output_schemas: bool = True, ) -> Optional[Dict[str, Any]]: if not body_field: return None assert isinstance(body_field, ModelField) body_schema = get_schema_from_model_field( field=body_field, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) field_info = cast(Body, body_field.field_info) request_media_type = field_info.media_type required = body_field.required request_body_oai: Dict[str, Any] = {} if required: request_body_oai["required"] = required request_media_content: Dict[str, Any] = {"schema": body_schema} if field_info.openapi_examples: request_media_content["examples"] = jsonable_encoder( field_info.openapi_examples ) elif field_info.example != Undefined: request_media_content["example"] = jsonable_encoder(field_info.example) request_body_oai["content"] = {request_media_type: request_media_content} return request_body_oai def generate_operation_id( *, route: routing.APIRoute, method: str ) -> str: # pragma: nocover warnings.warn( "fastapi.openapi.utils.generate_operation_id() was deprecated, " "it is not used internally, and will be removed soon", DeprecationWarning, stacklevel=2, ) if route.operation_id: return route.operation_id path: str = route.path_format return generate_operation_id_for_path(name=route.name, path=path, method=method) def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str: if route.summary: return route.summary return route.name.replace("_", " ").title() def get_openapi_operation_metadata( *, route: routing.APIRoute, method: str, operation_ids: Set[str] ) -> Dict[str, Any]: operation: Dict[str, Any] = {} if route.tags: operation["tags"] = route.tags operation["summary"] = generate_operation_summary(route=route, method=method) if route.description: operation["description"] = route.description operation_id = route.operation_id or route.unique_id if operation_id in operation_ids: message = ( f"Duplicate Operation ID {operation_id} for function " + f"{route.endpoint.__name__}" ) file_name = getattr(route.endpoint, "__globals__", {}).get("__file__") if file_name: message += f" at {file_name}" warnings.warn(message, stacklevel=1) operation_ids.add(operation_id) operation["operationId"] = operation_id if route.deprecated: operation["deprecated"] = route.deprecated return operation def get_openapi_path( *, route: routing.APIRoute, operation_ids: Set[str], schema_generator: GenerateJsonSchema, model_name_map: ModelNameMap, field_mapping: Dict[ Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue ], separate_input_output_schemas: bool = True, ) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]: path = {} security_schemes: Dict[str, Any] = {} definitions: Dict[str, Any] = {} assert route.methods is not None, "Methods must be a list" if isinstance(route.response_class, DefaultPlaceholder): current_response_class: Type[Response] = route.response_class.value else: current_response_class = route.response_class assert current_response_class, "A response class is needed to generate OpenAPI" route_response_media_type: Optional[str] = current_response_class.media_type if route.include_in_schema: for method in route.methods: operation = get_openapi_operation_metadata( route=route, method=method, operation_ids=operation_ids ) parameters: List[Dict[str, Any]] = [] flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True) security_definitions, operation_security = get_openapi_security_definitions( flat_dependant=flat_dependant ) if operation_security: operation.setdefault("security", []).extend(operation_security) if security_definitions: security_schemes.update(security_definitions) all_route_params = get_flat_params(route.dependant) operation_parameters = get_openapi_operation_parameters( all_route_params=all_route_params, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) parameters.extend(operation_parameters) if parameters: all_parameters = { (param["in"], param["name"]): param for param in parameters } required_parameters = { (param["in"], param["name"]): param for param in parameters if param.get("required") } # Make sure required definitions of the same parameter take precedence # over non-required definitions all_parameters.update(required_parameters) operation["parameters"] = list(all_parameters.values()) if method in METHODS_WITH_BODY: request_body_oai = get_openapi_operation_request_body( body_field=route.body_field, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) if request_body_oai: operation["requestBody"] = request_body_oai if route.callbacks: callbacks = {} for callback in route.callbacks: if isinstance(callback, routing.APIRoute): ( cb_path, cb_security_schemes, cb_definitions, ) = get_openapi_path( route=callback, operation_ids=operation_ids, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) callbacks[callback.name] = {callback.path: cb_path} operation["callbacks"] = callbacks if route.status_code is not None: status_code = str(route.status_code) else: # It would probably make more sense for all response classes to have an # explicit default status_code, and to extract it from them, instead of # doing this inspection tricks, that would probably be in the future # TODO: probably make status_code a default class attribute for all # responses in Starlette response_signature = inspect.signature(current_response_class.__init__) status_code_param = response_signature.parameters.get("status_code") if status_code_param is not None: if isinstance(status_code_param.default, int): status_code = str(status_code_param.default) operation.setdefault("responses", {}).setdefault(status_code, {})[ "description" ] = route.response_description if route_response_media_type and is_body_allowed_for_status_code( route.status_code ): response_schema = {"type": "string"} if lenient_issubclass(current_response_class, JSONResponse): if route.response_field: response_schema = get_schema_from_model_field( field=route.response_field, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) else: response_schema = {} operation.setdefault("responses", {}).setdefault( status_code, {} ).setdefault("content", {}).setdefault(route_response_media_type, {})[ "schema" ] = response_schema if route.responses: operation_responses = operation.setdefault("responses", {}) for ( additional_status_code, additional_response, ) in route.responses.items(): process_response = additional_response.copy() process_response.pop("model", None) status_code_key = str(additional_status_code).upper() if status_code_key == "DEFAULT": status_code_key = "default" openapi_response = operation_responses.setdefault( status_code_key, {} ) assert isinstance( process_response, dict ), "An additional response must be a dict" field = route.response_fields.get(additional_status_code) additional_field_schema: Optional[Dict[str, Any]] = None if field: additional_field_schema = get_schema_from_model_field( field=field, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) media_type = route_response_media_type or "application/json" additional_schema = ( process_response.setdefault("content", {}) .setdefault(media_type, {}) .setdefault("schema", {}) ) deep_dict_update(additional_schema, additional_field_schema) status_text: Optional[str] = status_code_ranges.get( str(additional_status_code).upper() ) or http.client.responses.get(int(additional_status_code)) description = ( process_response.get("description") or openapi_response.get("description") or status_text or "Additional Response" ) deep_dict_update(openapi_response, process_response) openapi_response["description"] = description http422 = str(HTTP_422_UNPROCESSABLE_ENTITY) if (all_route_params or route.body_field) and not any( status in operation["responses"] for status in [http422, "4XX", "default"] ): operation["responses"][http422] = { "description": "Validation Error", "content": { "application/json": { "schema": {"$ref": REF_PREFIX + "HTTPValidationError"} } }, } if "ValidationError" not in definitions: definitions.update( { "ValidationError": validation_error_definition, "HTTPValidationError": validation_error_response_definition, } ) if route.openapi_extra: deep_dict_update(operation, route.openapi_extra) path[method.lower()] = operation return path, security_schemes, definitions def get_fields_from_routes( routes: Sequence[BaseRoute], ) -> List[ModelField]: body_fields_from_routes: List[ModelField] = [] responses_from_routes: List[ModelField] = [] request_fields_from_routes: List[ModelField] = [] callback_flat_models: List[ModelField] = [] for route in routes: if getattr(route, "include_in_schema", None) and isinstance( route, routing.APIRoute ): if route.body_field: assert isinstance( route.body_field, ModelField ), "A request body must be a Pydantic Field" body_fields_from_routes.append(route.body_field) if route.response_field: responses_from_routes.append(route.response_field) if route.response_fields: responses_from_routes.extend(route.response_fields.values()) if route.callbacks: callback_flat_models.extend(get_fields_from_routes(route.callbacks)) params = get_flat_params(route.dependant) request_fields_from_routes.extend(params) flat_models = callback_flat_models + list( body_fields_from_routes + responses_from_routes + request_fields_from_routes ) return flat_models def get_openapi( *, title: str, version: str, openapi_version: str = "3.1.0", summary: Optional[str] = None, description: Optional[str] = None, routes: Sequence[BaseRoute], webhooks: Optional[Sequence[BaseRoute]] = None, tags: Optional[List[Dict[str, Any]]] = None, servers: Optional[List[Dict[str, Union[str, Any]]]] = None, terms_of_service: Optional[str] = None, contact: Optional[Dict[str, Union[str, Any]]] = None, license_info: Optional[Dict[str, Union[str, Any]]] = None, separate_input_output_schemas: bool = True, ) -> Dict[str, Any]: info: Dict[str, Any] = {"title": title, "version": version} if summary: info["summary"] = summary if description: info["description"] = description if terms_of_service: info["termsOfService"] = terms_of_service if contact: info["contact"] = contact if license_info: info["license"] = license_info output: Dict[str, Any] = {"openapi": openapi_version, "info": info} if servers: output["servers"] = servers components: Dict[str, Dict[str, Any]] = {} paths: Dict[str, Dict[str, Any]] = {} webhook_paths: Dict[str, Dict[str, Any]] = {} operation_ids: Set[str] = set() all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or [])) model_name_map = get_compat_model_name_map(all_fields) schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) field_mapping, definitions = get_definitions( fields=all_fields, schema_generator=schema_generator, model_name_map=model_name_map, separate_input_output_schemas=separate_input_output_schemas, ) for route in routes or []: if isinstance(route, routing.APIRoute): result = get_openapi_path( route=route, operation_ids=operation_ids, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) if result: path, security_schemes, path_definitions = result if path: paths.setdefault(route.path_format, {}).update(path) if security_schemes: components.setdefault("securitySchemes", {}).update( security_schemes ) if path_definitions: definitions.update(path_definitions) for webhook in webhooks or []: if isinstance(webhook, routing.APIRoute): result = get_openapi_path( route=webhook, operation_ids=operation_ids, schema_generator=schema_generator, model_name_map=model_name_map, field_mapping=field_mapping, separate_input_output_schemas=separate_input_output_schemas, ) if result: path, security_schemes, path_definitions = result if path: webhook_paths.setdefault(webhook.path_format, {}).update(path) if security_schemes: components.setdefault("securitySchemes", {}).update( security_schemes ) if path_definitions: definitions.update(path_definitions) if definitions: components["schemas"] = {k: definitions[k] for k in sorted(definitions)} if components: output["components"] = components output["paths"] = paths if webhook_paths: output["webhooks"] = webhook_paths if tags: output["tags"] = tags return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore