637 lines
23 KiB
Python
637 lines
23 KiB
Python
|
from collections import deque
|
||
|
from copy import copy
|
||
|
from dataclasses import dataclass, is_dataclass
|
||
|
from enum import Enum
|
||
|
from typing import (
|
||
|
Any,
|
||
|
Callable,
|
||
|
Deque,
|
||
|
Dict,
|
||
|
FrozenSet,
|
||
|
List,
|
||
|
Mapping,
|
||
|
Sequence,
|
||
|
Set,
|
||
|
Tuple,
|
||
|
Type,
|
||
|
Union,
|
||
|
)
|
||
|
|
||
|
from fastapi.exceptions import RequestErrorModel
|
||
|
from fastapi.types import IncEx, ModelNameMap, UnionType
|
||
|
from pydantic import BaseModel, create_model
|
||
|
from pydantic.version import VERSION as P_VERSION
|
||
|
from starlette.datastructures import UploadFile
|
||
|
from typing_extensions import Annotated, Literal, get_args, get_origin
|
||
|
|
||
|
# Reassign variable to make it reexported for mypy
|
||
|
PYDANTIC_VERSION = P_VERSION
|
||
|
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
|
||
|
|
||
|
|
||
|
sequence_annotation_to_type = {
|
||
|
Sequence: list,
|
||
|
List: list,
|
||
|
list: list,
|
||
|
Tuple: tuple,
|
||
|
tuple: tuple,
|
||
|
Set: set,
|
||
|
set: set,
|
||
|
FrozenSet: frozenset,
|
||
|
frozenset: frozenset,
|
||
|
Deque: deque,
|
||
|
deque: deque,
|
||
|
}
|
||
|
|
||
|
sequence_types = tuple(sequence_annotation_to_type.keys())
|
||
|
|
||
|
if PYDANTIC_V2:
|
||
|
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||
|
from pydantic import TypeAdapter
|
||
|
from pydantic import ValidationError as ValidationError
|
||
|
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
|
||
|
GetJsonSchemaHandler as GetJsonSchemaHandler,
|
||
|
)
|
||
|
from pydantic._internal._typing_extra import eval_type_lenient
|
||
|
from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
|
||
|
from pydantic.fields import FieldInfo
|
||
|
from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
|
||
|
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
|
||
|
from pydantic_core import CoreSchema as CoreSchema
|
||
|
from pydantic_core import PydanticUndefined, PydanticUndefinedType
|
||
|
from pydantic_core import Url as Url
|
||
|
|
||
|
try:
|
||
|
from pydantic_core.core_schema import (
|
||
|
with_info_plain_validator_function as with_info_plain_validator_function,
|
||
|
)
|
||
|
except ImportError: # pragma: no cover
|
||
|
from pydantic_core.core_schema import (
|
||
|
general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
|
||
|
)
|
||
|
|
||
|
Required = PydanticUndefined
|
||
|
Undefined = PydanticUndefined
|
||
|
UndefinedType = PydanticUndefinedType
|
||
|
evaluate_forwardref = eval_type_lenient
|
||
|
Validator = Any
|
||
|
|
||
|
class BaseConfig:
|
||
|
pass
|
||
|
|
||
|
class ErrorWrapper(Exception):
|
||
|
pass
|
||
|
|
||
|
@dataclass
|
||
|
class ModelField:
|
||
|
field_info: FieldInfo
|
||
|
name: str
|
||
|
mode: Literal["validation", "serialization"] = "validation"
|
||
|
|
||
|
@property
|
||
|
def alias(self) -> str:
|
||
|
a = self.field_info.alias
|
||
|
return a if a is not None else self.name
|
||
|
|
||
|
@property
|
||
|
def required(self) -> bool:
|
||
|
return self.field_info.is_required()
|
||
|
|
||
|
@property
|
||
|
def default(self) -> Any:
|
||
|
return self.get_default()
|
||
|
|
||
|
@property
|
||
|
def type_(self) -> Any:
|
||
|
return self.field_info.annotation
|
||
|
|
||
|
def __post_init__(self) -> None:
|
||
|
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
|
||
|
Annotated[self.field_info.annotation, self.field_info]
|
||
|
)
|
||
|
|
||
|
def get_default(self) -> Any:
|
||
|
if self.field_info.is_required():
|
||
|
return Undefined
|
||
|
return self.field_info.get_default(call_default_factory=True)
|
||
|
|
||
|
def validate(
|
||
|
self,
|
||
|
value: Any,
|
||
|
values: Dict[str, Any] = {}, # noqa: B006
|
||
|
*,
|
||
|
loc: Tuple[Union[int, str], ...] = (),
|
||
|
) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
|
||
|
try:
|
||
|
return (
|
||
|
self._type_adapter.validate_python(value, from_attributes=True),
|
||
|
None,
|
||
|
)
|
||
|
except ValidationError as exc:
|
||
|
return None, _regenerate_error_with_loc(
|
||
|
errors=exc.errors(include_url=False), loc_prefix=loc
|
||
|
)
|
||
|
|
||
|
def serialize(
|
||
|
self,
|
||
|
value: Any,
|
||
|
*,
|
||
|
mode: Literal["json", "python"] = "json",
|
||
|
include: Union[IncEx, None] = None,
|
||
|
exclude: Union[IncEx, None] = None,
|
||
|
by_alias: bool = True,
|
||
|
exclude_unset: bool = False,
|
||
|
exclude_defaults: bool = False,
|
||
|
exclude_none: bool = False,
|
||
|
) -> Any:
|
||
|
# What calls this code passes a value that already called
|
||
|
# self._type_adapter.validate_python(value)
|
||
|
return self._type_adapter.dump_python(
|
||
|
value,
|
||
|
mode=mode,
|
||
|
include=include,
|
||
|
exclude=exclude,
|
||
|
by_alias=by_alias,
|
||
|
exclude_unset=exclude_unset,
|
||
|
exclude_defaults=exclude_defaults,
|
||
|
exclude_none=exclude_none,
|
||
|
)
|
||
|
|
||
|
def __hash__(self) -> int:
|
||
|
# Each ModelField is unique for our purposes, to allow making a dict from
|
||
|
# ModelField to its JSON Schema.
|
||
|
return id(self)
|
||
|
|
||
|
def get_annotation_from_field_info(
|
||
|
annotation: Any, field_info: FieldInfo, field_name: str
|
||
|
) -> Any:
|
||
|
return annotation
|
||
|
|
||
|
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
||
|
return errors # type: ignore[return-value]
|
||
|
|
||
|
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||
|
model.model_rebuild()
|
||
|
|
||
|
def _model_dump(
|
||
|
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||
|
) -> Any:
|
||
|
return model.model_dump(mode=mode, **kwargs)
|
||
|
|
||
|
def _get_model_config(model: BaseModel) -> Any:
|
||
|
return model.model_config
|
||
|
|
||
|
def get_schema_from_model_field(
|
||
|
*,
|
||
|
field: ModelField,
|
||
|
schema_generator: GenerateJsonSchema,
|
||
|
model_name_map: ModelNameMap,
|
||
|
field_mapping: Dict[
|
||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||
|
],
|
||
|
separate_input_output_schemas: bool = True,
|
||
|
) -> Dict[str, Any]:
|
||
|
override_mode: Union[Literal["validation"], None] = (
|
||
|
None if separate_input_output_schemas else "validation"
|
||
|
)
|
||
|
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||
|
json_schema = field_mapping[(field, override_mode or field.mode)]
|
||
|
if "$ref" not in json_schema:
|
||
|
# TODO remove when deprecating Pydantic v1
|
||
|
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||
|
json_schema["title"] = (
|
||
|
field.field_info.title or field.alias.title().replace("_", " ")
|
||
|
)
|
||
|
return json_schema
|
||
|
|
||
|
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||
|
return {}
|
||
|
|
||
|
def get_definitions(
|
||
|
*,
|
||
|
fields: List[ModelField],
|
||
|
schema_generator: GenerateJsonSchema,
|
||
|
model_name_map: ModelNameMap,
|
||
|
separate_input_output_schemas: bool = True,
|
||
|
) -> Tuple[
|
||
|
Dict[
|
||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||
|
],
|
||
|
Dict[str, Dict[str, Any]],
|
||
|
]:
|
||
|
override_mode: Union[Literal["validation"], None] = (
|
||
|
None if separate_input_output_schemas else "validation"
|
||
|
)
|
||
|
inputs = [
|
||
|
(field, override_mode or field.mode, field._type_adapter.core_schema)
|
||
|
for field in fields
|
||
|
]
|
||
|
field_mapping, definitions = schema_generator.generate_definitions(
|
||
|
inputs=inputs
|
||
|
)
|
||
|
return field_mapping, definitions # type: ignore[return-value]
|
||
|
|
||
|
def is_scalar_field(field: ModelField) -> bool:
|
||
|
from fastapi import params
|
||
|
|
||
|
return field_annotation_is_scalar(
|
||
|
field.field_info.annotation
|
||
|
) and not isinstance(field.field_info, params.Body)
|
||
|
|
||
|
def is_sequence_field(field: ModelField) -> bool:
|
||
|
return field_annotation_is_sequence(field.field_info.annotation)
|
||
|
|
||
|
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||
|
return field_annotation_is_scalar_sequence(field.field_info.annotation)
|
||
|
|
||
|
def is_bytes_field(field: ModelField) -> bool:
|
||
|
return is_bytes_or_nonable_bytes_annotation(field.type_)
|
||
|
|
||
|
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||
|
return is_bytes_sequence_annotation(field.type_)
|
||
|
|
||
|
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||
|
cls = type(field_info)
|
||
|
merged_field_info = cls.from_annotation(annotation)
|
||
|
new_field_info = copy(field_info)
|
||
|
new_field_info.metadata = merged_field_info.metadata
|
||
|
new_field_info.annotation = merged_field_info.annotation
|
||
|
return new_field_info
|
||
|
|
||
|
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||
|
origin_type = (
|
||
|
get_origin(field.field_info.annotation) or field.field_info.annotation
|
||
|
)
|
||
|
assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
|
||
|
return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
|
||
|
|
||
|
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||
|
error = ValidationError.from_exception_data(
|
||
|
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||
|
).errors(include_url=False)[0]
|
||
|
error["input"] = None
|
||
|
return error # type: ignore[return-value]
|
||
|
|
||
|
def create_body_model(
|
||
|
*, fields: Sequence[ModelField], model_name: str
|
||
|
) -> Type[BaseModel]:
|
||
|
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
|
||
|
BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||
|
return BodyModel
|
||
|
|
||
|
else:
|
||
|
from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
|
||
|
from pydantic import AnyUrl as Url # noqa: F401
|
||
|
from pydantic import ( # type: ignore[assignment]
|
||
|
BaseConfig as BaseConfig, # noqa: F401
|
||
|
)
|
||
|
from pydantic import ValidationError as ValidationError # noqa: F401
|
||
|
from pydantic.class_validators import ( # type: ignore[no-redef]
|
||
|
Validator as Validator, # noqa: F401
|
||
|
)
|
||
|
from pydantic.error_wrappers import ( # type: ignore[no-redef]
|
||
|
ErrorWrapper as ErrorWrapper, # noqa: F401
|
||
|
)
|
||
|
from pydantic.errors import MissingError
|
||
|
from pydantic.fields import ( # type: ignore[attr-defined]
|
||
|
SHAPE_FROZENSET,
|
||
|
SHAPE_LIST,
|
||
|
SHAPE_SEQUENCE,
|
||
|
SHAPE_SET,
|
||
|
SHAPE_SINGLETON,
|
||
|
SHAPE_TUPLE,
|
||
|
SHAPE_TUPLE_ELLIPSIS,
|
||
|
)
|
||
|
from pydantic.fields import FieldInfo as FieldInfo
|
||
|
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||
|
ModelField as ModelField, # noqa: F401
|
||
|
)
|
||
|
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||
|
Required as Required, # noqa: F401
|
||
|
)
|
||
|
from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
|
||
|
Undefined as Undefined,
|
||
|
)
|
||
|
from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
|
||
|
UndefinedType as UndefinedType, # noqa: F401
|
||
|
)
|
||
|
from pydantic.schema import (
|
||
|
field_schema,
|
||
|
get_flat_models_from_fields,
|
||
|
get_model_name_map,
|
||
|
model_process_schema,
|
||
|
)
|
||
|
from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401
|
||
|
get_annotation_from_field_info as get_annotation_from_field_info,
|
||
|
)
|
||
|
from pydantic.typing import ( # type: ignore[no-redef]
|
||
|
evaluate_forwardref as evaluate_forwardref, # noqa: F401
|
||
|
)
|
||
|
from pydantic.utils import ( # type: ignore[no-redef]
|
||
|
lenient_issubclass as lenient_issubclass, # noqa: F401
|
||
|
)
|
||
|
|
||
|
GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
|
||
|
JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
|
||
|
CoreSchema = Any # type: ignore[assignment,misc]
|
||
|
|
||
|
sequence_shapes = {
|
||
|
SHAPE_LIST,
|
||
|
SHAPE_SET,
|
||
|
SHAPE_FROZENSET,
|
||
|
SHAPE_TUPLE,
|
||
|
SHAPE_SEQUENCE,
|
||
|
SHAPE_TUPLE_ELLIPSIS,
|
||
|
}
|
||
|
sequence_shape_to_type = {
|
||
|
SHAPE_LIST: list,
|
||
|
SHAPE_SET: set,
|
||
|
SHAPE_TUPLE: tuple,
|
||
|
SHAPE_SEQUENCE: list,
|
||
|
SHAPE_TUPLE_ELLIPSIS: list,
|
||
|
}
|
||
|
|
||
|
@dataclass
|
||
|
class GenerateJsonSchema: # type: ignore[no-redef]
|
||
|
ref_template: str
|
||
|
|
||
|
class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef]
|
||
|
pass
|
||
|
|
||
|
def with_info_plain_validator_function( # type: ignore[misc]
|
||
|
function: Callable[..., Any],
|
||
|
*,
|
||
|
ref: Union[str, None] = None,
|
||
|
metadata: Any = None,
|
||
|
serialization: Any = None,
|
||
|
) -> Any:
|
||
|
return {}
|
||
|
|
||
|
def get_model_definitions(
|
||
|
*,
|
||
|
flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
|
||
|
model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
|
||
|
) -> Dict[str, Any]:
|
||
|
definitions: Dict[str, Dict[str, Any]] = {}
|
||
|
for model in flat_models:
|
||
|
m_schema, m_definitions, m_nested_models = model_process_schema(
|
||
|
model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||
|
)
|
||
|
definitions.update(m_definitions)
|
||
|
model_name = model_name_map[model]
|
||
|
if "description" in m_schema:
|
||
|
m_schema["description"] = m_schema["description"].split("\f")[0]
|
||
|
definitions[model_name] = m_schema
|
||
|
return definitions
|
||
|
|
||
|
def is_pv1_scalar_field(field: ModelField) -> bool:
|
||
|
from fastapi import params
|
||
|
|
||
|
field_info = field.field_info
|
||
|
if not (
|
||
|
field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
|
||
|
and not lenient_issubclass(field.type_, BaseModel)
|
||
|
and not lenient_issubclass(field.type_, dict)
|
||
|
and not field_annotation_is_sequence(field.type_)
|
||
|
and not is_dataclass(field.type_)
|
||
|
and not isinstance(field_info, params.Body)
|
||
|
):
|
||
|
return False
|
||
|
if field.sub_fields: # type: ignore[attr-defined]
|
||
|
if not all(
|
||
|
is_pv1_scalar_field(f)
|
||
|
for f in field.sub_fields # type: ignore[attr-defined]
|
||
|
):
|
||
|
return False
|
||
|
return True
|
||
|
|
||
|
def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
|
||
|
if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
|
||
|
field.type_, BaseModel
|
||
|
):
|
||
|
if field.sub_fields is not None: # type: ignore[attr-defined]
|
||
|
for sub_field in field.sub_fields: # type: ignore[attr-defined]
|
||
|
if not is_pv1_scalar_field(sub_field):
|
||
|
return False
|
||
|
return True
|
||
|
if _annotation_is_sequence(field.type_):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
|
||
|
use_errors: List[Any] = []
|
||
|
for error in errors:
|
||
|
if isinstance(error, ErrorWrapper):
|
||
|
new_errors = ValidationError( # type: ignore[call-arg]
|
||
|
errors=[error], model=RequestErrorModel
|
||
|
).errors()
|
||
|
use_errors.extend(new_errors)
|
||
|
elif isinstance(error, list):
|
||
|
use_errors.extend(_normalize_errors(error))
|
||
|
else:
|
||
|
use_errors.append(error)
|
||
|
return use_errors
|
||
|
|
||
|
def _model_rebuild(model: Type[BaseModel]) -> None:
|
||
|
model.update_forward_refs()
|
||
|
|
||
|
def _model_dump(
|
||
|
model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
|
||
|
) -> Any:
|
||
|
return model.dict(**kwargs)
|
||
|
|
||
|
def _get_model_config(model: BaseModel) -> Any:
|
||
|
return model.__config__ # type: ignore[attr-defined]
|
||
|
|
||
|
def get_schema_from_model_field(
|
||
|
*,
|
||
|
field: ModelField,
|
||
|
schema_generator: GenerateJsonSchema,
|
||
|
model_name_map: ModelNameMap,
|
||
|
field_mapping: Dict[
|
||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||
|
],
|
||
|
separate_input_output_schemas: bool = True,
|
||
|
) -> Dict[str, Any]:
|
||
|
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||
|
return field_schema( # type: ignore[no-any-return]
|
||
|
field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
|
||
|
)[0]
|
||
|
|
||
|
def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
|
||
|
models = get_flat_models_from_fields(fields, known_models=set())
|
||
|
return get_model_name_map(models) # type: ignore[no-any-return]
|
||
|
|
||
|
def get_definitions(
|
||
|
*,
|
||
|
fields: List[ModelField],
|
||
|
schema_generator: GenerateJsonSchema,
|
||
|
model_name_map: ModelNameMap,
|
||
|
separate_input_output_schemas: bool = True,
|
||
|
) -> Tuple[
|
||
|
Dict[
|
||
|
Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||
|
],
|
||
|
Dict[str, Dict[str, Any]],
|
||
|
]:
|
||
|
models = get_flat_models_from_fields(fields, known_models=set())
|
||
|
return {}, get_model_definitions(
|
||
|
flat_models=models, model_name_map=model_name_map
|
||
|
)
|
||
|
|
||
|
def is_scalar_field(field: ModelField) -> bool:
|
||
|
return is_pv1_scalar_field(field)
|
||
|
|
||
|
def is_sequence_field(field: ModelField) -> bool:
|
||
|
return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
|
||
|
|
||
|
def is_scalar_sequence_field(field: ModelField) -> bool:
|
||
|
return is_pv1_scalar_sequence_field(field)
|
||
|
|
||
|
def is_bytes_field(field: ModelField) -> bool:
|
||
|
return lenient_issubclass(field.type_, bytes)
|
||
|
|
||
|
def is_bytes_sequence_field(field: ModelField) -> bool:
|
||
|
return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
|
||
|
|
||
|
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||
|
return copy(field_info)
|
||
|
|
||
|
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||
|
return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
|
||
|
|
||
|
def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
|
||
|
missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
|
||
|
new_error = ValidationError([missing_field_error], RequestErrorModel)
|
||
|
return new_error.errors()[0] # type: ignore[return-value]
|
||
|
|
||
|
def create_body_model(
|
||
|
*, fields: Sequence[ModelField], model_name: str
|
||
|
) -> Type[BaseModel]:
|
||
|
BodyModel = create_model(model_name)
|
||
|
for f in fields:
|
||
|
BodyModel.__fields__[f.name] = f # type: ignore[index]
|
||
|
return BodyModel
|
||
|
|
||
|
|
||
|
def _regenerate_error_with_loc(
|
||
|
*, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
|
||
|
) -> List[Dict[str, Any]]:
|
||
|
updated_loc_errors: List[Any] = [
|
||
|
{**err, "loc": loc_prefix + err.get("loc", ())}
|
||
|
for err in _normalize_errors(errors)
|
||
|
]
|
||
|
|
||
|
return updated_loc_errors
|
||
|
|
||
|
|
||
|
def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||
|
if lenient_issubclass(annotation, (str, bytes)):
|
||
|
return False
|
||
|
return lenient_issubclass(annotation, sequence_types)
|
||
|
|
||
|
|
||
|
def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||
|
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||
|
get_origin(annotation)
|
||
|
)
|
||
|
|
||
|
|
||
|
def value_is_sequence(value: Any) -> bool:
|
||
|
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
|
||
|
|
||
|
|
||
|
def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
||
|
return (
|
||
|
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
|
||
|
or _annotation_is_sequence(annotation)
|
||
|
or is_dataclass(annotation)
|
||
|
)
|
||
|
|
||
|
|
||
|
def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
|
||
|
origin = get_origin(annotation)
|
||
|
if origin is Union or origin is UnionType:
|
||
|
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
|
||
|
|
||
|
return (
|
||
|
_annotation_is_complex(annotation)
|
||
|
or _annotation_is_complex(origin)
|
||
|
or hasattr(origin, "__pydantic_core_schema__")
|
||
|
or hasattr(origin, "__get_pydantic_core_schema__")
|
||
|
)
|
||
|
|
||
|
|
||
|
def field_annotation_is_scalar(annotation: Any) -> bool:
|
||
|
# handle Ellipsis here to make tuple[int, ...] work nicely
|
||
|
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
|
||
|
|
||
|
|
||
|
def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
|
||
|
origin = get_origin(annotation)
|
||
|
if origin is Union or origin is UnionType:
|
||
|
at_least_one_scalar_sequence = False
|
||
|
for arg in get_args(annotation):
|
||
|
if field_annotation_is_scalar_sequence(arg):
|
||
|
at_least_one_scalar_sequence = True
|
||
|
continue
|
||
|
elif not field_annotation_is_scalar(arg):
|
||
|
return False
|
||
|
return at_least_one_scalar_sequence
|
||
|
return field_annotation_is_sequence(annotation) and all(
|
||
|
field_annotation_is_scalar(sub_annotation)
|
||
|
for sub_annotation in get_args(annotation)
|
||
|
)
|
||
|
|
||
|
|
||
|
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||
|
if lenient_issubclass(annotation, bytes):
|
||
|
return True
|
||
|
origin = get_origin(annotation)
|
||
|
if origin is Union or origin is UnionType:
|
||
|
for arg in get_args(annotation):
|
||
|
if lenient_issubclass(arg, bytes):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
|
||
|
if lenient_issubclass(annotation, UploadFile):
|
||
|
return True
|
||
|
origin = get_origin(annotation)
|
||
|
if origin is Union or origin is UnionType:
|
||
|
for arg in get_args(annotation):
|
||
|
if lenient_issubclass(arg, UploadFile):
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
def is_bytes_sequence_annotation(annotation: Any) -> bool:
|
||
|
origin = get_origin(annotation)
|
||
|
if origin is Union or origin is UnionType:
|
||
|
at_least_one = False
|
||
|
for arg in get_args(annotation):
|
||
|
if is_bytes_sequence_annotation(arg):
|
||
|
at_least_one = True
|
||
|
continue
|
||
|
return at_least_one
|
||
|
return field_annotation_is_sequence(annotation) and all(
|
||
|
is_bytes_or_nonable_bytes_annotation(sub_annotation)
|
||
|
for sub_annotation in get_args(annotation)
|
||
|
)
|
||
|
|
||
|
|
||
|
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||
|
origin = get_origin(annotation)
|
||
|
if origin is Union or origin is UnionType:
|
||
|
at_least_one = False
|
||
|
for arg in get_args(annotation):
|
||
|
if is_uploadfile_sequence_annotation(arg):
|
||
|
at_least_one = True
|
||
|
continue
|
||
|
return at_least_one
|
||
|
return field_annotation_is_sequence(annotation) and all(
|
||
|
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||
|
for sub_annotation in get_args(annotation)
|
||
|
)
|