_fields.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. """Private logic related to fields (the `Field()` function and `FieldInfo` class), and arguments to `Annotated`."""
  2. from __future__ import annotations as _annotations
  3. import dataclasses
  4. import warnings
  5. from copy import copy
  6. from functools import lru_cache
  7. from inspect import Parameter, ismethoddescriptor, signature
  8. from typing import TYPE_CHECKING, Any, Callable, Pattern
  9. from pydantic_core import PydanticUndefined
  10. from typing_extensions import TypeIs
  11. from pydantic.errors import PydanticUserError
  12. from . import _typing_extra
  13. from ._config import ConfigWrapper
  14. from ._docs_extraction import extract_docstrings_from_cls
  15. from ._import_utils import import_cached_base_model, import_cached_field_info
  16. from ._namespace_utils import NsResolver
  17. from ._repr import Representation
  18. from ._utils import can_be_positional
  19. if TYPE_CHECKING:
  20. from annotated_types import BaseMetadata
  21. from ..fields import FieldInfo
  22. from ..main import BaseModel
  23. from ._dataclasses import StandardDataclass
  24. from ._decorators import DecoratorInfos
  25. class PydanticMetadata(Representation):
  26. """Base class for annotation markers like `Strict`."""
  27. __slots__ = ()
  28. def pydantic_general_metadata(**metadata: Any) -> BaseMetadata:
  29. """Create a new `_PydanticGeneralMetadata` class with the given metadata.
  30. Args:
  31. **metadata: The metadata to add.
  32. Returns:
  33. The new `_PydanticGeneralMetadata` class.
  34. """
  35. return _general_metadata_cls()(metadata) # type: ignore
  36. @lru_cache(maxsize=None)
  37. def _general_metadata_cls() -> type[BaseMetadata]:
  38. """Do it this way to avoid importing `annotated_types` at import time."""
  39. from annotated_types import BaseMetadata
  40. class _PydanticGeneralMetadata(PydanticMetadata, BaseMetadata):
  41. """Pydantic general metadata like `max_digits`."""
  42. def __init__(self, metadata: Any):
  43. self.__dict__ = metadata
  44. return _PydanticGeneralMetadata # type: ignore
  45. def _update_fields_from_docstrings(cls: type[Any], fields: dict[str, FieldInfo], config_wrapper: ConfigWrapper) -> None:
  46. if config_wrapper.use_attribute_docstrings:
  47. fields_docs = extract_docstrings_from_cls(cls)
  48. for ann_name, field_info in fields.items():
  49. if field_info.description is None and ann_name in fields_docs:
  50. field_info.description = fields_docs[ann_name]
  51. def collect_model_fields( # noqa: C901
  52. cls: type[BaseModel],
  53. bases: tuple[type[Any], ...],
  54. config_wrapper: ConfigWrapper,
  55. ns_resolver: NsResolver | None,
  56. *,
  57. typevars_map: dict[Any, Any] | None = None,
  58. ) -> tuple[dict[str, FieldInfo], set[str]]:
  59. """Collect the fields of a nascent pydantic model.
  60. Also collect the names of any ClassVars present in the type hints.
  61. The returned value is a tuple of two items: the fields dict, and the set of ClassVar names.
  62. Args:
  63. cls: BaseModel or dataclass.
  64. bases: Parents of the class, generally `cls.__bases__`.
  65. config_wrapper: The config wrapper instance.
  66. ns_resolver: Namespace resolver to use when getting model annotations.
  67. typevars_map: A dictionary mapping type variables to their concrete types.
  68. Returns:
  69. A tuple contains fields and class variables.
  70. Raises:
  71. NameError:
  72. - If there is a conflict between a field name and protected namespaces.
  73. - If there is a field other than `root` in `RootModel`.
  74. - If a field shadows an attribute in the parent model.
  75. """
  76. BaseModel = import_cached_base_model()
  77. FieldInfo_ = import_cached_field_info()
  78. parent_fields_lookup: dict[str, FieldInfo] = {}
  79. for base in reversed(bases):
  80. if model_fields := getattr(base, '__pydantic_fields__', None):
  81. parent_fields_lookup.update(model_fields)
  82. type_hints = _typing_extra.get_model_type_hints(cls, ns_resolver=ns_resolver)
  83. # https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
  84. # annotations is only used for finding fields in parent classes
  85. annotations = cls.__dict__.get('__annotations__', {})
  86. fields: dict[str, FieldInfo] = {}
  87. class_vars: set[str] = set()
  88. for ann_name, (ann_type, evaluated) in type_hints.items():
  89. if ann_name == 'model_config':
  90. # We never want to treat `model_config` as a field
  91. # Note: we may need to change this logic if/when we introduce a `BareModel` class with no
  92. # protected namespaces (where `model_config` might be allowed as a field name)
  93. continue
  94. for protected_namespace in config_wrapper.protected_namespaces:
  95. ns_violation: bool = False
  96. if isinstance(protected_namespace, Pattern):
  97. ns_violation = protected_namespace.match(ann_name) is not None
  98. elif isinstance(protected_namespace, str):
  99. ns_violation = ann_name.startswith(protected_namespace)
  100. if ns_violation:
  101. for b in bases:
  102. if hasattr(b, ann_name):
  103. if not (issubclass(b, BaseModel) and ann_name in getattr(b, '__pydantic_fields__', {})):
  104. raise NameError(
  105. f'Field "{ann_name}" conflicts with member {getattr(b, ann_name)}'
  106. f' of protected namespace "{protected_namespace}".'
  107. )
  108. else:
  109. valid_namespaces = ()
  110. for pn in config_wrapper.protected_namespaces:
  111. if isinstance(pn, Pattern):
  112. if not pn.match(ann_name):
  113. valid_namespaces += (f're.compile({pn.pattern})',)
  114. else:
  115. if not ann_name.startswith(pn):
  116. valid_namespaces += (pn,)
  117. warnings.warn(
  118. f'Field "{ann_name}" in {cls.__name__} has conflict with protected namespace "{protected_namespace}".'
  119. '\n\nYou may be able to resolve this warning by setting'
  120. f" `model_config['protected_namespaces'] = {valid_namespaces}`.",
  121. UserWarning,
  122. )
  123. if _typing_extra.is_classvar_annotation(ann_type):
  124. class_vars.add(ann_name)
  125. continue
  126. if _is_finalvar_with_default_val(ann_type, getattr(cls, ann_name, PydanticUndefined)):
  127. class_vars.add(ann_name)
  128. continue
  129. if not is_valid_field_name(ann_name):
  130. continue
  131. if cls.__pydantic_root_model__ and ann_name != 'root':
  132. raise NameError(
  133. f"Unexpected field with name {ann_name!r}; only 'root' is allowed as a field of a `RootModel`"
  134. )
  135. # when building a generic model with `MyModel[int]`, the generic_origin check makes sure we don't get
  136. # "... shadows an attribute" warnings
  137. generic_origin = getattr(cls, '__pydantic_generic_metadata__', {}).get('origin')
  138. for base in bases:
  139. dataclass_fields = {
  140. field.name for field in (dataclasses.fields(base) if dataclasses.is_dataclass(base) else ())
  141. }
  142. if hasattr(base, ann_name):
  143. if base is generic_origin:
  144. # Don't warn when "shadowing" of attributes in parametrized generics
  145. continue
  146. if ann_name in dataclass_fields:
  147. # Don't warn when inheriting stdlib dataclasses whose fields are "shadowed" by defaults being set
  148. # on the class instance.
  149. continue
  150. if ann_name not in annotations:
  151. # Don't warn when a field exists in a parent class but has not been defined in the current class
  152. continue
  153. warnings.warn(
  154. f'Field name "{ann_name}" in "{cls.__qualname__}" shadows an attribute in parent '
  155. f'"{base.__qualname__}"',
  156. UserWarning,
  157. )
  158. try:
  159. default = getattr(cls, ann_name, PydanticUndefined)
  160. if default is PydanticUndefined:
  161. raise AttributeError
  162. except AttributeError:
  163. if ann_name in annotations:
  164. field_info = FieldInfo_.from_annotation(ann_type)
  165. field_info.evaluated = evaluated
  166. else:
  167. # if field has no default value and is not in __annotations__ this means that it is
  168. # defined in a base class and we can take it from there
  169. if ann_name in parent_fields_lookup:
  170. # The field was present on one of the (possibly multiple) base classes
  171. # copy the field to make sure typevar substitutions don't cause issues with the base classes
  172. field_info = copy(parent_fields_lookup[ann_name])
  173. else:
  174. # The field was not found on any base classes; this seems to be caused by fields not getting
  175. # generated thanks to models not being fully defined while initializing recursive models.
  176. # Nothing stops us from just creating a new FieldInfo for this type hint, so we do this.
  177. field_info = FieldInfo_.from_annotation(ann_type)
  178. field_info.evaluated = evaluated
  179. else:
  180. _warn_on_nested_alias_in_annotation(ann_type, ann_name)
  181. if isinstance(default, FieldInfo_) and ismethoddescriptor(default.default):
  182. # the `getattr` call above triggers a call to `__get__` for descriptors, so we do
  183. # the same if the `= field(default=...)` form is used. Note that we only do this
  184. # for method descriptors for now, we might want to extend this to any descriptor
  185. # in the future (by simply checking for `hasattr(default.default, '__get__')`).
  186. default.default = default.default.__get__(None, cls)
  187. field_info = FieldInfo_.from_annotated_attribute(ann_type, default)
  188. field_info.evaluated = evaluated
  189. # attributes which are fields are removed from the class namespace:
  190. # 1. To match the behaviour of annotation-only fields
  191. # 2. To avoid false positives in the NameError check above
  192. try:
  193. delattr(cls, ann_name)
  194. except AttributeError:
  195. pass # indicates the attribute was on a parent class
  196. # Use cls.__dict__['__pydantic_decorators__'] instead of cls.__pydantic_decorators__
  197. # to make sure the decorators have already been built for this exact class
  198. decorators: DecoratorInfos = cls.__dict__['__pydantic_decorators__']
  199. if ann_name in decorators.computed_fields:
  200. raise ValueError("you can't override a field with a computed field")
  201. fields[ann_name] = field_info
  202. if typevars_map:
  203. for field in fields.values():
  204. field.apply_typevars_map(typevars_map)
  205. _update_fields_from_docstrings(cls, fields, config_wrapper)
  206. return fields, class_vars
  207. def _warn_on_nested_alias_in_annotation(ann_type: type[Any], ann_name: str) -> None:
  208. FieldInfo = import_cached_field_info()
  209. args = getattr(ann_type, '__args__', None)
  210. if args:
  211. for anno_arg in args:
  212. if _typing_extra.is_annotated(anno_arg):
  213. for anno_type_arg in _typing_extra.get_args(anno_arg):
  214. if isinstance(anno_type_arg, FieldInfo) and anno_type_arg.alias is not None:
  215. warnings.warn(
  216. f'`alias` specification on field "{ann_name}" must be set on outermost annotation to take effect.',
  217. UserWarning,
  218. )
  219. return
  220. def _is_finalvar_with_default_val(type_: type[Any], val: Any) -> bool:
  221. FieldInfo = import_cached_field_info()
  222. if not _typing_extra.is_finalvar(type_):
  223. return False
  224. elif val is PydanticUndefined:
  225. return False
  226. elif isinstance(val, FieldInfo) and (val.default is PydanticUndefined and val.default_factory is None):
  227. return False
  228. else:
  229. return True
  230. def collect_dataclass_fields(
  231. cls: type[StandardDataclass],
  232. *,
  233. ns_resolver: NsResolver | None = None,
  234. typevars_map: dict[Any, Any] | None = None,
  235. config_wrapper: ConfigWrapper | None = None,
  236. ) -> dict[str, FieldInfo]:
  237. """Collect the fields of a dataclass.
  238. Args:
  239. cls: dataclass.
  240. ns_resolver: Namespace resolver to use when getting dataclass annotations.
  241. Defaults to an empty instance.
  242. typevars_map: A dictionary mapping type variables to their concrete types.
  243. config_wrapper: The config wrapper instance.
  244. Returns:
  245. The dataclass fields.
  246. """
  247. FieldInfo_ = import_cached_field_info()
  248. fields: dict[str, FieldInfo] = {}
  249. ns_resolver = ns_resolver or NsResolver()
  250. dataclass_fields = cls.__dataclass_fields__
  251. # The logic here is similar to `_typing_extra.get_cls_type_hints`,
  252. # although we do it manually as stdlib dataclasses already have annotations
  253. # collected in each class:
  254. for base in reversed(cls.__mro__):
  255. if not dataclasses.is_dataclass(base):
  256. continue
  257. with ns_resolver.push(base):
  258. for ann_name, dataclass_field in dataclass_fields.items():
  259. if ann_name not in base.__dict__.get('__annotations__', {}):
  260. # `__dataclass_fields__`contains every field, even the ones from base classes.
  261. # Only collect the ones defined on `base`.
  262. continue
  263. globalns, localns = ns_resolver.types_namespace
  264. ann_type, _ = _typing_extra.try_eval_type(dataclass_field.type, globalns, localns)
  265. if _typing_extra.is_classvar_annotation(ann_type):
  266. continue
  267. if (
  268. not dataclass_field.init
  269. and dataclass_field.default is dataclasses.MISSING
  270. and dataclass_field.default_factory is dataclasses.MISSING
  271. ):
  272. # TODO: We should probably do something with this so that validate_assignment behaves properly
  273. # Issue: https://github.com/pydantic/pydantic/issues/5470
  274. continue
  275. if isinstance(dataclass_field.default, FieldInfo_):
  276. if dataclass_field.default.init_var:
  277. if dataclass_field.default.init is False:
  278. raise PydanticUserError(
  279. f'Dataclass field {ann_name} has init=False and init_var=True, but these are mutually exclusive.',
  280. code='clashing-init-and-init-var',
  281. )
  282. # TODO: same note as above re validate_assignment
  283. continue
  284. field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field.default)
  285. else:
  286. field_info = FieldInfo_.from_annotated_attribute(ann_type, dataclass_field)
  287. fields[ann_name] = field_info
  288. if field_info.default is not PydanticUndefined and isinstance(
  289. getattr(cls, ann_name, field_info), FieldInfo_
  290. ):
  291. # We need this to fix the default when the "default" from __dataclass_fields__ is a pydantic.FieldInfo
  292. setattr(cls, ann_name, field_info.default)
  293. if typevars_map:
  294. for field in fields.values():
  295. # We don't pass any ns, as `field.annotation`
  296. # was already evaluated. TODO: is this method relevant?
  297. # Can't we juste use `_generics.replace_types`?
  298. field.apply_typevars_map(typevars_map)
  299. if config_wrapper is not None:
  300. _update_fields_from_docstrings(cls, fields, config_wrapper)
  301. return fields
  302. def is_valid_field_name(name: str) -> bool:
  303. return not name.startswith('_')
  304. def is_valid_privateattr_name(name: str) -> bool:
  305. return name.startswith('_') and not name.startswith('__')
  306. def takes_validated_data_argument(
  307. default_factory: Callable[[], Any] | Callable[[dict[str, Any]], Any],
  308. ) -> TypeIs[Callable[[dict[str, Any]], Any]]:
  309. """Whether the provided default factory callable has a validated data parameter."""
  310. try:
  311. sig = signature(default_factory)
  312. except (ValueError, TypeError):
  313. # `inspect.signature` might not be able to infer a signature, e.g. with C objects.
  314. # In this case, we assume no data argument is present:
  315. return False
  316. parameters = list(sig.parameters.values())
  317. return len(parameters) == 1 and can_be_positional(parameters[0]) and parameters[0].default is Parameter.empty