531 lines
22 KiB
Python
531 lines
22 KiB
Python
import http.client
|
|
import inspect
|
|
import warnings
|
|
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type, Union, cast
|
|
|
|
from fastapi import routing
|
|
from fastapi._compat import (
|
|
GenerateJsonSchema,
|
|
JsonSchemaValue,
|
|
ModelField,
|
|
Undefined,
|
|
get_compat_model_name_map,
|
|
get_definitions,
|
|
get_schema_from_model_field,
|
|
lenient_issubclass,
|
|
)
|
|
from fastapi.datastructures import DefaultPlaceholder
|
|
from fastapi.dependencies.models import Dependant
|
|
from fastapi.dependencies.utils import get_flat_dependant, get_flat_params
|
|
from fastapi.encoders import jsonable_encoder
|
|
from fastapi.openapi.constants import METHODS_WITH_BODY, REF_PREFIX, REF_TEMPLATE
|
|
from fastapi.openapi.models import OpenAPI
|
|
from fastapi.params import Body, Param
|
|
from fastapi.responses import Response
|
|
from fastapi.types import ModelNameMap
|
|
from fastapi.utils import (
|
|
deep_dict_update,
|
|
generate_operation_id_for_path,
|
|
is_body_allowed_for_status_code,
|
|
)
|
|
from starlette.responses import JSONResponse
|
|
from starlette.routing import BaseRoute
|
|
from starlette.status import HTTP_422_UNPROCESSABLE_ENTITY
|
|
from typing_extensions import Literal
|
|
|
|
validation_error_definition = {
|
|
"title": "ValidationError",
|
|
"type": "object",
|
|
"properties": {
|
|
"loc": {
|
|
"title": "Location",
|
|
"type": "array",
|
|
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
|
},
|
|
"msg": {"title": "Message", "type": "string"},
|
|
"type": {"title": "Error Type", "type": "string"},
|
|
},
|
|
"required": ["loc", "msg", "type"],
|
|
}
|
|
|
|
validation_error_response_definition = {
|
|
"title": "HTTPValidationError",
|
|
"type": "object",
|
|
"properties": {
|
|
"detail": {
|
|
"title": "Detail",
|
|
"type": "array",
|
|
"items": {"$ref": REF_PREFIX + "ValidationError"},
|
|
}
|
|
},
|
|
}
|
|
|
|
status_code_ranges: Dict[str, str] = {
|
|
"1XX": "Information",
|
|
"2XX": "Success",
|
|
"3XX": "Redirection",
|
|
"4XX": "Client Error",
|
|
"5XX": "Server Error",
|
|
"DEFAULT": "Default Response",
|
|
}
|
|
|
|
|
|
def get_openapi_security_definitions(
|
|
flat_dependant: Dependant,
|
|
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
|
|
security_definitions = {}
|
|
operation_security = []
|
|
for security_requirement in flat_dependant.security_requirements:
|
|
security_definition = jsonable_encoder(
|
|
security_requirement.security_scheme.model,
|
|
by_alias=True,
|
|
exclude_none=True,
|
|
)
|
|
security_name = security_requirement.security_scheme.scheme_name
|
|
security_definitions[security_name] = security_definition
|
|
operation_security.append({security_name: security_requirement.scopes})
|
|
return security_definitions, operation_security
|
|
|
|
|
|
def get_openapi_operation_parameters(
|
|
*,
|
|
all_route_params: Sequence[ModelField],
|
|
schema_generator: GenerateJsonSchema,
|
|
model_name_map: ModelNameMap,
|
|
field_mapping: Dict[
|
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
],
|
|
separate_input_output_schemas: bool = True,
|
|
) -> List[Dict[str, Any]]:
|
|
parameters = []
|
|
for param in all_route_params:
|
|
field_info = param.field_info
|
|
field_info = cast(Param, field_info)
|
|
if not field_info.include_in_schema:
|
|
continue
|
|
param_schema = get_schema_from_model_field(
|
|
field=param,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
parameter = {
|
|
"name": param.alias,
|
|
"in": field_info.in_.value,
|
|
"required": param.required,
|
|
"schema": param_schema,
|
|
}
|
|
if field_info.description:
|
|
parameter["description"] = field_info.description
|
|
if field_info.openapi_examples:
|
|
parameter["examples"] = jsonable_encoder(field_info.openapi_examples)
|
|
elif field_info.example != Undefined:
|
|
parameter["example"] = jsonable_encoder(field_info.example)
|
|
if field_info.deprecated:
|
|
parameter["deprecated"] = True
|
|
parameters.append(parameter)
|
|
return parameters
|
|
|
|
|
|
def get_openapi_operation_request_body(
|
|
*,
|
|
body_field: Optional[ModelField],
|
|
schema_generator: GenerateJsonSchema,
|
|
model_name_map: ModelNameMap,
|
|
field_mapping: Dict[
|
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
],
|
|
separate_input_output_schemas: bool = True,
|
|
) -> Optional[Dict[str, Any]]:
|
|
if not body_field:
|
|
return None
|
|
assert isinstance(body_field, ModelField)
|
|
body_schema = get_schema_from_model_field(
|
|
field=body_field,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
field_info = cast(Body, body_field.field_info)
|
|
request_media_type = field_info.media_type
|
|
required = body_field.required
|
|
request_body_oai: Dict[str, Any] = {}
|
|
if required:
|
|
request_body_oai["required"] = required
|
|
request_media_content: Dict[str, Any] = {"schema": body_schema}
|
|
if field_info.openapi_examples:
|
|
request_media_content["examples"] = jsonable_encoder(
|
|
field_info.openapi_examples
|
|
)
|
|
elif field_info.example != Undefined:
|
|
request_media_content["example"] = jsonable_encoder(field_info.example)
|
|
request_body_oai["content"] = {request_media_type: request_media_content}
|
|
return request_body_oai
|
|
|
|
|
|
def generate_operation_id(
|
|
*, route: routing.APIRoute, method: str
|
|
) -> str: # pragma: nocover
|
|
warnings.warn(
|
|
"fastapi.openapi.utils.generate_operation_id() was deprecated, "
|
|
"it is not used internally, and will be removed soon",
|
|
DeprecationWarning,
|
|
stacklevel=2,
|
|
)
|
|
if route.operation_id:
|
|
return route.operation_id
|
|
path: str = route.path_format
|
|
return generate_operation_id_for_path(name=route.name, path=path, method=method)
|
|
|
|
|
|
def generate_operation_summary(*, route: routing.APIRoute, method: str) -> str:
|
|
if route.summary:
|
|
return route.summary
|
|
return route.name.replace("_", " ").title()
|
|
|
|
|
|
def get_openapi_operation_metadata(
|
|
*, route: routing.APIRoute, method: str, operation_ids: Set[str]
|
|
) -> Dict[str, Any]:
|
|
operation: Dict[str, Any] = {}
|
|
if route.tags:
|
|
operation["tags"] = route.tags
|
|
operation["summary"] = generate_operation_summary(route=route, method=method)
|
|
if route.description:
|
|
operation["description"] = route.description
|
|
operation_id = route.operation_id or route.unique_id
|
|
if operation_id in operation_ids:
|
|
message = (
|
|
f"Duplicate Operation ID {operation_id} for function "
|
|
+ f"{route.endpoint.__name__}"
|
|
)
|
|
file_name = getattr(route.endpoint, "__globals__", {}).get("__file__")
|
|
if file_name:
|
|
message += f" at {file_name}"
|
|
warnings.warn(message, stacklevel=1)
|
|
operation_ids.add(operation_id)
|
|
operation["operationId"] = operation_id
|
|
if route.deprecated:
|
|
operation["deprecated"] = route.deprecated
|
|
return operation
|
|
|
|
|
|
def get_openapi_path(
|
|
*,
|
|
route: routing.APIRoute,
|
|
operation_ids: Set[str],
|
|
schema_generator: GenerateJsonSchema,
|
|
model_name_map: ModelNameMap,
|
|
field_mapping: Dict[
|
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
|
],
|
|
separate_input_output_schemas: bool = True,
|
|
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
|
path = {}
|
|
security_schemes: Dict[str, Any] = {}
|
|
definitions: Dict[str, Any] = {}
|
|
assert route.methods is not None, "Methods must be a list"
|
|
if isinstance(route.response_class, DefaultPlaceholder):
|
|
current_response_class: Type[Response] = route.response_class.value
|
|
else:
|
|
current_response_class = route.response_class
|
|
assert current_response_class, "A response class is needed to generate OpenAPI"
|
|
route_response_media_type: Optional[str] = current_response_class.media_type
|
|
if route.include_in_schema:
|
|
for method in route.methods:
|
|
operation = get_openapi_operation_metadata(
|
|
route=route, method=method, operation_ids=operation_ids
|
|
)
|
|
parameters: List[Dict[str, Any]] = []
|
|
flat_dependant = get_flat_dependant(route.dependant, skip_repeats=True)
|
|
security_definitions, operation_security = get_openapi_security_definitions(
|
|
flat_dependant=flat_dependant
|
|
)
|
|
if operation_security:
|
|
operation.setdefault("security", []).extend(operation_security)
|
|
if security_definitions:
|
|
security_schemes.update(security_definitions)
|
|
all_route_params = get_flat_params(route.dependant)
|
|
operation_parameters = get_openapi_operation_parameters(
|
|
all_route_params=all_route_params,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
parameters.extend(operation_parameters)
|
|
if parameters:
|
|
all_parameters = {
|
|
(param["in"], param["name"]): param for param in parameters
|
|
}
|
|
required_parameters = {
|
|
(param["in"], param["name"]): param
|
|
for param in parameters
|
|
if param.get("required")
|
|
}
|
|
# Make sure required definitions of the same parameter take precedence
|
|
# over non-required definitions
|
|
all_parameters.update(required_parameters)
|
|
operation["parameters"] = list(all_parameters.values())
|
|
if method in METHODS_WITH_BODY:
|
|
request_body_oai = get_openapi_operation_request_body(
|
|
body_field=route.body_field,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
if request_body_oai:
|
|
operation["requestBody"] = request_body_oai
|
|
if route.callbacks:
|
|
callbacks = {}
|
|
for callback in route.callbacks:
|
|
if isinstance(callback, routing.APIRoute):
|
|
(
|
|
cb_path,
|
|
cb_security_schemes,
|
|
cb_definitions,
|
|
) = get_openapi_path(
|
|
route=callback,
|
|
operation_ids=operation_ids,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
callbacks[callback.name] = {callback.path: cb_path}
|
|
operation["callbacks"] = callbacks
|
|
if route.status_code is not None:
|
|
status_code = str(route.status_code)
|
|
else:
|
|
# It would probably make more sense for all response classes to have an
|
|
# explicit default status_code, and to extract it from them, instead of
|
|
# doing this inspection tricks, that would probably be in the future
|
|
# TODO: probably make status_code a default class attribute for all
|
|
# responses in Starlette
|
|
response_signature = inspect.signature(current_response_class.__init__)
|
|
status_code_param = response_signature.parameters.get("status_code")
|
|
if status_code_param is not None:
|
|
if isinstance(status_code_param.default, int):
|
|
status_code = str(status_code_param.default)
|
|
operation.setdefault("responses", {}).setdefault(status_code, {})[
|
|
"description"
|
|
] = route.response_description
|
|
if route_response_media_type and is_body_allowed_for_status_code(
|
|
route.status_code
|
|
):
|
|
response_schema = {"type": "string"}
|
|
if lenient_issubclass(current_response_class, JSONResponse):
|
|
if route.response_field:
|
|
response_schema = get_schema_from_model_field(
|
|
field=route.response_field,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
else:
|
|
response_schema = {}
|
|
operation.setdefault("responses", {}).setdefault(
|
|
status_code, {}
|
|
).setdefault("content", {}).setdefault(route_response_media_type, {})[
|
|
"schema"
|
|
] = response_schema
|
|
if route.responses:
|
|
operation_responses = operation.setdefault("responses", {})
|
|
for (
|
|
additional_status_code,
|
|
additional_response,
|
|
) in route.responses.items():
|
|
process_response = additional_response.copy()
|
|
process_response.pop("model", None)
|
|
status_code_key = str(additional_status_code).upper()
|
|
if status_code_key == "DEFAULT":
|
|
status_code_key = "default"
|
|
openapi_response = operation_responses.setdefault(
|
|
status_code_key, {}
|
|
)
|
|
assert isinstance(
|
|
process_response, dict
|
|
), "An additional response must be a dict"
|
|
field = route.response_fields.get(additional_status_code)
|
|
additional_field_schema: Optional[Dict[str, Any]] = None
|
|
if field:
|
|
additional_field_schema = get_schema_from_model_field(
|
|
field=field,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
media_type = route_response_media_type or "application/json"
|
|
additional_schema = (
|
|
process_response.setdefault("content", {})
|
|
.setdefault(media_type, {})
|
|
.setdefault("schema", {})
|
|
)
|
|
deep_dict_update(additional_schema, additional_field_schema)
|
|
status_text: Optional[str] = status_code_ranges.get(
|
|
str(additional_status_code).upper()
|
|
) or http.client.responses.get(int(additional_status_code))
|
|
description = (
|
|
process_response.get("description")
|
|
or openapi_response.get("description")
|
|
or status_text
|
|
or "Additional Response"
|
|
)
|
|
deep_dict_update(openapi_response, process_response)
|
|
openapi_response["description"] = description
|
|
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
|
|
if (all_route_params or route.body_field) and not any(
|
|
status in operation["responses"]
|
|
for status in [http422, "4XX", "default"]
|
|
):
|
|
operation["responses"][http422] = {
|
|
"description": "Validation Error",
|
|
"content": {
|
|
"application/json": {
|
|
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
|
|
}
|
|
},
|
|
}
|
|
if "ValidationError" not in definitions:
|
|
definitions.update(
|
|
{
|
|
"ValidationError": validation_error_definition,
|
|
"HTTPValidationError": validation_error_response_definition,
|
|
}
|
|
)
|
|
if route.openapi_extra:
|
|
deep_dict_update(operation, route.openapi_extra)
|
|
path[method.lower()] = operation
|
|
return path, security_schemes, definitions
|
|
|
|
|
|
def get_fields_from_routes(
|
|
routes: Sequence[BaseRoute],
|
|
) -> List[ModelField]:
|
|
body_fields_from_routes: List[ModelField] = []
|
|
responses_from_routes: List[ModelField] = []
|
|
request_fields_from_routes: List[ModelField] = []
|
|
callback_flat_models: List[ModelField] = []
|
|
for route in routes:
|
|
if getattr(route, "include_in_schema", None) and isinstance(
|
|
route, routing.APIRoute
|
|
):
|
|
if route.body_field:
|
|
assert isinstance(
|
|
route.body_field, ModelField
|
|
), "A request body must be a Pydantic Field"
|
|
body_fields_from_routes.append(route.body_field)
|
|
if route.response_field:
|
|
responses_from_routes.append(route.response_field)
|
|
if route.response_fields:
|
|
responses_from_routes.extend(route.response_fields.values())
|
|
if route.callbacks:
|
|
callback_flat_models.extend(get_fields_from_routes(route.callbacks))
|
|
params = get_flat_params(route.dependant)
|
|
request_fields_from_routes.extend(params)
|
|
|
|
flat_models = callback_flat_models + list(
|
|
body_fields_from_routes + responses_from_routes + request_fields_from_routes
|
|
)
|
|
return flat_models
|
|
|
|
|
|
def get_openapi(
|
|
*,
|
|
title: str,
|
|
version: str,
|
|
openapi_version: str = "3.1.0",
|
|
summary: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
routes: Sequence[BaseRoute],
|
|
webhooks: Optional[Sequence[BaseRoute]] = None,
|
|
tags: Optional[List[Dict[str, Any]]] = None,
|
|
servers: Optional[List[Dict[str, Union[str, Any]]]] = None,
|
|
terms_of_service: Optional[str] = None,
|
|
contact: Optional[Dict[str, Union[str, Any]]] = None,
|
|
license_info: Optional[Dict[str, Union[str, Any]]] = None,
|
|
separate_input_output_schemas: bool = True,
|
|
) -> Dict[str, Any]:
|
|
info: Dict[str, Any] = {"title": title, "version": version}
|
|
if summary:
|
|
info["summary"] = summary
|
|
if description:
|
|
info["description"] = description
|
|
if terms_of_service:
|
|
info["termsOfService"] = terms_of_service
|
|
if contact:
|
|
info["contact"] = contact
|
|
if license_info:
|
|
info["license"] = license_info
|
|
output: Dict[str, Any] = {"openapi": openapi_version, "info": info}
|
|
if servers:
|
|
output["servers"] = servers
|
|
components: Dict[str, Dict[str, Any]] = {}
|
|
paths: Dict[str, Dict[str, Any]] = {}
|
|
webhook_paths: Dict[str, Dict[str, Any]] = {}
|
|
operation_ids: Set[str] = set()
|
|
all_fields = get_fields_from_routes(list(routes or []) + list(webhooks or []))
|
|
model_name_map = get_compat_model_name_map(all_fields)
|
|
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
|
field_mapping, definitions = get_definitions(
|
|
fields=all_fields,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
for route in routes or []:
|
|
if isinstance(route, routing.APIRoute):
|
|
result = get_openapi_path(
|
|
route=route,
|
|
operation_ids=operation_ids,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
if result:
|
|
path, security_schemes, path_definitions = result
|
|
if path:
|
|
paths.setdefault(route.path_format, {}).update(path)
|
|
if security_schemes:
|
|
components.setdefault("securitySchemes", {}).update(
|
|
security_schemes
|
|
)
|
|
if path_definitions:
|
|
definitions.update(path_definitions)
|
|
for webhook in webhooks or []:
|
|
if isinstance(webhook, routing.APIRoute):
|
|
result = get_openapi_path(
|
|
route=webhook,
|
|
operation_ids=operation_ids,
|
|
schema_generator=schema_generator,
|
|
model_name_map=model_name_map,
|
|
field_mapping=field_mapping,
|
|
separate_input_output_schemas=separate_input_output_schemas,
|
|
)
|
|
if result:
|
|
path, security_schemes, path_definitions = result
|
|
if path:
|
|
webhook_paths.setdefault(webhook.path_format, {}).update(path)
|
|
if security_schemes:
|
|
components.setdefault("securitySchemes", {}).update(
|
|
security_schemes
|
|
)
|
|
if path_definitions:
|
|
definitions.update(path_definitions)
|
|
if definitions:
|
|
components["schemas"] = {k: definitions[k] for k in sorted(definitions)}
|
|
if components:
|
|
output["components"] = components
|
|
output["paths"] = paths
|
|
if webhook_paths:
|
|
output["webhooks"] = webhook_paths
|
|
if tags:
|
|
output["tags"] = tags
|
|
return jsonable_encoder(OpenAPI(**output), by_alias=True, exclude_none=True) # type: ignore
|