webdriver_template/telecli/lib/python3.11/site-packages/starlette/schemas.py

151 lines
5.1 KiB
Python
Raw Normal View History

2024-08-10 14:48:21 +03:00
from __future__ import annotations
import inspect
import re
import typing
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Host, Mount, Route
try:
import yaml
except ModuleNotFoundError: # pragma: nocover
yaml = None # type: ignore[assignment]
class OpenAPIResponse(Response):
media_type = "application/vnd.oai.openapi"
def render(self, content: typing.Any) -> bytes:
assert yaml is not None, "`pyyaml` must be installed to use OpenAPIResponse."
assert isinstance(
content, dict
), "The schema passed to OpenAPIResponse should be a dictionary."
return yaml.dump(content, default_flow_style=False).encode("utf-8")
class EndpointInfo(typing.NamedTuple):
path: str
http_method: str
func: typing.Callable[..., typing.Any]
class BaseSchemaGenerator:
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
raise NotImplementedError() # pragma: no cover
def get_endpoints(self, routes: list[BaseRoute]) -> list[EndpointInfo]:
"""
Given the routes, yields the following information:
- path
eg: /users/
- http_method
one of 'get', 'post', 'put', 'patch', 'delete', 'options'
- func
method ready to extract the docstring
"""
endpoints_info: list[EndpointInfo] = []
for route in routes:
if isinstance(route, (Mount, Host)):
routes = route.routes or []
if isinstance(route, Mount):
path = self._remove_converter(route.path)
else:
path = ""
sub_endpoints = [
EndpointInfo(
path="".join((path, sub_endpoint.path)),
http_method=sub_endpoint.http_method,
func=sub_endpoint.func,
)
for sub_endpoint in self.get_endpoints(routes)
]
endpoints_info.extend(sub_endpoints)
elif not isinstance(route, Route) or not route.include_in_schema:
continue
elif inspect.isfunction(route.endpoint) or inspect.ismethod(route.endpoint):
path = self._remove_converter(route.path)
for method in route.methods or ["GET"]:
if method == "HEAD":
continue
endpoints_info.append(
EndpointInfo(path, method.lower(), route.endpoint)
)
else:
path = self._remove_converter(route.path)
for method in ["get", "post", "put", "patch", "delete", "options"]:
if not hasattr(route.endpoint, method):
continue
func = getattr(route.endpoint, method)
endpoints_info.append(EndpointInfo(path, method.lower(), func))
return endpoints_info
def _remove_converter(self, path: str) -> str:
"""
Remove the converter from the path.
For example, a route like this:
Route("/users/{id:int}", endpoint=get_user, methods=["GET"])
Should be represented as `/users/{id}` in the OpenAPI schema.
"""
return re.sub(r":\w+}", "}", path)
def parse_docstring(
self, func_or_method: typing.Callable[..., typing.Any]
) -> dict[str, typing.Any]:
"""
Given a function, parse the docstring as YAML and return a dictionary of info.
"""
docstring = func_or_method.__doc__
if not docstring:
return {}
assert yaml is not None, "`pyyaml` must be installed to use parse_docstring."
# We support having regular docstrings before the schema
# definition. Here we return just the schema part from
# the docstring.
docstring = docstring.split("---")[-1]
parsed = yaml.safe_load(docstring)
if not isinstance(parsed, dict):
# A regular docstring (not yaml formatted) can return
# a simple string here, which wouldn't follow the schema.
return {}
return parsed
def OpenAPIResponse(self, request: Request) -> Response:
routes = request.app.routes
schema = self.get_schema(routes=routes)
return OpenAPIResponse(schema)
class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, base_schema: dict[str, typing.Any]) -> None:
self.base_schema = base_schema
def get_schema(self, routes: list[BaseRoute]) -> dict[str, typing.Any]:
schema = dict(self.base_schema)
schema.setdefault("paths", {})
endpoints_info = self.get_endpoints(routes)
for endpoint in endpoints_info:
parsed = self.parse_docstring(endpoint.func)
if not parsed:
continue
if endpoint.path not in schema["paths"]:
schema["paths"][endpoint.path] = {}
schema["paths"][endpoint.path][endpoint.http_method] = parsed
return schema