1314 lines
56 KiB
Python
1314 lines
56 KiB
Python
|
"""This module includes classes and functions designed specifically for use with the mypy plugin."""
|
||
|
|
||
|
from __future__ import annotations
|
||
|
|
||
|
import sys
|
||
|
from configparser import ConfigParser
|
||
|
from typing import Any, Callable, Iterator
|
||
|
|
||
|
from mypy.errorcodes import ErrorCode
|
||
|
from mypy.expandtype import expand_type, expand_type_by_instance
|
||
|
from mypy.nodes import (
|
||
|
ARG_NAMED,
|
||
|
ARG_NAMED_OPT,
|
||
|
ARG_OPT,
|
||
|
ARG_POS,
|
||
|
ARG_STAR2,
|
||
|
INVARIANT,
|
||
|
MDEF,
|
||
|
Argument,
|
||
|
AssignmentStmt,
|
||
|
Block,
|
||
|
CallExpr,
|
||
|
ClassDef,
|
||
|
Context,
|
||
|
Decorator,
|
||
|
DictExpr,
|
||
|
EllipsisExpr,
|
||
|
Expression,
|
||
|
FuncDef,
|
||
|
IfStmt,
|
||
|
JsonDict,
|
||
|
MemberExpr,
|
||
|
NameExpr,
|
||
|
PassStmt,
|
||
|
PlaceholderNode,
|
||
|
RefExpr,
|
||
|
Statement,
|
||
|
StrExpr,
|
||
|
SymbolTableNode,
|
||
|
TempNode,
|
||
|
TypeAlias,
|
||
|
TypeInfo,
|
||
|
Var,
|
||
|
)
|
||
|
from mypy.options import Options
|
||
|
from mypy.plugin import (
|
||
|
CheckerPluginInterface,
|
||
|
ClassDefContext,
|
||
|
FunctionContext,
|
||
|
MethodContext,
|
||
|
Plugin,
|
||
|
ReportConfigContext,
|
||
|
SemanticAnalyzerPluginInterface,
|
||
|
)
|
||
|
from mypy.plugins import dataclasses
|
||
|
from mypy.plugins.common import (
|
||
|
deserialize_and_fixup_type,
|
||
|
)
|
||
|
from mypy.semanal import set_callable_name
|
||
|
from mypy.server.trigger import make_wildcard_trigger
|
||
|
from mypy.state import state
|
||
|
from mypy.typeops import map_type_from_supertype
|
||
|
from mypy.types import (
|
||
|
AnyType,
|
||
|
CallableType,
|
||
|
Instance,
|
||
|
NoneType,
|
||
|
Overloaded,
|
||
|
Type,
|
||
|
TypeOfAny,
|
||
|
TypeType,
|
||
|
TypeVarType,
|
||
|
UnionType,
|
||
|
get_proper_type,
|
||
|
)
|
||
|
from mypy.typevars import fill_typevars
|
||
|
from mypy.util import get_unique_redefinition_name
|
||
|
from mypy.version import __version__ as mypy_version
|
||
|
|
||
|
from pydantic._internal import _fields
|
||
|
from pydantic.version import parse_mypy_version
|
||
|
|
||
|
try:
|
||
|
from mypy.types import TypeVarDef # type: ignore[attr-defined]
|
||
|
except ImportError: # pragma: no cover
|
||
|
# Backward-compatible with TypeVarDef from Mypy 0.930.
|
||
|
from mypy.types import TypeVarType as TypeVarDef
|
||
|
|
||
|
CONFIGFILE_KEY = 'pydantic-mypy'
|
||
|
METADATA_KEY = 'pydantic-mypy-metadata'
|
||
|
BASEMODEL_FULLNAME = 'pydantic.main.BaseModel'
|
||
|
BASESETTINGS_FULLNAME = 'pydantic_settings.main.BaseSettings'
|
||
|
ROOT_MODEL_FULLNAME = 'pydantic.root_model.RootModel'
|
||
|
MODEL_METACLASS_FULLNAME = 'pydantic._internal._model_construction.ModelMetaclass'
|
||
|
FIELD_FULLNAME = 'pydantic.fields.Field'
|
||
|
DATACLASS_FULLNAME = 'pydantic.dataclasses.dataclass'
|
||
|
MODEL_VALIDATOR_FULLNAME = 'pydantic.functional_validators.model_validator'
|
||
|
DECORATOR_FULLNAMES = {
|
||
|
'pydantic.functional_validators.field_validator',
|
||
|
'pydantic.functional_validators.model_validator',
|
||
|
'pydantic.functional_serializers.serializer',
|
||
|
'pydantic.functional_serializers.model_serializer',
|
||
|
'pydantic.deprecated.class_validators.validator',
|
||
|
'pydantic.deprecated.class_validators.root_validator',
|
||
|
}
|
||
|
|
||
|
|
||
|
MYPY_VERSION_TUPLE = parse_mypy_version(mypy_version)
|
||
|
BUILTINS_NAME = 'builtins' if MYPY_VERSION_TUPLE >= (0, 930) else '__builtins__'
|
||
|
|
||
|
# Increment version if plugin changes and mypy caches should be invalidated
|
||
|
__version__ = 2
|
||
|
|
||
|
|
||
|
def plugin(version: str) -> type[Plugin]:
|
||
|
"""`version` is the mypy version string.
|
||
|
|
||
|
We might want to use this to print a warning if the mypy version being used is
|
||
|
newer, or especially older, than we expect (or need).
|
||
|
|
||
|
Args:
|
||
|
version: The mypy version string.
|
||
|
|
||
|
Return:
|
||
|
The Pydantic mypy plugin type.
|
||
|
"""
|
||
|
return PydanticPlugin
|
||
|
|
||
|
|
||
|
class PydanticPlugin(Plugin):
|
||
|
"""The Pydantic mypy plugin."""
|
||
|
|
||
|
def __init__(self, options: Options) -> None:
|
||
|
self.plugin_config = PydanticPluginConfig(options)
|
||
|
self._plugin_data = self.plugin_config.to_data()
|
||
|
super().__init__(options)
|
||
|
|
||
|
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], bool] | None:
|
||
|
"""Update Pydantic model class."""
|
||
|
sym = self.lookup_fully_qualified(fullname)
|
||
|
if sym and isinstance(sym.node, TypeInfo): # pragma: no branch
|
||
|
# No branching may occur if the mypy cache has not been cleared
|
||
|
if any(base.fullname == BASEMODEL_FULLNAME for base in sym.node.mro):
|
||
|
return self._pydantic_model_class_maker_callback
|
||
|
return None
|
||
|
|
||
|
def get_metaclass_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
|
||
|
"""Update Pydantic `ModelMetaclass` definition."""
|
||
|
if fullname == MODEL_METACLASS_FULLNAME:
|
||
|
return self._pydantic_model_metaclass_marker_callback
|
||
|
return None
|
||
|
|
||
|
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
|
||
|
"""Adjust the return type of the `Field` function."""
|
||
|
sym = self.lookup_fully_qualified(fullname)
|
||
|
if sym and sym.fullname == FIELD_FULLNAME:
|
||
|
return self._pydantic_field_callback
|
||
|
return None
|
||
|
|
||
|
def get_method_hook(self, fullname: str) -> Callable[[MethodContext], Type] | None:
|
||
|
"""Adjust return type of `from_orm` method call."""
|
||
|
if fullname.endswith('.from_orm'):
|
||
|
return from_attributes_callback
|
||
|
return None
|
||
|
|
||
|
def get_class_decorator_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
|
||
|
"""Mark pydantic.dataclasses as dataclass.
|
||
|
|
||
|
Mypy version 1.1.1 added support for `@dataclass_transform` decorator.
|
||
|
"""
|
||
|
if fullname == DATACLASS_FULLNAME and MYPY_VERSION_TUPLE < (1, 1):
|
||
|
return dataclasses.dataclass_class_maker_callback # type: ignore[return-value]
|
||
|
return None
|
||
|
|
||
|
def report_config_data(self, ctx: ReportConfigContext) -> dict[str, Any]:
|
||
|
"""Return all plugin config data.
|
||
|
|
||
|
Used by mypy to determine if cache needs to be discarded.
|
||
|
"""
|
||
|
return self._plugin_data
|
||
|
|
||
|
def _pydantic_model_class_maker_callback(self, ctx: ClassDefContext) -> bool:
|
||
|
transformer = PydanticModelTransformer(ctx.cls, ctx.reason, ctx.api, self.plugin_config)
|
||
|
return transformer.transform()
|
||
|
|
||
|
def _pydantic_model_metaclass_marker_callback(self, ctx: ClassDefContext) -> None:
|
||
|
"""Reset dataclass_transform_spec attribute of ModelMetaclass.
|
||
|
|
||
|
Let the plugin handle it. This behavior can be disabled
|
||
|
if 'debug_dataclass_transform' is set to True', for testing purposes.
|
||
|
"""
|
||
|
if self.plugin_config.debug_dataclass_transform:
|
||
|
return
|
||
|
info_metaclass = ctx.cls.info.declared_metaclass
|
||
|
assert info_metaclass, "callback not passed from 'get_metaclass_hook'"
|
||
|
if getattr(info_metaclass.type, 'dataclass_transform_spec', None):
|
||
|
info_metaclass.type.dataclass_transform_spec = None
|
||
|
|
||
|
def _pydantic_field_callback(self, ctx: FunctionContext) -> Type:
|
||
|
"""Extract the type of the `default` argument from the Field function, and use it as the return type.
|
||
|
|
||
|
In particular:
|
||
|
* Check whether the default and default_factory argument is specified.
|
||
|
* Output an error if both are specified.
|
||
|
* Retrieve the type of the argument which is specified, and use it as return type for the function.
|
||
|
"""
|
||
|
default_any_type = ctx.default_return_type
|
||
|
|
||
|
assert ctx.callee_arg_names[0] == 'default', '"default" is no longer first argument in Field()'
|
||
|
assert ctx.callee_arg_names[1] == 'default_factory', '"default_factory" is no longer second argument in Field()'
|
||
|
default_args = ctx.args[0]
|
||
|
default_factory_args = ctx.args[1]
|
||
|
|
||
|
if default_args and default_factory_args:
|
||
|
error_default_and_default_factory_specified(ctx.api, ctx.context)
|
||
|
return default_any_type
|
||
|
|
||
|
if default_args:
|
||
|
default_type = ctx.arg_types[0][0]
|
||
|
default_arg = default_args[0]
|
||
|
|
||
|
# Fallback to default Any type if the field is required
|
||
|
if not isinstance(default_arg, EllipsisExpr):
|
||
|
return default_type
|
||
|
|
||
|
elif default_factory_args:
|
||
|
default_factory_type = ctx.arg_types[1][0]
|
||
|
|
||
|
# Functions which use `ParamSpec` can be overloaded, exposing the callable's types as a parameter
|
||
|
# Pydantic calls the default factory without any argument, so we retrieve the first item
|
||
|
if isinstance(default_factory_type, Overloaded):
|
||
|
default_factory_type = default_factory_type.items[0]
|
||
|
|
||
|
if isinstance(default_factory_type, CallableType):
|
||
|
ret_type = default_factory_type.ret_type
|
||
|
# mypy doesn't think `ret_type` has `args`, you'd think mypy should know,
|
||
|
# add this check in case it varies by version
|
||
|
args = getattr(ret_type, 'args', None)
|
||
|
if args:
|
||
|
if all(isinstance(arg, TypeVarType) for arg in args):
|
||
|
# Looks like the default factory is a type like `list` or `dict`, replace all args with `Any`
|
||
|
ret_type.args = tuple(default_any_type for _ in args) # type: ignore[attr-defined]
|
||
|
return ret_type
|
||
|
|
||
|
return default_any_type
|
||
|
|
||
|
|
||
|
class PydanticPluginConfig:
|
||
|
"""A Pydantic mypy plugin config holder.
|
||
|
|
||
|
Attributes:
|
||
|
init_forbid_extra: Whether to add a `**kwargs` at the end of the generated `__init__` signature.
|
||
|
init_typed: Whether to annotate fields in the generated `__init__`.
|
||
|
warn_required_dynamic_aliases: Whether to raise required dynamic aliases error.
|
||
|
debug_dataclass_transform: Whether to not reset `dataclass_transform_spec` attribute
|
||
|
of `ModelMetaclass` for testing purposes.
|
||
|
"""
|
||
|
|
||
|
__slots__ = (
|
||
|
'init_forbid_extra',
|
||
|
'init_typed',
|
||
|
'warn_required_dynamic_aliases',
|
||
|
'debug_dataclass_transform',
|
||
|
)
|
||
|
init_forbid_extra: bool
|
||
|
init_typed: bool
|
||
|
warn_required_dynamic_aliases: bool
|
||
|
debug_dataclass_transform: bool # undocumented
|
||
|
|
||
|
def __init__(self, options: Options) -> None:
|
||
|
if options.config_file is None: # pragma: no cover
|
||
|
return
|
||
|
|
||
|
toml_config = parse_toml(options.config_file)
|
||
|
if toml_config is not None:
|
||
|
config = toml_config.get('tool', {}).get('pydantic-mypy', {})
|
||
|
for key in self.__slots__:
|
||
|
setting = config.get(key, False)
|
||
|
if not isinstance(setting, bool):
|
||
|
raise ValueError(f'Configuration value must be a boolean for key: {key}')
|
||
|
setattr(self, key, setting)
|
||
|
else:
|
||
|
plugin_config = ConfigParser()
|
||
|
plugin_config.read(options.config_file)
|
||
|
for key in self.__slots__:
|
||
|
setting = plugin_config.getboolean(CONFIGFILE_KEY, key, fallback=False)
|
||
|
setattr(self, key, setting)
|
||
|
|
||
|
def to_data(self) -> dict[str, Any]:
|
||
|
"""Returns a dict of config names to their values."""
|
||
|
return {key: getattr(self, key) for key in self.__slots__}
|
||
|
|
||
|
|
||
|
def from_attributes_callback(ctx: MethodContext) -> Type:
|
||
|
"""Raise an error if from_attributes is not enabled."""
|
||
|
model_type: Instance
|
||
|
ctx_type = ctx.type
|
||
|
if isinstance(ctx_type, TypeType):
|
||
|
ctx_type = ctx_type.item
|
||
|
if isinstance(ctx_type, CallableType) and isinstance(ctx_type.ret_type, Instance):
|
||
|
model_type = ctx_type.ret_type # called on the class
|
||
|
elif isinstance(ctx_type, Instance):
|
||
|
model_type = ctx_type # called on an instance (unusual, but still valid)
|
||
|
else: # pragma: no cover
|
||
|
detail = f'ctx.type: {ctx_type} (of type {ctx_type.__class__.__name__})'
|
||
|
error_unexpected_behavior(detail, ctx.api, ctx.context)
|
||
|
return ctx.default_return_type
|
||
|
pydantic_metadata = model_type.type.metadata.get(METADATA_KEY)
|
||
|
if pydantic_metadata is None:
|
||
|
return ctx.default_return_type
|
||
|
from_attributes = pydantic_metadata.get('config', {}).get('from_attributes')
|
||
|
if from_attributes is not True:
|
||
|
error_from_attributes(model_type.type.name, ctx.api, ctx.context)
|
||
|
return ctx.default_return_type
|
||
|
|
||
|
|
||
|
class PydanticModelField:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
name: str,
|
||
|
alias: str | None,
|
||
|
has_dynamic_alias: bool,
|
||
|
has_default: bool,
|
||
|
line: int,
|
||
|
column: int,
|
||
|
type: Type | None,
|
||
|
info: TypeInfo,
|
||
|
):
|
||
|
self.name = name
|
||
|
self.alias = alias
|
||
|
self.has_dynamic_alias = has_dynamic_alias
|
||
|
self.has_default = has_default
|
||
|
self.line = line
|
||
|
self.column = column
|
||
|
self.type = type
|
||
|
self.info = info
|
||
|
|
||
|
def to_argument(
|
||
|
self,
|
||
|
current_info: TypeInfo,
|
||
|
typed: bool,
|
||
|
force_optional: bool,
|
||
|
use_alias: bool,
|
||
|
api: SemanticAnalyzerPluginInterface,
|
||
|
force_typevars_invariant: bool,
|
||
|
) -> Argument:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_argument."""
|
||
|
variable = self.to_var(current_info, api, use_alias, force_typevars_invariant)
|
||
|
type_annotation = self.expand_type(current_info, api) if typed else AnyType(TypeOfAny.explicit)
|
||
|
return Argument(
|
||
|
variable=variable,
|
||
|
type_annotation=type_annotation,
|
||
|
initializer=None,
|
||
|
kind=ARG_NAMED_OPT if force_optional or self.has_default else ARG_NAMED,
|
||
|
)
|
||
|
|
||
|
def expand_type(
|
||
|
self, current_info: TypeInfo, api: SemanticAnalyzerPluginInterface, force_typevars_invariant: bool = False
|
||
|
) -> Type | None:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.expand_type."""
|
||
|
# The getattr in the next line is used to prevent errors in legacy versions of mypy without this attribute
|
||
|
if force_typevars_invariant:
|
||
|
# In some cases, mypy will emit an error "Cannot use a covariant type variable as a parameter"
|
||
|
# To prevent that, we add an option to replace typevars with invariant ones while building certain
|
||
|
# method signatures (in particular, `__init__`). There may be a better way to do this, if this causes
|
||
|
# us problems in the future, we should look into why the dataclasses plugin doesn't have this issue.
|
||
|
if isinstance(self.type, TypeVarType):
|
||
|
modified_type = self.type.copy_modified()
|
||
|
modified_type.variance = INVARIANT
|
||
|
self.type = modified_type
|
||
|
|
||
|
if self.type is not None and getattr(self.info, 'self_type', None) is not None:
|
||
|
# In general, it is not safe to call `expand_type()` during semantic analyzis,
|
||
|
# however this plugin is called very late, so all types should be fully ready.
|
||
|
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
|
||
|
# we serialize attributes).
|
||
|
with state.strict_optional_set(api.options.strict_optional):
|
||
|
filled_with_typevars = fill_typevars(current_info)
|
||
|
if force_typevars_invariant:
|
||
|
for arg in filled_with_typevars.args:
|
||
|
if isinstance(arg, TypeVarType):
|
||
|
arg.variance = INVARIANT
|
||
|
return expand_type(self.type, {self.info.self_type.id: filled_with_typevars})
|
||
|
return self.type
|
||
|
|
||
|
def to_var(
|
||
|
self,
|
||
|
current_info: TypeInfo,
|
||
|
api: SemanticAnalyzerPluginInterface,
|
||
|
use_alias: bool,
|
||
|
force_typevars_invariant: bool = False,
|
||
|
) -> Var:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.to_var."""
|
||
|
if use_alias and self.alias is not None:
|
||
|
name = self.alias
|
||
|
else:
|
||
|
name = self.name
|
||
|
|
||
|
return Var(name, self.expand_type(current_info, api, force_typevars_invariant))
|
||
|
|
||
|
def serialize(self) -> JsonDict:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
|
||
|
assert self.type
|
||
|
return {
|
||
|
'name': self.name,
|
||
|
'alias': self.alias,
|
||
|
'has_dynamic_alias': self.has_dynamic_alias,
|
||
|
'has_default': self.has_default,
|
||
|
'line': self.line,
|
||
|
'column': self.column,
|
||
|
'type': self.type.serialize(),
|
||
|
}
|
||
|
|
||
|
@classmethod
|
||
|
def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface) -> PydanticModelField:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize."""
|
||
|
data = data.copy()
|
||
|
typ = deserialize_and_fixup_type(data.pop('type'), api)
|
||
|
return cls(type=typ, info=info, **data)
|
||
|
|
||
|
def expand_typevar_from_subtype(self, sub_type: TypeInfo, api: SemanticAnalyzerPluginInterface) -> None:
|
||
|
"""Expands type vars in the context of a subtype when an attribute is inherited
|
||
|
from a generic super type.
|
||
|
"""
|
||
|
if self.type is not None:
|
||
|
with state.strict_optional_set(api.options.strict_optional):
|
||
|
self.type = map_type_from_supertype(self.type, sub_type, self.info)
|
||
|
|
||
|
|
||
|
class PydanticModelClassVar:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.
|
||
|
|
||
|
ClassVars are ignored by subclasses.
|
||
|
|
||
|
Attributes:
|
||
|
name: the ClassVar name
|
||
|
"""
|
||
|
|
||
|
def __init__(self, name):
|
||
|
self.name = name
|
||
|
|
||
|
@classmethod
|
||
|
def deserialize(cls, data: JsonDict) -> PydanticModelClassVar:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.deserialize."""
|
||
|
data = data.copy()
|
||
|
return cls(**data)
|
||
|
|
||
|
def serialize(self) -> JsonDict:
|
||
|
"""Based on mypy.plugins.dataclasses.DataclassAttribute.serialize."""
|
||
|
return {
|
||
|
'name': self.name,
|
||
|
}
|
||
|
|
||
|
|
||
|
class PydanticModelTransformer:
|
||
|
"""Transform the BaseModel subclass according to the plugin settings.
|
||
|
|
||
|
Attributes:
|
||
|
tracked_config_fields: A set of field configs that the plugin has to track their value.
|
||
|
"""
|
||
|
|
||
|
tracked_config_fields: set[str] = {
|
||
|
'extra',
|
||
|
'frozen',
|
||
|
'from_attributes',
|
||
|
'populate_by_name',
|
||
|
'alias_generator',
|
||
|
}
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
cls: ClassDef,
|
||
|
reason: Expression | Statement,
|
||
|
api: SemanticAnalyzerPluginInterface,
|
||
|
plugin_config: PydanticPluginConfig,
|
||
|
) -> None:
|
||
|
self._cls = cls
|
||
|
self._reason = reason
|
||
|
self._api = api
|
||
|
|
||
|
self.plugin_config = plugin_config
|
||
|
|
||
|
def transform(self) -> bool:
|
||
|
"""Configures the BaseModel subclass according to the plugin settings.
|
||
|
|
||
|
In particular:
|
||
|
|
||
|
* determines the model config and fields,
|
||
|
* adds a fields-aware signature for the initializer and construct methods
|
||
|
* freezes the class if frozen = True
|
||
|
* stores the fields, config, and if the class is settings in the mypy metadata for access by subclasses
|
||
|
"""
|
||
|
info = self._cls.info
|
||
|
is_root_model = any(ROOT_MODEL_FULLNAME in base.fullname for base in info.mro[:-1])
|
||
|
config = self.collect_config()
|
||
|
fields, class_vars = self.collect_fields_and_class_vars(config, is_root_model)
|
||
|
if fields is None or class_vars is None:
|
||
|
# Some definitions are not ready. We need another pass.
|
||
|
return False
|
||
|
for field in fields:
|
||
|
if field.type is None:
|
||
|
return False
|
||
|
|
||
|
is_settings = any(base.fullname == BASESETTINGS_FULLNAME for base in info.mro[:-1])
|
||
|
self.add_initializer(fields, config, is_settings, is_root_model)
|
||
|
if not is_root_model:
|
||
|
self.add_model_construct_method(fields, config, is_settings)
|
||
|
self.set_frozen(fields, self._api, frozen=config.frozen is True)
|
||
|
|
||
|
self.adjust_decorator_signatures()
|
||
|
|
||
|
info.metadata[METADATA_KEY] = {
|
||
|
'fields': {field.name: field.serialize() for field in fields},
|
||
|
'class_vars': {class_var.name: class_var.serialize() for class_var in class_vars},
|
||
|
'config': config.get_values_dict(),
|
||
|
}
|
||
|
|
||
|
return True
|
||
|
|
||
|
def adjust_decorator_signatures(self) -> None:
|
||
|
"""When we decorate a function `f` with `pydantic.validator(...)`, `pydantic.field_validator`
|
||
|
or `pydantic.serializer(...)`, mypy sees `f` as a regular method taking a `self` instance,
|
||
|
even though pydantic internally wraps `f` with `classmethod` if necessary.
|
||
|
|
||
|
Teach mypy this by marking any function whose outermost decorator is a `validator()`,
|
||
|
`field_validator()` or `serializer()` call as a `classmethod`.
|
||
|
"""
|
||
|
for name, sym in self._cls.info.names.items():
|
||
|
if isinstance(sym.node, Decorator):
|
||
|
first_dec = sym.node.original_decorators[0]
|
||
|
if (
|
||
|
isinstance(first_dec, CallExpr)
|
||
|
and isinstance(first_dec.callee, NameExpr)
|
||
|
and first_dec.callee.fullname in DECORATOR_FULLNAMES
|
||
|
# @model_validator(mode="after") is an exception, it expects a regular method
|
||
|
and not (
|
||
|
first_dec.callee.fullname == MODEL_VALIDATOR_FULLNAME
|
||
|
and any(
|
||
|
first_dec.arg_names[i] == 'mode' and isinstance(arg, StrExpr) and arg.value == 'after'
|
||
|
for i, arg in enumerate(first_dec.args)
|
||
|
)
|
||
|
)
|
||
|
):
|
||
|
# TODO: Only do this if the first argument of the decorated function is `cls`
|
||
|
sym.node.func.is_class = True
|
||
|
|
||
|
def collect_config(self) -> ModelConfigData: # noqa: C901 (ignore complexity)
|
||
|
"""Collects the values of the config attributes that are used by the plugin, accounting for parent classes."""
|
||
|
cls = self._cls
|
||
|
config = ModelConfigData()
|
||
|
|
||
|
has_config_kwargs = False
|
||
|
has_config_from_namespace = False
|
||
|
|
||
|
# Handle `class MyModel(BaseModel, <name>=<expr>, ...):`
|
||
|
for name, expr in cls.keywords.items():
|
||
|
config_data = self.get_config_update(name, expr)
|
||
|
if config_data:
|
||
|
has_config_kwargs = True
|
||
|
config.update(config_data)
|
||
|
|
||
|
# Handle `model_config`
|
||
|
stmt: Statement | None = None
|
||
|
for stmt in cls.defs.body:
|
||
|
if not isinstance(stmt, (AssignmentStmt, ClassDef)):
|
||
|
continue
|
||
|
|
||
|
if isinstance(stmt, AssignmentStmt):
|
||
|
lhs = stmt.lvalues[0]
|
||
|
if not isinstance(lhs, NameExpr) or lhs.name != 'model_config':
|
||
|
continue
|
||
|
|
||
|
if isinstance(stmt.rvalue, CallExpr): # calls to `dict` or `ConfigDict`
|
||
|
for arg_name, arg in zip(stmt.rvalue.arg_names, stmt.rvalue.args):
|
||
|
if arg_name is None:
|
||
|
continue
|
||
|
config.update(self.get_config_update(arg_name, arg, lax_extra=True))
|
||
|
elif isinstance(stmt.rvalue, DictExpr): # dict literals
|
||
|
for key_expr, value_expr in stmt.rvalue.items:
|
||
|
if not isinstance(key_expr, StrExpr):
|
||
|
continue
|
||
|
config.update(self.get_config_update(key_expr.value, value_expr))
|
||
|
|
||
|
elif isinstance(stmt, ClassDef):
|
||
|
if stmt.name != 'Config': # 'deprecated' Config-class
|
||
|
continue
|
||
|
for substmt in stmt.defs.body:
|
||
|
if not isinstance(substmt, AssignmentStmt):
|
||
|
continue
|
||
|
lhs = substmt.lvalues[0]
|
||
|
if not isinstance(lhs, NameExpr):
|
||
|
continue
|
||
|
config.update(self.get_config_update(lhs.name, substmt.rvalue))
|
||
|
|
||
|
if has_config_kwargs:
|
||
|
self._api.fail(
|
||
|
'Specifying config in two places is ambiguous, use either Config attribute or class kwargs',
|
||
|
cls,
|
||
|
)
|
||
|
break
|
||
|
|
||
|
has_config_from_namespace = True
|
||
|
|
||
|
if has_config_kwargs or has_config_from_namespace:
|
||
|
if (
|
||
|
stmt
|
||
|
and config.has_alias_generator
|
||
|
and not config.populate_by_name
|
||
|
and self.plugin_config.warn_required_dynamic_aliases
|
||
|
):
|
||
|
error_required_dynamic_aliases(self._api, stmt)
|
||
|
|
||
|
for info in cls.info.mro[1:]: # 0 is the current class
|
||
|
if METADATA_KEY not in info.metadata:
|
||
|
continue
|
||
|
|
||
|
# Each class depends on the set of fields in its ancestors
|
||
|
self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname))
|
||
|
for name, value in info.metadata[METADATA_KEY]['config'].items():
|
||
|
config.setdefault(name, value)
|
||
|
return config
|
||
|
|
||
|
def collect_fields_and_class_vars(
|
||
|
self, model_config: ModelConfigData, is_root_model: bool
|
||
|
) -> tuple[list[PydanticModelField] | None, list[PydanticModelClassVar] | None]:
|
||
|
"""Collects the fields for the model, accounting for parent classes."""
|
||
|
cls = self._cls
|
||
|
|
||
|
# First, collect fields and ClassVars belonging to any class in the MRO, ignoring duplicates.
|
||
|
#
|
||
|
# We iterate through the MRO in reverse because attrs defined in the parent must appear
|
||
|
# earlier in the attributes list than attrs defined in the child. See:
|
||
|
# https://docs.python.org/3/library/dataclasses.html#inheritance
|
||
|
#
|
||
|
# However, we also want fields defined in the subtype to override ones defined
|
||
|
# in the parent. We can implement this via a dict without disrupting the attr order
|
||
|
# because dicts preserve insertion order in Python 3.7+.
|
||
|
found_fields: dict[str, PydanticModelField] = {}
|
||
|
found_class_vars: dict[str, PydanticModelClassVar] = {}
|
||
|
for info in reversed(cls.info.mro[1:-1]): # 0 is the current class, -2 is BaseModel, -1 is object
|
||
|
# if BASEMODEL_METADATA_TAG_KEY in info.metadata and BASEMODEL_METADATA_KEY not in info.metadata:
|
||
|
# # We haven't processed the base class yet. Need another pass.
|
||
|
# return None, None
|
||
|
if METADATA_KEY not in info.metadata:
|
||
|
continue
|
||
|
|
||
|
# Each class depends on the set of attributes in its dataclass ancestors.
|
||
|
self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname))
|
||
|
|
||
|
for name, data in info.metadata[METADATA_KEY]['fields'].items():
|
||
|
field = PydanticModelField.deserialize(info, data, self._api)
|
||
|
# (The following comment comes directly from the dataclasses plugin)
|
||
|
# TODO: We shouldn't be performing type operations during the main
|
||
|
# semantic analysis pass, since some TypeInfo attributes might
|
||
|
# still be in flux. This should be performed in a later phase.
|
||
|
field.expand_typevar_from_subtype(cls.info, self._api)
|
||
|
found_fields[name] = field
|
||
|
|
||
|
sym_node = cls.info.names.get(name)
|
||
|
if sym_node and sym_node.node and not isinstance(sym_node.node, Var):
|
||
|
self._api.fail(
|
||
|
'BaseModel field may only be overridden by another field',
|
||
|
sym_node.node,
|
||
|
)
|
||
|
# Collect ClassVars
|
||
|
for name, data in info.metadata[METADATA_KEY]['class_vars'].items():
|
||
|
found_class_vars[name] = PydanticModelClassVar.deserialize(data)
|
||
|
|
||
|
# Second, collect fields and ClassVars belonging to the current class.
|
||
|
current_field_names: set[str] = set()
|
||
|
current_class_vars_names: set[str] = set()
|
||
|
for stmt in self._get_assignment_statements_from_block(cls.defs):
|
||
|
maybe_field = self.collect_field_or_class_var_from_stmt(stmt, model_config, found_class_vars)
|
||
|
if isinstance(maybe_field, PydanticModelField):
|
||
|
lhs = stmt.lvalues[0]
|
||
|
if is_root_model and lhs.name != 'root':
|
||
|
error_extra_fields_on_root_model(self._api, stmt)
|
||
|
else:
|
||
|
current_field_names.add(lhs.name)
|
||
|
found_fields[lhs.name] = maybe_field
|
||
|
elif isinstance(maybe_field, PydanticModelClassVar):
|
||
|
lhs = stmt.lvalues[0]
|
||
|
current_class_vars_names.add(lhs.name)
|
||
|
found_class_vars[lhs.name] = maybe_field
|
||
|
|
||
|
return list(found_fields.values()), list(found_class_vars.values())
|
||
|
|
||
|
def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]:
|
||
|
for body in stmt.body:
|
||
|
if not body.is_unreachable:
|
||
|
yield from self._get_assignment_statements_from_block(body)
|
||
|
if stmt.else_body is not None and not stmt.else_body.is_unreachable:
|
||
|
yield from self._get_assignment_statements_from_block(stmt.else_body)
|
||
|
|
||
|
def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]:
|
||
|
for stmt in block.body:
|
||
|
if isinstance(stmt, AssignmentStmt):
|
||
|
yield stmt
|
||
|
elif isinstance(stmt, IfStmt):
|
||
|
yield from self._get_assignment_statements_from_if_statement(stmt)
|
||
|
|
||
|
def collect_field_or_class_var_from_stmt( # noqa C901
|
||
|
self, stmt: AssignmentStmt, model_config: ModelConfigData, class_vars: dict[str, PydanticModelClassVar]
|
||
|
) -> PydanticModelField | PydanticModelClassVar | None:
|
||
|
"""Get pydantic model field from statement.
|
||
|
|
||
|
Args:
|
||
|
stmt: The statement.
|
||
|
model_config: Configuration settings for the model.
|
||
|
class_vars: ClassVars already known to be defined on the model.
|
||
|
|
||
|
Returns:
|
||
|
A pydantic model field if it could find the field in statement. Otherwise, `None`.
|
||
|
"""
|
||
|
cls = self._cls
|
||
|
|
||
|
lhs = stmt.lvalues[0]
|
||
|
if not isinstance(lhs, NameExpr) or not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config':
|
||
|
return None
|
||
|
|
||
|
if not stmt.new_syntax:
|
||
|
if (
|
||
|
isinstance(stmt.rvalue, CallExpr)
|
||
|
and isinstance(stmt.rvalue.callee, CallExpr)
|
||
|
and isinstance(stmt.rvalue.callee.callee, NameExpr)
|
||
|
and stmt.rvalue.callee.callee.fullname in DECORATOR_FULLNAMES
|
||
|
):
|
||
|
# This is a (possibly-reused) validator or serializer, not a field
|
||
|
# In particular, it looks something like: my_validator = validator('my_field')(f)
|
||
|
# Eventually, we may want to attempt to respect model_config['ignored_types']
|
||
|
return None
|
||
|
|
||
|
if lhs.name in class_vars:
|
||
|
# Class vars are not fields and are not required to be annotated
|
||
|
return None
|
||
|
|
||
|
# The assignment does not have an annotation, and it's not anything else we recognize
|
||
|
error_untyped_fields(self._api, stmt)
|
||
|
return None
|
||
|
|
||
|
lhs = stmt.lvalues[0]
|
||
|
if not isinstance(lhs, NameExpr):
|
||
|
return None
|
||
|
|
||
|
if not _fields.is_valid_field_name(lhs.name) or lhs.name == 'model_config':
|
||
|
return None
|
||
|
|
||
|
sym = cls.info.names.get(lhs.name)
|
||
|
if sym is None: # pragma: no cover
|
||
|
# This is likely due to a star import (see the dataclasses plugin for a more detailed explanation)
|
||
|
# This is the same logic used in the dataclasses plugin
|
||
|
return None
|
||
|
|
||
|
node = sym.node
|
||
|
if isinstance(node, PlaceholderNode): # pragma: no cover
|
||
|
# See the PlaceholderNode docstring for more detail about how this can occur
|
||
|
# Basically, it is an edge case when dealing with complex import logic
|
||
|
|
||
|
# The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does..
|
||
|
return None
|
||
|
|
||
|
if isinstance(node, TypeAlias):
|
||
|
self._api.fail(
|
||
|
'Type aliases inside BaseModel definitions are not supported at runtime',
|
||
|
node,
|
||
|
)
|
||
|
# Skip processing this node. This doesn't match the runtime behaviour,
|
||
|
# but the only alternative would be to modify the SymbolTable,
|
||
|
# and it's a little hairy to do that in a plugin.
|
||
|
return None
|
||
|
|
||
|
if not isinstance(node, Var): # pragma: no cover
|
||
|
# Don't know if this edge case still happens with the `is_valid_field` check above
|
||
|
# but better safe than sorry
|
||
|
|
||
|
# The dataclasses plugin now asserts this cannot happen, but I'd rather not error if it does..
|
||
|
return None
|
||
|
|
||
|
# x: ClassVar[int] is not a field
|
||
|
if node.is_classvar:
|
||
|
return PydanticModelClassVar(lhs.name)
|
||
|
|
||
|
# x: InitVar[int] is not supported in BaseModel
|
||
|
node_type = get_proper_type(node.type)
|
||
|
if isinstance(node_type, Instance) and node_type.type.fullname == 'dataclasses.InitVar':
|
||
|
self._api.fail(
|
||
|
'InitVar is not supported in BaseModel',
|
||
|
node,
|
||
|
)
|
||
|
|
||
|
has_default = self.get_has_default(stmt)
|
||
|
|
||
|
if sym.type is None and node.is_final and node.is_inferred:
|
||
|
# This follows the logic from the dataclasses plugin. The following comment is taken verbatim:
|
||
|
#
|
||
|
# This is a special case, assignment like x: Final = 42 is classified
|
||
|
# annotated above, but mypy strips the `Final` turning it into x = 42.
|
||
|
# We do not support inferred types in dataclasses, so we can try inferring
|
||
|
# type for simple literals, and otherwise require an explicit type
|
||
|
# argument for Final[...].
|
||
|
typ = self._api.analyze_simple_literal_type(stmt.rvalue, is_final=True)
|
||
|
if typ:
|
||
|
node.type = typ
|
||
|
else:
|
||
|
self._api.fail(
|
||
|
'Need type argument for Final[...] with non-literal default in BaseModel',
|
||
|
stmt,
|
||
|
)
|
||
|
node.type = AnyType(TypeOfAny.from_error)
|
||
|
|
||
|
alias, has_dynamic_alias = self.get_alias_info(stmt)
|
||
|
if has_dynamic_alias and not model_config.populate_by_name and self.plugin_config.warn_required_dynamic_aliases:
|
||
|
error_required_dynamic_aliases(self._api, stmt)
|
||
|
|
||
|
init_type = self._infer_dataclass_attr_init_type(sym, lhs.name, stmt)
|
||
|
return PydanticModelField(
|
||
|
name=lhs.name,
|
||
|
has_dynamic_alias=has_dynamic_alias,
|
||
|
has_default=has_default,
|
||
|
alias=alias,
|
||
|
line=stmt.line,
|
||
|
column=stmt.column,
|
||
|
type=init_type,
|
||
|
info=cls.info,
|
||
|
)
|
||
|
|
||
|
def _infer_dataclass_attr_init_type(self, sym: SymbolTableNode, name: str, context: Context) -> Type | None:
|
||
|
"""Infer __init__ argument type for an attribute.
|
||
|
|
||
|
In particular, possibly use the signature of __set__.
|
||
|
"""
|
||
|
default = sym.type
|
||
|
if sym.implicit:
|
||
|
return default
|
||
|
t = get_proper_type(sym.type)
|
||
|
|
||
|
# Perform a simple-minded inference from the signature of __set__, if present.
|
||
|
# We can't use mypy.checkmember here, since this plugin runs before type checking.
|
||
|
# We only support some basic scanerios here, which is hopefully sufficient for
|
||
|
# the vast majority of use cases.
|
||
|
if not isinstance(t, Instance):
|
||
|
return default
|
||
|
setter = t.type.get('__set__')
|
||
|
if setter:
|
||
|
if isinstance(setter.node, FuncDef):
|
||
|
super_info = t.type.get_containing_type_info('__set__')
|
||
|
assert super_info
|
||
|
if setter.type:
|
||
|
setter_type = get_proper_type(map_type_from_supertype(setter.type, t.type, super_info))
|
||
|
else:
|
||
|
return AnyType(TypeOfAny.unannotated)
|
||
|
if isinstance(setter_type, CallableType) and setter_type.arg_kinds == [
|
||
|
ARG_POS,
|
||
|
ARG_POS,
|
||
|
ARG_POS,
|
||
|
]:
|
||
|
return expand_type_by_instance(setter_type.arg_types[2], t)
|
||
|
else:
|
||
|
self._api.fail(f'Unsupported signature for "__set__" in "{t.type.name}"', context)
|
||
|
else:
|
||
|
self._api.fail(f'Unsupported "__set__" in "{t.type.name}"', context)
|
||
|
|
||
|
return default
|
||
|
|
||
|
def add_initializer(
|
||
|
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool, is_root_model: bool
|
||
|
) -> None:
|
||
|
"""Adds a fields-aware `__init__` method to the class.
|
||
|
|
||
|
The added `__init__` will be annotated with types vs. all `Any` depending on the plugin settings.
|
||
|
"""
|
||
|
if '__init__' in self._cls.info.names and not self._cls.info.names['__init__'].plugin_generated:
|
||
|
return # Don't generate an __init__ if one already exists
|
||
|
|
||
|
typed = self.plugin_config.init_typed
|
||
|
use_alias = config.populate_by_name is not True
|
||
|
requires_dynamic_aliases = bool(config.has_alias_generator and not config.populate_by_name)
|
||
|
args = self.get_field_arguments(
|
||
|
fields,
|
||
|
typed=typed,
|
||
|
requires_dynamic_aliases=requires_dynamic_aliases,
|
||
|
use_alias=use_alias,
|
||
|
is_settings=is_settings,
|
||
|
force_typevars_invariant=True,
|
||
|
)
|
||
|
|
||
|
if is_root_model and MYPY_VERSION_TUPLE <= (1, 0, 1):
|
||
|
# convert root argument to positional argument
|
||
|
# This is needed because mypy support for `dataclass_transform` isn't complete on 1.0.1
|
||
|
args[0].kind = ARG_POS if args[0].kind == ARG_NAMED else ARG_OPT
|
||
|
|
||
|
if is_settings:
|
||
|
base_settings_node = self._api.lookup_fully_qualified(BASESETTINGS_FULLNAME).node
|
||
|
if '__init__' in base_settings_node.names:
|
||
|
base_settings_init_node = base_settings_node.names['__init__'].node
|
||
|
if base_settings_init_node is not None and base_settings_init_node.type is not None:
|
||
|
func_type = base_settings_init_node.type
|
||
|
for arg_idx, arg_name in enumerate(func_type.arg_names):
|
||
|
if arg_name.startswith('__') or not arg_name.startswith('_'):
|
||
|
continue
|
||
|
analyzed_variable_type = self._api.anal_type(func_type.arg_types[arg_idx])
|
||
|
variable = Var(arg_name, analyzed_variable_type)
|
||
|
args.append(Argument(variable, analyzed_variable_type, None, ARG_OPT))
|
||
|
|
||
|
if not self.should_init_forbid_extra(fields, config):
|
||
|
var = Var('kwargs')
|
||
|
args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
|
||
|
|
||
|
add_method(self._api, self._cls, '__init__', args=args, return_type=NoneType())
|
||
|
|
||
|
def add_model_construct_method(
|
||
|
self, fields: list[PydanticModelField], config: ModelConfigData, is_settings: bool
|
||
|
) -> None:
|
||
|
"""Adds a fully typed `model_construct` classmethod to the class.
|
||
|
|
||
|
Similar to the fields-aware __init__ method, but always uses the field names (not aliases),
|
||
|
and does not treat settings fields as optional.
|
||
|
"""
|
||
|
set_str = self._api.named_type(f'{BUILTINS_NAME}.set', [self._api.named_type(f'{BUILTINS_NAME}.str')])
|
||
|
optional_set_str = UnionType([set_str, NoneType()])
|
||
|
fields_set_argument = Argument(Var('_fields_set', optional_set_str), optional_set_str, None, ARG_OPT)
|
||
|
with state.strict_optional_set(self._api.options.strict_optional):
|
||
|
args = self.get_field_arguments(
|
||
|
fields, typed=True, requires_dynamic_aliases=False, use_alias=False, is_settings=is_settings
|
||
|
)
|
||
|
if not self.should_init_forbid_extra(fields, config):
|
||
|
var = Var('kwargs')
|
||
|
args.append(Argument(var, AnyType(TypeOfAny.explicit), None, ARG_STAR2))
|
||
|
|
||
|
args = [fields_set_argument] + args
|
||
|
|
||
|
add_method(
|
||
|
self._api,
|
||
|
self._cls,
|
||
|
'model_construct',
|
||
|
args=args,
|
||
|
return_type=fill_typevars(self._cls.info),
|
||
|
is_classmethod=True,
|
||
|
)
|
||
|
|
||
|
def set_frozen(self, fields: list[PydanticModelField], api: SemanticAnalyzerPluginInterface, frozen: bool) -> None:
|
||
|
"""Marks all fields as properties so that attempts to set them trigger mypy errors.
|
||
|
|
||
|
This is the same approach used by the attrs and dataclasses plugins.
|
||
|
"""
|
||
|
info = self._cls.info
|
||
|
for field in fields:
|
||
|
sym_node = info.names.get(field.name)
|
||
|
if sym_node is not None:
|
||
|
var = sym_node.node
|
||
|
if isinstance(var, Var):
|
||
|
var.is_property = frozen
|
||
|
elif isinstance(var, PlaceholderNode) and not self._api.final_iteration:
|
||
|
# See https://github.com/pydantic/pydantic/issues/5191 to hit this branch for test coverage
|
||
|
self._api.defer()
|
||
|
else: # pragma: no cover
|
||
|
# I don't know whether it's possible to hit this branch, but I've added it for safety
|
||
|
try:
|
||
|
var_str = str(var)
|
||
|
except TypeError:
|
||
|
# This happens for PlaceholderNode; perhaps it will happen for other types in the future..
|
||
|
var_str = repr(var)
|
||
|
detail = f'sym_node.node: {var_str} (of type {var.__class__})'
|
||
|
error_unexpected_behavior(detail, self._api, self._cls)
|
||
|
else:
|
||
|
var = field.to_var(info, api, use_alias=False)
|
||
|
var.info = info
|
||
|
var.is_property = frozen
|
||
|
var._fullname = info.fullname + '.' + var.name
|
||
|
info.names[var.name] = SymbolTableNode(MDEF, var)
|
||
|
|
||
|
def get_config_update(self, name: str, arg: Expression, lax_extra: bool = False) -> ModelConfigData | None:
|
||
|
"""Determines the config update due to a single kwarg in the ConfigDict definition.
|
||
|
|
||
|
Warns if a tracked config attribute is set to a value the plugin doesn't know how to interpret (e.g., an int)
|
||
|
"""
|
||
|
if name not in self.tracked_config_fields:
|
||
|
return None
|
||
|
if name == 'extra':
|
||
|
if isinstance(arg, StrExpr):
|
||
|
forbid_extra = arg.value == 'forbid'
|
||
|
elif isinstance(arg, MemberExpr):
|
||
|
forbid_extra = arg.name == 'forbid'
|
||
|
else:
|
||
|
if not lax_extra:
|
||
|
# Only emit an error for other types of `arg` (e.g., `NameExpr`, `ConditionalExpr`, etc.) when
|
||
|
# reading from a config class, etc. If a ConfigDict is used, then we don't want to emit an error
|
||
|
# because you'll get type checking from the ConfigDict itself.
|
||
|
#
|
||
|
# It would be nice if we could introspect the types better otherwise, but I don't know what the API
|
||
|
# is to evaluate an expr into its type and then check if that type is compatible with the expected
|
||
|
# type. Note that you can still get proper type checking via: `model_config = ConfigDict(...)`, just
|
||
|
# if you don't use an explicit string, the plugin won't be able to infer whether extra is forbidden.
|
||
|
error_invalid_config_value(name, self._api, arg)
|
||
|
return None
|
||
|
return ModelConfigData(forbid_extra=forbid_extra)
|
||
|
if name == 'alias_generator':
|
||
|
has_alias_generator = True
|
||
|
if isinstance(arg, NameExpr) and arg.fullname == 'builtins.None':
|
||
|
has_alias_generator = False
|
||
|
return ModelConfigData(has_alias_generator=has_alias_generator)
|
||
|
if isinstance(arg, NameExpr) and arg.fullname in ('builtins.True', 'builtins.False'):
|
||
|
return ModelConfigData(**{name: arg.fullname == 'builtins.True'})
|
||
|
error_invalid_config_value(name, self._api, arg)
|
||
|
return None
|
||
|
|
||
|
@staticmethod
|
||
|
def get_has_default(stmt: AssignmentStmt) -> bool:
|
||
|
"""Returns a boolean indicating whether the field defined in `stmt` is a required field."""
|
||
|
expr = stmt.rvalue
|
||
|
if isinstance(expr, TempNode):
|
||
|
# TempNode means annotation-only, so has no default
|
||
|
return False
|
||
|
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME:
|
||
|
# The "default value" is a call to `Field`; at this point, the field has a default if and only if:
|
||
|
# * there is a positional argument that is not `...`
|
||
|
# * there is a keyword argument named "default" that is not `...`
|
||
|
# * there is a "default_factory" that is not `None`
|
||
|
for arg, name in zip(expr.args, expr.arg_names):
|
||
|
# If name is None, then this arg is the default because it is the only positional argument.
|
||
|
if name is None or name == 'default':
|
||
|
return arg.__class__ is not EllipsisExpr
|
||
|
if name == 'default_factory':
|
||
|
return not (isinstance(arg, NameExpr) and arg.fullname == 'builtins.None')
|
||
|
return False
|
||
|
# Has no default if the "default value" is Ellipsis (i.e., `field_name: Annotation = ...`)
|
||
|
return not isinstance(expr, EllipsisExpr)
|
||
|
|
||
|
@staticmethod
|
||
|
def get_alias_info(stmt: AssignmentStmt) -> tuple[str | None, bool]:
|
||
|
"""Returns a pair (alias, has_dynamic_alias), extracted from the declaration of the field defined in `stmt`.
|
||
|
|
||
|
`has_dynamic_alias` is True if and only if an alias is provided, but not as a string literal.
|
||
|
If `has_dynamic_alias` is True, `alias` will be None.
|
||
|
"""
|
||
|
expr = stmt.rvalue
|
||
|
if isinstance(expr, TempNode):
|
||
|
# TempNode means annotation-only
|
||
|
return None, False
|
||
|
|
||
|
if not (
|
||
|
isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr) and expr.callee.fullname == FIELD_FULLNAME
|
||
|
):
|
||
|
# Assigned value is not a call to pydantic.fields.Field
|
||
|
return None, False
|
||
|
|
||
|
for i, arg_name in enumerate(expr.arg_names):
|
||
|
if arg_name != 'alias':
|
||
|
continue
|
||
|
arg = expr.args[i]
|
||
|
if isinstance(arg, StrExpr):
|
||
|
return arg.value, False
|
||
|
else:
|
||
|
return None, True
|
||
|
return None, False
|
||
|
|
||
|
def get_field_arguments(
|
||
|
self,
|
||
|
fields: list[PydanticModelField],
|
||
|
typed: bool,
|
||
|
use_alias: bool,
|
||
|
requires_dynamic_aliases: bool,
|
||
|
is_settings: bool,
|
||
|
force_typevars_invariant: bool = False,
|
||
|
) -> list[Argument]:
|
||
|
"""Helper function used during the construction of the `__init__` and `model_construct` method signatures.
|
||
|
|
||
|
Returns a list of mypy Argument instances for use in the generated signatures.
|
||
|
"""
|
||
|
info = self._cls.info
|
||
|
arguments = [
|
||
|
field.to_argument(
|
||
|
info,
|
||
|
typed=typed,
|
||
|
force_optional=requires_dynamic_aliases or is_settings,
|
||
|
use_alias=use_alias,
|
||
|
api=self._api,
|
||
|
force_typevars_invariant=force_typevars_invariant,
|
||
|
)
|
||
|
for field in fields
|
||
|
if not (use_alias and field.has_dynamic_alias)
|
||
|
]
|
||
|
return arguments
|
||
|
|
||
|
def should_init_forbid_extra(self, fields: list[PydanticModelField], config: ModelConfigData) -> bool:
|
||
|
"""Indicates whether the generated `__init__` should get a `**kwargs` at the end of its signature.
|
||
|
|
||
|
We disallow arbitrary kwargs if the extra config setting is "forbid", or if the plugin config says to,
|
||
|
*unless* a required dynamic alias is present (since then we can't determine a valid signature).
|
||
|
"""
|
||
|
if not config.populate_by_name:
|
||
|
if self.is_dynamic_alias_present(fields, bool(config.has_alias_generator)):
|
||
|
return False
|
||
|
if config.forbid_extra:
|
||
|
return True
|
||
|
return self.plugin_config.init_forbid_extra
|
||
|
|
||
|
@staticmethod
|
||
|
def is_dynamic_alias_present(fields: list[PydanticModelField], has_alias_generator: bool) -> bool:
|
||
|
"""Returns whether any fields on the model have a "dynamic alias", i.e., an alias that cannot be
|
||
|
determined during static analysis.
|
||
|
"""
|
||
|
for field in fields:
|
||
|
if field.has_dynamic_alias:
|
||
|
return True
|
||
|
if has_alias_generator:
|
||
|
for field in fields:
|
||
|
if field.alias is None:
|
||
|
return True
|
||
|
return False
|
||
|
|
||
|
|
||
|
class ModelConfigData:
|
||
|
"""Pydantic mypy plugin model config class."""
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
forbid_extra: bool | None = None,
|
||
|
frozen: bool | None = None,
|
||
|
from_attributes: bool | None = None,
|
||
|
populate_by_name: bool | None = None,
|
||
|
has_alias_generator: bool | None = None,
|
||
|
):
|
||
|
self.forbid_extra = forbid_extra
|
||
|
self.frozen = frozen
|
||
|
self.from_attributes = from_attributes
|
||
|
self.populate_by_name = populate_by_name
|
||
|
self.has_alias_generator = has_alias_generator
|
||
|
|
||
|
def get_values_dict(self) -> dict[str, Any]:
|
||
|
"""Returns a dict of Pydantic model config names to their values.
|
||
|
|
||
|
It includes the config if config value is not `None`.
|
||
|
"""
|
||
|
return {k: v for k, v in self.__dict__.items() if v is not None}
|
||
|
|
||
|
def update(self, config: ModelConfigData | None) -> None:
|
||
|
"""Update Pydantic model config values."""
|
||
|
if config is None:
|
||
|
return
|
||
|
for k, v in config.get_values_dict().items():
|
||
|
setattr(self, k, v)
|
||
|
|
||
|
def setdefault(self, key: str, value: Any) -> None:
|
||
|
"""Set default value for Pydantic model config if config value is `None`."""
|
||
|
if getattr(self, key) is None:
|
||
|
setattr(self, key, value)
|
||
|
|
||
|
|
||
|
ERROR_ORM = ErrorCode('pydantic-orm', 'Invalid from_attributes call', 'Pydantic')
|
||
|
ERROR_CONFIG = ErrorCode('pydantic-config', 'Invalid config value', 'Pydantic')
|
||
|
ERROR_ALIAS = ErrorCode('pydantic-alias', 'Dynamic alias disallowed', 'Pydantic')
|
||
|
ERROR_UNEXPECTED = ErrorCode('pydantic-unexpected', 'Unexpected behavior', 'Pydantic')
|
||
|
ERROR_UNTYPED = ErrorCode('pydantic-field', 'Untyped field disallowed', 'Pydantic')
|
||
|
ERROR_FIELD_DEFAULTS = ErrorCode('pydantic-field', 'Invalid Field defaults', 'Pydantic')
|
||
|
ERROR_EXTRA_FIELD_ROOT_MODEL = ErrorCode('pydantic-field', 'Extra field on RootModel subclass', 'Pydantic')
|
||
|
|
||
|
|
||
|
def error_from_attributes(model_name: str, api: CheckerPluginInterface, context: Context) -> None:
|
||
|
"""Emits an error when the model does not have `from_attributes=True`."""
|
||
|
api.fail(f'"{model_name}" does not have from_attributes=True', context, code=ERROR_ORM)
|
||
|
|
||
|
|
||
|
def error_invalid_config_value(name: str, api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||
|
"""Emits an error when the config value is invalid."""
|
||
|
api.fail(f'Invalid value for "Config.{name}"', context, code=ERROR_CONFIG)
|
||
|
|
||
|
|
||
|
def error_required_dynamic_aliases(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||
|
"""Emits required dynamic aliases error.
|
||
|
|
||
|
This will be called when `warn_required_dynamic_aliases=True`.
|
||
|
"""
|
||
|
api.fail('Required dynamic aliases disallowed', context, code=ERROR_ALIAS)
|
||
|
|
||
|
|
||
|
def error_unexpected_behavior(
|
||
|
detail: str, api: CheckerPluginInterface | SemanticAnalyzerPluginInterface, context: Context
|
||
|
) -> None: # pragma: no cover
|
||
|
"""Emits unexpected behavior error."""
|
||
|
# Can't think of a good way to test this, but I confirmed it renders as desired by adding to a non-error path
|
||
|
link = 'https://github.com/pydantic/pydantic/issues/new/choose'
|
||
|
full_message = f'The pydantic mypy plugin ran into unexpected behavior: {detail}\n'
|
||
|
full_message += f'Please consider reporting this bug at {link} so we can try to fix it!'
|
||
|
api.fail(full_message, context, code=ERROR_UNEXPECTED)
|
||
|
|
||
|
|
||
|
def error_untyped_fields(api: SemanticAnalyzerPluginInterface, context: Context) -> None:
|
||
|
"""Emits an error when there is an untyped field in the model."""
|
||
|
api.fail('Untyped fields disallowed', context, code=ERROR_UNTYPED)
|
||
|
|
||
|
|
||
|
def error_extra_fields_on_root_model(api: CheckerPluginInterface, context: Context) -> None:
|
||
|
"""Emits an error when there is more than just a root field defined for a subclass of RootModel."""
|
||
|
api.fail('Only `root` is allowed as a field of a `RootModel`', context, code=ERROR_EXTRA_FIELD_ROOT_MODEL)
|
||
|
|
||
|
|
||
|
def error_default_and_default_factory_specified(api: CheckerPluginInterface, context: Context) -> None:
|
||
|
"""Emits an error when `Field` has both `default` and `default_factory` together."""
|
||
|
api.fail('Field default and default_factory cannot be specified together', context, code=ERROR_FIELD_DEFAULTS)
|
||
|
|
||
|
|
||
|
def add_method(
|
||
|
api: SemanticAnalyzerPluginInterface | CheckerPluginInterface,
|
||
|
cls: ClassDef,
|
||
|
name: str,
|
||
|
args: list[Argument],
|
||
|
return_type: Type,
|
||
|
self_type: Type | None = None,
|
||
|
tvar_def: TypeVarDef | None = None,
|
||
|
is_classmethod: bool = False,
|
||
|
) -> None:
|
||
|
"""Very closely related to `mypy.plugins.common.add_method_to_class`, with a few pydantic-specific changes."""
|
||
|
info = cls.info
|
||
|
|
||
|
# First remove any previously generated methods with the same name
|
||
|
# to avoid clashes and problems in the semantic analyzer.
|
||
|
if name in info.names:
|
||
|
sym = info.names[name]
|
||
|
if sym.plugin_generated and isinstance(sym.node, FuncDef):
|
||
|
cls.defs.body.remove(sym.node) # pragma: no cover
|
||
|
|
||
|
if isinstance(api, SemanticAnalyzerPluginInterface):
|
||
|
function_type = api.named_type('builtins.function')
|
||
|
else:
|
||
|
function_type = api.named_generic_type('builtins.function', [])
|
||
|
|
||
|
if is_classmethod:
|
||
|
self_type = self_type or TypeType(fill_typevars(info))
|
||
|
first = [Argument(Var('_cls'), self_type, None, ARG_POS, True)]
|
||
|
else:
|
||
|
self_type = self_type or fill_typevars(info)
|
||
|
# `self` is positional *ONLY* here, but this can't be expressed
|
||
|
# fully in the mypy internal API. ARG_POS is the closest we can get.
|
||
|
# Using ARG_POS will, however, give mypy errors if a `self` field
|
||
|
# is present on a model:
|
||
|
#
|
||
|
# Name "self" already defined (possibly by an import) [no-redef]
|
||
|
#
|
||
|
# As a workaround, we give this argument a name that will
|
||
|
# never conflict. By its positional nature, this name will not
|
||
|
# be used or exposed to users.
|
||
|
first = [Argument(Var('__pydantic_self__'), self_type, None, ARG_POS)]
|
||
|
args = first + args
|
||
|
|
||
|
arg_types, arg_names, arg_kinds = [], [], []
|
||
|
for arg in args:
|
||
|
assert arg.type_annotation, 'All arguments must be fully typed.'
|
||
|
arg_types.append(arg.type_annotation)
|
||
|
arg_names.append(arg.variable.name)
|
||
|
arg_kinds.append(arg.kind)
|
||
|
|
||
|
signature = CallableType(arg_types, arg_kinds, arg_names, return_type, function_type)
|
||
|
if tvar_def:
|
||
|
signature.variables = [tvar_def]
|
||
|
|
||
|
func = FuncDef(name, args, Block([PassStmt()]))
|
||
|
func.info = info
|
||
|
func.type = set_callable_name(signature, func)
|
||
|
func.is_class = is_classmethod
|
||
|
func._fullname = info.fullname + '.' + name
|
||
|
func.line = info.line
|
||
|
|
||
|
# NOTE: we would like the plugin generated node to dominate, but we still
|
||
|
# need to keep any existing definitions so they get semantically analyzed.
|
||
|
if name in info.names:
|
||
|
# Get a nice unique name instead.
|
||
|
r_name = get_unique_redefinition_name(name, info.names)
|
||
|
info.names[r_name] = info.names[name]
|
||
|
|
||
|
# Add decorator for is_classmethod
|
||
|
# The dataclasses plugin claims this is unnecessary for classmethods, but not including it results in a
|
||
|
# signature incompatible with the superclass, which causes mypy errors to occur for every subclass of BaseModel.
|
||
|
if is_classmethod:
|
||
|
func.is_decorated = True
|
||
|
v = Var(name, func.type)
|
||
|
v.info = info
|
||
|
v._fullname = func._fullname
|
||
|
v.is_classmethod = True
|
||
|
dec = Decorator(func, [NameExpr('classmethod')], v)
|
||
|
dec.line = info.line
|
||
|
sym = SymbolTableNode(MDEF, dec)
|
||
|
else:
|
||
|
sym = SymbolTableNode(MDEF, func)
|
||
|
sym.plugin_generated = True
|
||
|
info.names[name] = sym
|
||
|
|
||
|
info.defn.defs.body.append(func)
|
||
|
|
||
|
|
||
|
def parse_toml(config_file: str) -> dict[str, Any] | None:
|
||
|
"""Returns a dict of config keys to values.
|
||
|
|
||
|
It reads configs from toml file and returns `None` if the file is not a toml file.
|
||
|
"""
|
||
|
if not config_file.endswith('.toml'):
|
||
|
return None
|
||
|
|
||
|
if sys.version_info >= (3, 11):
|
||
|
import tomllib as toml_
|
||
|
else:
|
||
|
try:
|
||
|
import tomli as toml_
|
||
|
except ImportError: # pragma: no cover
|
||
|
import warnings
|
||
|
|
||
|
warnings.warn('No TOML parser installed, cannot read configuration from `pyproject.toml`.')
|
||
|
return None
|
||
|
|
||
|
with open(config_file, 'rb') as rf:
|
||
|
return toml_.load(rf)
|