_generics.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. from __future__ import annotations
  2. import sys
  3. import types
  4. import typing
  5. from collections import ChainMap
  6. from contextlib import contextmanager
  7. from contextvars import ContextVar
  8. from types import prepare_class
  9. from typing import TYPE_CHECKING, Any, Iterator, Mapping, MutableMapping, Tuple, TypeVar
  10. from weakref import WeakValueDictionary
  11. import typing_extensions
  12. from . import _typing_extra
  13. from ._core_utils import get_type_ref
  14. from ._forward_ref import PydanticRecursiveRef
  15. from ._utils import all_identical, is_model_class
  16. if sys.version_info >= (3, 10):
  17. from typing import _UnionGenericAlias # type: ignore[attr-defined]
  18. if TYPE_CHECKING:
  19. from ..main import BaseModel
  20. GenericTypesCacheKey = Tuple[Any, Any, Tuple[Any, ...]]
  21. # Note: We want to remove LimitedDict, but to do this, we'd need to improve the handling of generics caching.
  22. # Right now, to handle recursive generics, we some types must remain cached for brief periods without references.
  23. # By chaining the WeakValuesDict with a LimitedDict, we have a way to retain caching for all types with references,
  24. # while also retaining a limited number of types even without references. This is generally enough to build
  25. # specific recursive generic models without losing required items out of the cache.
  26. KT = TypeVar('KT')
  27. VT = TypeVar('VT')
  28. _LIMITED_DICT_SIZE = 100
  29. if TYPE_CHECKING:
  30. class LimitedDict(dict, MutableMapping[KT, VT]):
  31. def __init__(self, size_limit: int = _LIMITED_DICT_SIZE): ...
  32. else:
  33. class LimitedDict(dict):
  34. """Limit the size/length of a dict used for caching to avoid unlimited increase in memory usage.
  35. Since the dict is ordered, and we always remove elements from the beginning, this is effectively a FIFO cache.
  36. """
  37. def __init__(self, size_limit: int = _LIMITED_DICT_SIZE):
  38. self.size_limit = size_limit
  39. super().__init__()
  40. def __setitem__(self, key: Any, value: Any, /) -> None:
  41. super().__setitem__(key, value)
  42. if len(self) > self.size_limit:
  43. excess = len(self) - self.size_limit + self.size_limit // 10
  44. to_remove = list(self.keys())[:excess]
  45. for k in to_remove:
  46. del self[k]
  47. # weak dictionaries allow the dynamically created parametrized versions of generic models to get collected
  48. # once they are no longer referenced by the caller.
  49. if sys.version_info >= (3, 9): # Typing for weak dictionaries available at 3.9
  50. GenericTypesCache = WeakValueDictionary[GenericTypesCacheKey, 'type[BaseModel]']
  51. else:
  52. GenericTypesCache = WeakValueDictionary
  53. if TYPE_CHECKING:
  54. class DeepChainMap(ChainMap[KT, VT]): # type: ignore
  55. ...
  56. else:
  57. class DeepChainMap(ChainMap):
  58. """Variant of ChainMap that allows direct updates to inner scopes.
  59. Taken from https://docs.python.org/3/library/collections.html#collections.ChainMap,
  60. with some light modifications for this use case.
  61. """
  62. def clear(self) -> None:
  63. for mapping in self.maps:
  64. mapping.clear()
  65. def __setitem__(self, key: KT, value: VT) -> None:
  66. for mapping in self.maps:
  67. mapping[key] = value
  68. def __delitem__(self, key: KT) -> None:
  69. hit = False
  70. for mapping in self.maps:
  71. if key in mapping:
  72. del mapping[key]
  73. hit = True
  74. if not hit:
  75. raise KeyError(key)
  76. # Despite the fact that LimitedDict _seems_ no longer necessary, I'm very nervous to actually remove it
  77. # and discover later on that we need to re-add all this infrastructure...
  78. # _GENERIC_TYPES_CACHE = DeepChainMap(GenericTypesCache(), LimitedDict())
  79. _GENERIC_TYPES_CACHE = GenericTypesCache()
  80. class PydanticGenericMetadata(typing_extensions.TypedDict):
  81. origin: type[BaseModel] | None # analogous to typing._GenericAlias.__origin__
  82. args: tuple[Any, ...] # analogous to typing._GenericAlias.__args__
  83. parameters: tuple[TypeVar, ...] # analogous to typing.Generic.__parameters__
  84. def create_generic_submodel(
  85. model_name: str, origin: type[BaseModel], args: tuple[Any, ...], params: tuple[Any, ...]
  86. ) -> type[BaseModel]:
  87. """Dynamically create a submodel of a provided (generic) BaseModel.
  88. This is used when producing concrete parametrizations of generic models. This function
  89. only *creates* the new subclass; the schema/validators/serialization must be updated to
  90. reflect a concrete parametrization elsewhere.
  91. Args:
  92. model_name: The name of the newly created model.
  93. origin: The base class for the new model to inherit from.
  94. args: A tuple of generic metadata arguments.
  95. params: A tuple of generic metadata parameters.
  96. Returns:
  97. The created submodel.
  98. """
  99. namespace: dict[str, Any] = {'__module__': origin.__module__}
  100. bases = (origin,)
  101. meta, ns, kwds = prepare_class(model_name, bases)
  102. namespace.update(ns)
  103. created_model = meta(
  104. model_name,
  105. bases,
  106. namespace,
  107. __pydantic_generic_metadata__={
  108. 'origin': origin,
  109. 'args': args,
  110. 'parameters': params,
  111. },
  112. __pydantic_reset_parent_namespace__=False,
  113. **kwds,
  114. )
  115. model_module, called_globally = _get_caller_frame_info(depth=3)
  116. if called_globally: # create global reference and therefore allow pickling
  117. object_by_reference = None
  118. reference_name = model_name
  119. reference_module_globals = sys.modules[created_model.__module__].__dict__
  120. while object_by_reference is not created_model:
  121. object_by_reference = reference_module_globals.setdefault(reference_name, created_model)
  122. reference_name += '_'
  123. return created_model
  124. def _get_caller_frame_info(depth: int = 2) -> tuple[str | None, bool]:
  125. """Used inside a function to check whether it was called globally.
  126. Args:
  127. depth: The depth to get the frame.
  128. Returns:
  129. A tuple contains `module_name` and `called_globally`.
  130. Raises:
  131. RuntimeError: If the function is not called inside a function.
  132. """
  133. try:
  134. previous_caller_frame = sys._getframe(depth)
  135. except ValueError as e:
  136. raise RuntimeError('This function must be used inside another function') from e
  137. except AttributeError: # sys module does not have _getframe function, so there's nothing we can do about it
  138. return None, False
  139. frame_globals = previous_caller_frame.f_globals
  140. return frame_globals.get('__name__'), previous_caller_frame.f_locals is frame_globals
  141. DictValues: type[Any] = {}.values().__class__
  142. def iter_contained_typevars(v: Any) -> Iterator[TypeVar]:
  143. """Recursively iterate through all subtypes and type args of `v` and yield any typevars that are found.
  144. This is inspired as an alternative to directly accessing the `__parameters__` attribute of a GenericAlias,
  145. since __parameters__ of (nested) generic BaseModel subclasses won't show up in that list.
  146. """
  147. if isinstance(v, TypeVar):
  148. yield v
  149. elif is_model_class(v):
  150. yield from v.__pydantic_generic_metadata__['parameters']
  151. elif isinstance(v, (DictValues, list)):
  152. for var in v:
  153. yield from iter_contained_typevars(var)
  154. else:
  155. args = get_args(v)
  156. for arg in args:
  157. yield from iter_contained_typevars(arg)
  158. def get_args(v: Any) -> Any:
  159. pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
  160. if pydantic_generic_metadata:
  161. return pydantic_generic_metadata.get('args')
  162. return typing_extensions.get_args(v)
  163. def get_origin(v: Any) -> Any:
  164. pydantic_generic_metadata: PydanticGenericMetadata | None = getattr(v, '__pydantic_generic_metadata__', None)
  165. if pydantic_generic_metadata:
  166. return pydantic_generic_metadata.get('origin')
  167. return typing_extensions.get_origin(v)
  168. def get_standard_typevars_map(cls: Any) -> dict[TypeVar, Any] | None:
  169. """Package a generic type's typevars and parametrization (if present) into a dictionary compatible with the
  170. `replace_types` function. Specifically, this works with standard typing generics and typing._GenericAlias.
  171. """
  172. origin = get_origin(cls)
  173. if origin is None:
  174. return None
  175. if not hasattr(origin, '__parameters__'):
  176. return None
  177. # In this case, we know that cls is a _GenericAlias, and origin is the generic type
  178. # So it is safe to access cls.__args__ and origin.__parameters__
  179. args: tuple[Any, ...] = cls.__args__ # type: ignore
  180. parameters: tuple[TypeVar, ...] = origin.__parameters__
  181. return dict(zip(parameters, args))
  182. def get_model_typevars_map(cls: type[BaseModel]) -> dict[TypeVar, Any] | None:
  183. """Package a generic BaseModel's typevars and concrete parametrization (if present) into a dictionary compatible
  184. with the `replace_types` function.
  185. Since BaseModel.__class_getitem__ does not produce a typing._GenericAlias, and the BaseModel generic info is
  186. stored in the __pydantic_generic_metadata__ attribute, we need special handling here.
  187. """
  188. # TODO: This could be unified with `get_standard_typevars_map` if we stored the generic metadata
  189. # in the __origin__, __args__, and __parameters__ attributes of the model.
  190. generic_metadata = cls.__pydantic_generic_metadata__
  191. origin = generic_metadata['origin']
  192. args = generic_metadata['args']
  193. return dict(zip(iter_contained_typevars(origin), args))
  194. def replace_types(type_: Any, type_map: Mapping[Any, Any] | None) -> Any:
  195. """Return type with all occurrences of `type_map` keys recursively replaced with their values.
  196. Args:
  197. type_: The class or generic alias.
  198. type_map: Mapping from `TypeVar` instance to concrete types.
  199. Returns:
  200. A new type representing the basic structure of `type_` with all
  201. `typevar_map` keys recursively replaced.
  202. Example:
  203. ```python
  204. from typing import List, Tuple, Union
  205. from pydantic._internal._generics import replace_types
  206. replace_types(Tuple[str, Union[List[str], float]], {str: int})
  207. #> Tuple[int, Union[List[int], float]]
  208. ```
  209. """
  210. if not type_map:
  211. return type_
  212. type_args = get_args(type_)
  213. if _typing_extra.is_annotated(type_):
  214. annotated_type, *annotations = type_args
  215. annotated = replace_types(annotated_type, type_map)
  216. for annotation in annotations:
  217. annotated = typing_extensions.Annotated[annotated, annotation]
  218. return annotated
  219. origin_type = get_origin(type_)
  220. # Having type args is a good indicator that this is a typing special form
  221. # instance or a generic alias of some sort.
  222. if type_args:
  223. resolved_type_args = tuple(replace_types(arg, type_map) for arg in type_args)
  224. if all_identical(type_args, resolved_type_args):
  225. # If all arguments are the same, there is no need to modify the
  226. # type or create a new object at all
  227. return type_
  228. if (
  229. origin_type is not None
  230. and isinstance(type_, _typing_extra.typing_base)
  231. and not isinstance(origin_type, _typing_extra.typing_base)
  232. and getattr(type_, '_name', None) is not None
  233. ):
  234. # In python < 3.9 generic aliases don't exist so any of these like `list`,
  235. # `type` or `collections.abc.Callable` need to be translated.
  236. # See: https://www.python.org/dev/peps/pep-0585
  237. origin_type = getattr(typing, type_._name)
  238. assert origin_type is not None
  239. if _typing_extra.origin_is_union(origin_type):
  240. if any(_typing_extra.is_any(arg) for arg in resolved_type_args):
  241. # `Any | T` ~ `Any`:
  242. resolved_type_args = (Any,)
  243. # `Never | T` ~ `T`:
  244. resolved_type_args = tuple(
  245. arg
  246. for arg in resolved_type_args
  247. if not (_typing_extra.is_no_return(arg) or _typing_extra.is_never(arg))
  248. )
  249. # PEP-604 syntax (Ex.: list | str) is represented with a types.UnionType object that does not have __getitem__.
  250. # We also cannot use isinstance() since we have to compare types.
  251. if sys.version_info >= (3, 10) and origin_type is types.UnionType:
  252. return _UnionGenericAlias(origin_type, resolved_type_args)
  253. # NotRequired[T] and Required[T] don't support tuple type resolved_type_args, hence the condition below
  254. return origin_type[resolved_type_args[0] if len(resolved_type_args) == 1 else resolved_type_args]
  255. # We handle pydantic generic models separately as they don't have the same
  256. # semantics as "typing" classes or generic aliases
  257. if not origin_type and is_model_class(type_):
  258. parameters = type_.__pydantic_generic_metadata__['parameters']
  259. if not parameters:
  260. return type_
  261. resolved_type_args = tuple(replace_types(t, type_map) for t in parameters)
  262. if all_identical(parameters, resolved_type_args):
  263. return type_
  264. return type_[resolved_type_args]
  265. # Handle special case for typehints that can have lists as arguments.
  266. # `typing.Callable[[int, str], int]` is an example for this.
  267. if isinstance(type_, list):
  268. resolved_list = [replace_types(element, type_map) for element in type_]
  269. if all_identical(type_, resolved_list):
  270. return type_
  271. return resolved_list
  272. # If all else fails, we try to resolve the type directly and otherwise just
  273. # return the input with no modifications.
  274. return type_map.get(type_, type_)
  275. def has_instance_in_type(type_: Any, isinstance_target: Any) -> bool:
  276. """Checks if the type, or any of its arbitrary nested args, satisfy
  277. `isinstance(<type>, isinstance_target)`.
  278. """
  279. if isinstance(type_, isinstance_target):
  280. return True
  281. if _typing_extra.is_annotated(type_):
  282. return has_instance_in_type(type_.__origin__, isinstance_target)
  283. if _typing_extra.is_literal(type_):
  284. return False
  285. type_args = get_args(type_)
  286. # Having type args is a good indicator that this is a typing module
  287. # class instantiation or a generic alias of some sort.
  288. for arg in type_args:
  289. if has_instance_in_type(arg, isinstance_target):
  290. return True
  291. # Handle special case for typehints that can have lists as arguments.
  292. # `typing.Callable[[int, str], int]` is an example for this.
  293. if (
  294. isinstance(type_, list)
  295. # On Python < 3.10, typing_extensions implements `ParamSpec` as a subclass of `list`:
  296. and not isinstance(type_, typing_extensions.ParamSpec)
  297. ):
  298. for element in type_:
  299. if has_instance_in_type(element, isinstance_target):
  300. return True
  301. return False
  302. def check_parameters_count(cls: type[BaseModel], parameters: tuple[Any, ...]) -> None:
  303. """Check the generic model parameters count is equal.
  304. Args:
  305. cls: The generic model.
  306. parameters: A tuple of passed parameters to the generic model.
  307. Raises:
  308. TypeError: If the passed parameters count is not equal to generic model parameters count.
  309. """
  310. actual = len(parameters)
  311. expected = len(cls.__pydantic_generic_metadata__['parameters'])
  312. if actual != expected:
  313. description = 'many' if actual > expected else 'few'
  314. raise TypeError(f'Too {description} parameters for {cls}; actual {actual}, expected {expected}')
  315. _generic_recursion_cache: ContextVar[set[str] | None] = ContextVar('_generic_recursion_cache', default=None)
  316. @contextmanager
  317. def generic_recursion_self_type(
  318. origin: type[BaseModel], args: tuple[Any, ...]
  319. ) -> Iterator[PydanticRecursiveRef | None]:
  320. """This contextmanager should be placed around the recursive calls used to build a generic type,
  321. and accept as arguments the generic origin type and the type arguments being passed to it.
  322. If the same origin and arguments are observed twice, it implies that a self-reference placeholder
  323. can be used while building the core schema, and will produce a schema_ref that will be valid in the
  324. final parent schema.
  325. """
  326. previously_seen_type_refs = _generic_recursion_cache.get()
  327. if previously_seen_type_refs is None:
  328. previously_seen_type_refs = set()
  329. token = _generic_recursion_cache.set(previously_seen_type_refs)
  330. else:
  331. token = None
  332. try:
  333. type_ref = get_type_ref(origin, args_override=args)
  334. if type_ref in previously_seen_type_refs:
  335. self_type = PydanticRecursiveRef(type_ref=type_ref)
  336. yield self_type
  337. else:
  338. previously_seen_type_refs.add(type_ref)
  339. yield
  340. previously_seen_type_refs.remove(type_ref)
  341. finally:
  342. if token:
  343. _generic_recursion_cache.reset(token)
  344. def recursively_defined_type_refs() -> set[str]:
  345. visited = _generic_recursion_cache.get()
  346. if not visited:
  347. return set() # not in a generic recursion, so there are no types
  348. return visited.copy() # don't allow modifications
  349. def get_cached_generic_type_early(parent: type[BaseModel], typevar_values: Any) -> type[BaseModel] | None:
  350. """The use of a two-stage cache lookup approach was necessary to have the highest performance possible for
  351. repeated calls to `__class_getitem__` on generic types (which may happen in tighter loops during runtime),
  352. while still ensuring that certain alternative parametrizations ultimately resolve to the same type.
  353. As a concrete example, this approach was necessary to make Model[List[T]][int] equal to Model[List[int]].
  354. The approach could be modified to not use two different cache keys at different points, but the
  355. _early_cache_key is optimized to be as quick to compute as possible (for repeated-access speed), and the
  356. _late_cache_key is optimized to be as "correct" as possible, so that two types that will ultimately be the
  357. same after resolving the type arguments will always produce cache hits.
  358. If we wanted to move to only using a single cache key per type, we would either need to always use the
  359. slower/more computationally intensive logic associated with _late_cache_key, or would need to accept
  360. that Model[List[T]][int] is a different type than Model[List[T]][int]. Because we rely on subclass relationships
  361. during validation, I think it is worthwhile to ensure that types that are functionally equivalent are actually
  362. equal.
  363. """
  364. return _GENERIC_TYPES_CACHE.get(_early_cache_key(parent, typevar_values))
  365. def get_cached_generic_type_late(
  366. parent: type[BaseModel], typevar_values: Any, origin: type[BaseModel], args: tuple[Any, ...]
  367. ) -> type[BaseModel] | None:
  368. """See the docstring of `get_cached_generic_type_early` for more information about the two-stage cache lookup."""
  369. cached = _GENERIC_TYPES_CACHE.get(_late_cache_key(origin, args, typevar_values))
  370. if cached is not None:
  371. set_cached_generic_type(parent, typevar_values, cached, origin, args)
  372. return cached
  373. def set_cached_generic_type(
  374. parent: type[BaseModel],
  375. typevar_values: tuple[Any, ...],
  376. type_: type[BaseModel],
  377. origin: type[BaseModel] | None = None,
  378. args: tuple[Any, ...] | None = None,
  379. ) -> None:
  380. """See the docstring of `get_cached_generic_type_early` for more information about why items are cached with
  381. two different keys.
  382. """
  383. _GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values)] = type_
  384. if len(typevar_values) == 1:
  385. _GENERIC_TYPES_CACHE[_early_cache_key(parent, typevar_values[0])] = type_
  386. if origin and args:
  387. _GENERIC_TYPES_CACHE[_late_cache_key(origin, args, typevar_values)] = type_
  388. def _union_orderings_key(typevar_values: Any) -> Any:
  389. """This is intended to help differentiate between Union types with the same arguments in different order.
  390. Thanks to caching internal to the `typing` module, it is not possible to distinguish between
  391. List[Union[int, float]] and List[Union[float, int]] (and similarly for other "parent" origins besides List)
  392. because `typing` considers Union[int, float] to be equal to Union[float, int].
  393. However, you _can_ distinguish between (top-level) Union[int, float] vs. Union[float, int].
  394. Because we parse items as the first Union type that is successful, we get slightly more consistent behavior
  395. if we make an effort to distinguish the ordering of items in a union. It would be best if we could _always_
  396. get the exact-correct order of items in the union, but that would require a change to the `typing` module itself.
  397. (See https://github.com/python/cpython/issues/86483 for reference.)
  398. """
  399. if isinstance(typevar_values, tuple):
  400. args_data = []
  401. for value in typevar_values:
  402. args_data.append(_union_orderings_key(value))
  403. return tuple(args_data)
  404. elif _typing_extra.is_union(typevar_values):
  405. return get_args(typevar_values)
  406. else:
  407. return ()
  408. def _early_cache_key(cls: type[BaseModel], typevar_values: Any) -> GenericTypesCacheKey:
  409. """This is intended for minimal computational overhead during lookups of cached types.
  410. Note that this is overly simplistic, and it's possible that two different cls/typevar_values
  411. inputs would ultimately result in the same type being created in BaseModel.__class_getitem__.
  412. To handle this, we have a fallback _late_cache_key that is checked later if the _early_cache_key
  413. lookup fails, and should result in a cache hit _precisely_ when the inputs to __class_getitem__
  414. would result in the same type.
  415. """
  416. return cls, typevar_values, _union_orderings_key(typevar_values)
  417. def _late_cache_key(origin: type[BaseModel], args: tuple[Any, ...], typevar_values: Any) -> GenericTypesCacheKey:
  418. """This is intended for use later in the process of creating a new type, when we have more information
  419. about the exact args that will be passed. If it turns out that a different set of inputs to
  420. __class_getitem__ resulted in the same inputs to the generic type creation process, we can still
  421. return the cached type, and update the cache with the _early_cache_key as well.
  422. """
  423. # The _union_orderings_key is placed at the start here to ensure there cannot be a collision with an
  424. # _early_cache_key, as that function will always produce a BaseModel subclass as the first item in the key,
  425. # whereas this function will always produce a tuple as the first item in the key.
  426. return _union_orderings_key(typevar_values), origin, args