_compat.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659
  1. from collections import deque
  2. from copy import copy
  3. from dataclasses import dataclass, is_dataclass
  4. from enum import Enum
  5. from functools import lru_cache
  6. from typing import (
  7. Any,
  8. Callable,
  9. Deque,
  10. Dict,
  11. FrozenSet,
  12. List,
  13. Mapping,
  14. Sequence,
  15. Set,
  16. Tuple,
  17. Type,
  18. Union,
  19. )
  20. from fastapi.exceptions import RequestErrorModel
  21. from fastapi.types import IncEx, ModelNameMap, UnionType
  22. from pydantic import BaseModel, create_model
  23. from pydantic.version import VERSION as PYDANTIC_VERSION
  24. from starlette.datastructures import UploadFile
  25. from typing_extensions import Annotated, Literal, get_args, get_origin
  26. PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
  27. PYDANTIC_V2 = PYDANTIC_VERSION_MINOR_TUPLE[0] == 2
  28. sequence_annotation_to_type = {
  29. Sequence: list,
  30. List: list,
  31. list: list,
  32. Tuple: tuple,
  33. tuple: tuple,
  34. Set: set,
  35. set: set,
  36. FrozenSet: frozenset,
  37. frozenset: frozenset,
  38. Deque: deque,
  39. deque: deque,
  40. }
  41. sequence_types = tuple(sequence_annotation_to_type.keys())
  42. Url: Type[Any]
  43. if PYDANTIC_V2:
  44. from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
  45. from pydantic import TypeAdapter
  46. from pydantic import ValidationError as ValidationError
  47. from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
  48. GetJsonSchemaHandler as GetJsonSchemaHandler,
  49. )
  50. from pydantic._internal._typing_extra import eval_type_lenient
  51. from pydantic._internal._utils import lenient_issubclass as lenient_issubclass
  52. from pydantic.fields import FieldInfo
  53. from pydantic.json_schema import GenerateJsonSchema as GenerateJsonSchema
  54. from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
  55. from pydantic_core import CoreSchema as CoreSchema
  56. from pydantic_core import PydanticUndefined, PydanticUndefinedType
  57. from pydantic_core import Url as Url
  58. try:
  59. from pydantic_core.core_schema import (
  60. with_info_plain_validator_function as with_info_plain_validator_function,
  61. )
  62. except ImportError: # pragma: no cover
  63. from pydantic_core.core_schema import (
  64. general_plain_validator_function as with_info_plain_validator_function, # noqa: F401
  65. )
  66. RequiredParam = PydanticUndefined
  67. Undefined = PydanticUndefined
  68. UndefinedType = PydanticUndefinedType
  69. evaluate_forwardref = eval_type_lenient
  70. Validator = Any
  71. class BaseConfig:
  72. pass
  73. class ErrorWrapper(Exception):
  74. pass
  75. @dataclass
  76. class ModelField:
  77. field_info: FieldInfo
  78. name: str
  79. mode: Literal["validation", "serialization"] = "validation"
  80. @property
  81. def alias(self) -> str:
  82. a = self.field_info.alias
  83. return a if a is not None else self.name
  84. @property
  85. def required(self) -> bool:
  86. return self.field_info.is_required()
  87. @property
  88. def default(self) -> Any:
  89. return self.get_default()
  90. @property
  91. def type_(self) -> Any:
  92. return self.field_info.annotation
  93. def __post_init__(self) -> None:
  94. self._type_adapter: TypeAdapter[Any] = TypeAdapter(
  95. Annotated[self.field_info.annotation, self.field_info]
  96. )
  97. def get_default(self) -> Any:
  98. if self.field_info.is_required():
  99. return Undefined
  100. return self.field_info.get_default(call_default_factory=True)
  101. def validate(
  102. self,
  103. value: Any,
  104. values: Dict[str, Any] = {}, # noqa: B006
  105. *,
  106. loc: Tuple[Union[int, str], ...] = (),
  107. ) -> Tuple[Any, Union[List[Dict[str, Any]], None]]:
  108. try:
  109. return (
  110. self._type_adapter.validate_python(value, from_attributes=True),
  111. None,
  112. )
  113. except ValidationError as exc:
  114. return None, _regenerate_error_with_loc(
  115. errors=exc.errors(include_url=False), loc_prefix=loc
  116. )
  117. def serialize(
  118. self,
  119. value: Any,
  120. *,
  121. mode: Literal["json", "python"] = "json",
  122. include: Union[IncEx, None] = None,
  123. exclude: Union[IncEx, None] = None,
  124. by_alias: bool = True,
  125. exclude_unset: bool = False,
  126. exclude_defaults: bool = False,
  127. exclude_none: bool = False,
  128. ) -> Any:
  129. # What calls this code passes a value that already called
  130. # self._type_adapter.validate_python(value)
  131. return self._type_adapter.dump_python(
  132. value,
  133. mode=mode,
  134. include=include,
  135. exclude=exclude,
  136. by_alias=by_alias,
  137. exclude_unset=exclude_unset,
  138. exclude_defaults=exclude_defaults,
  139. exclude_none=exclude_none,
  140. )
  141. def __hash__(self) -> int:
  142. # Each ModelField is unique for our purposes, to allow making a dict from
  143. # ModelField to its JSON Schema.
  144. return id(self)
  145. def get_annotation_from_field_info(
  146. annotation: Any, field_info: FieldInfo, field_name: str
  147. ) -> Any:
  148. return annotation
  149. def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
  150. return errors # type: ignore[return-value]
  151. def _model_rebuild(model: Type[BaseModel]) -> None:
  152. model.model_rebuild()
  153. def _model_dump(
  154. model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
  155. ) -> Any:
  156. return model.model_dump(mode=mode, **kwargs)
  157. def _get_model_config(model: BaseModel) -> Any:
  158. return model.model_config
  159. def get_schema_from_model_field(
  160. *,
  161. field: ModelField,
  162. schema_generator: GenerateJsonSchema,
  163. model_name_map: ModelNameMap,
  164. field_mapping: Dict[
  165. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  166. ],
  167. separate_input_output_schemas: bool = True,
  168. ) -> Dict[str, Any]:
  169. override_mode: Union[Literal["validation"], None] = (
  170. None if separate_input_output_schemas else "validation"
  171. )
  172. # This expects that GenerateJsonSchema was already used to generate the definitions
  173. json_schema = field_mapping[(field, override_mode or field.mode)]
  174. if "$ref" not in json_schema:
  175. # TODO remove when deprecating Pydantic v1
  176. # Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
  177. json_schema["title"] = (
  178. field.field_info.title or field.alias.title().replace("_", " ")
  179. )
  180. return json_schema
  181. def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
  182. return {}
  183. def get_definitions(
  184. *,
  185. fields: List[ModelField],
  186. schema_generator: GenerateJsonSchema,
  187. model_name_map: ModelNameMap,
  188. separate_input_output_schemas: bool = True,
  189. ) -> Tuple[
  190. Dict[
  191. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  192. ],
  193. Dict[str, Dict[str, Any]],
  194. ]:
  195. override_mode: Union[Literal["validation"], None] = (
  196. None if separate_input_output_schemas else "validation"
  197. )
  198. inputs = [
  199. (field, override_mode or field.mode, field._type_adapter.core_schema)
  200. for field in fields
  201. ]
  202. field_mapping, definitions = schema_generator.generate_definitions(
  203. inputs=inputs
  204. )
  205. return field_mapping, definitions # type: ignore[return-value]
  206. def is_scalar_field(field: ModelField) -> bool:
  207. from fastapi import params
  208. return field_annotation_is_scalar(
  209. field.field_info.annotation
  210. ) and not isinstance(field.field_info, params.Body)
  211. def is_sequence_field(field: ModelField) -> bool:
  212. return field_annotation_is_sequence(field.field_info.annotation)
  213. def is_scalar_sequence_field(field: ModelField) -> bool:
  214. return field_annotation_is_scalar_sequence(field.field_info.annotation)
  215. def is_bytes_field(field: ModelField) -> bool:
  216. return is_bytes_or_nonable_bytes_annotation(field.type_)
  217. def is_bytes_sequence_field(field: ModelField) -> bool:
  218. return is_bytes_sequence_annotation(field.type_)
  219. def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
  220. cls = type(field_info)
  221. merged_field_info = cls.from_annotation(annotation)
  222. new_field_info = copy(field_info)
  223. new_field_info.metadata = merged_field_info.metadata
  224. new_field_info.annotation = merged_field_info.annotation
  225. return new_field_info
  226. def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
  227. origin_type = (
  228. get_origin(field.field_info.annotation) or field.field_info.annotation
  229. )
  230. assert issubclass(origin_type, sequence_types) # type: ignore[arg-type]
  231. return sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return]
  232. def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
  233. error = ValidationError.from_exception_data(
  234. "Field required", [{"type": "missing", "loc": loc, "input": {}}]
  235. ).errors(include_url=False)[0]
  236. error["input"] = None
  237. return error # type: ignore[return-value]
  238. def create_body_model(
  239. *, fields: Sequence[ModelField], model_name: str
  240. ) -> Type[BaseModel]:
  241. field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
  242. BodyModel: Type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
  243. return BodyModel
  244. def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
  245. return [
  246. ModelField(field_info=field_info, name=name)
  247. for name, field_info in model.model_fields.items()
  248. ]
  249. else:
  250. from fastapi.openapi.constants import REF_PREFIX as REF_PREFIX
  251. from pydantic import AnyUrl as Url # noqa: F401
  252. from pydantic import ( # type: ignore[assignment]
  253. BaseConfig as BaseConfig, # noqa: F401
  254. )
  255. from pydantic import ValidationError as ValidationError # noqa: F401
  256. from pydantic.class_validators import ( # type: ignore[no-redef]
  257. Validator as Validator, # noqa: F401
  258. )
  259. from pydantic.error_wrappers import ( # type: ignore[no-redef]
  260. ErrorWrapper as ErrorWrapper, # noqa: F401
  261. )
  262. from pydantic.errors import MissingError
  263. from pydantic.fields import ( # type: ignore[attr-defined]
  264. SHAPE_FROZENSET,
  265. SHAPE_LIST,
  266. SHAPE_SEQUENCE,
  267. SHAPE_SET,
  268. SHAPE_SINGLETON,
  269. SHAPE_TUPLE,
  270. SHAPE_TUPLE_ELLIPSIS,
  271. )
  272. from pydantic.fields import FieldInfo as FieldInfo
  273. from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
  274. ModelField as ModelField, # noqa: F401
  275. )
  276. # Keeping old "Required" functionality from Pydantic V1, without
  277. # shadowing typing.Required.
  278. RequiredParam: Any = Ellipsis # type: ignore[no-redef]
  279. from pydantic.fields import ( # type: ignore[no-redef,attr-defined]
  280. Undefined as Undefined,
  281. )
  282. from pydantic.fields import ( # type: ignore[no-redef, attr-defined]
  283. UndefinedType as UndefinedType, # noqa: F401
  284. )
  285. from pydantic.schema import (
  286. field_schema,
  287. get_flat_models_from_fields,
  288. get_model_name_map,
  289. model_process_schema,
  290. )
  291. from pydantic.schema import ( # type: ignore[no-redef] # noqa: F401
  292. get_annotation_from_field_info as get_annotation_from_field_info,
  293. )
  294. from pydantic.typing import ( # type: ignore[no-redef]
  295. evaluate_forwardref as evaluate_forwardref, # noqa: F401
  296. )
  297. from pydantic.utils import ( # type: ignore[no-redef]
  298. lenient_issubclass as lenient_issubclass, # noqa: F401
  299. )
  300. GetJsonSchemaHandler = Any # type: ignore[assignment,misc]
  301. JsonSchemaValue = Dict[str, Any] # type: ignore[misc]
  302. CoreSchema = Any # type: ignore[assignment,misc]
  303. sequence_shapes = {
  304. SHAPE_LIST,
  305. SHAPE_SET,
  306. SHAPE_FROZENSET,
  307. SHAPE_TUPLE,
  308. SHAPE_SEQUENCE,
  309. SHAPE_TUPLE_ELLIPSIS,
  310. }
  311. sequence_shape_to_type = {
  312. SHAPE_LIST: list,
  313. SHAPE_SET: set,
  314. SHAPE_TUPLE: tuple,
  315. SHAPE_SEQUENCE: list,
  316. SHAPE_TUPLE_ELLIPSIS: list,
  317. }
  318. @dataclass
  319. class GenerateJsonSchema: # type: ignore[no-redef]
  320. ref_template: str
  321. class PydanticSchemaGenerationError(Exception): # type: ignore[no-redef]
  322. pass
  323. def with_info_plain_validator_function( # type: ignore[misc]
  324. function: Callable[..., Any],
  325. *,
  326. ref: Union[str, None] = None,
  327. metadata: Any = None,
  328. serialization: Any = None,
  329. ) -> Any:
  330. return {}
  331. def get_model_definitions(
  332. *,
  333. flat_models: Set[Union[Type[BaseModel], Type[Enum]]],
  334. model_name_map: Dict[Union[Type[BaseModel], Type[Enum]], str],
  335. ) -> Dict[str, Any]:
  336. definitions: Dict[str, Dict[str, Any]] = {}
  337. for model in flat_models:
  338. m_schema, m_definitions, m_nested_models = model_process_schema(
  339. model, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  340. )
  341. definitions.update(m_definitions)
  342. model_name = model_name_map[model]
  343. if "description" in m_schema:
  344. m_schema["description"] = m_schema["description"].split("\f")[0]
  345. definitions[model_name] = m_schema
  346. return definitions
  347. def is_pv1_scalar_field(field: ModelField) -> bool:
  348. from fastapi import params
  349. field_info = field.field_info
  350. if not (
  351. field.shape == SHAPE_SINGLETON # type: ignore[attr-defined]
  352. and not lenient_issubclass(field.type_, BaseModel)
  353. and not lenient_issubclass(field.type_, dict)
  354. and not field_annotation_is_sequence(field.type_)
  355. and not is_dataclass(field.type_)
  356. and not isinstance(field_info, params.Body)
  357. ):
  358. return False
  359. if field.sub_fields: # type: ignore[attr-defined]
  360. if not all(
  361. is_pv1_scalar_field(f)
  362. for f in field.sub_fields # type: ignore[attr-defined]
  363. ):
  364. return False
  365. return True
  366. def is_pv1_scalar_sequence_field(field: ModelField) -> bool:
  367. if (field.shape in sequence_shapes) and not lenient_issubclass( # type: ignore[attr-defined]
  368. field.type_, BaseModel
  369. ):
  370. if field.sub_fields is not None: # type: ignore[attr-defined]
  371. for sub_field in field.sub_fields: # type: ignore[attr-defined]
  372. if not is_pv1_scalar_field(sub_field):
  373. return False
  374. return True
  375. if _annotation_is_sequence(field.type_):
  376. return True
  377. return False
  378. def _normalize_errors(errors: Sequence[Any]) -> List[Dict[str, Any]]:
  379. use_errors: List[Any] = []
  380. for error in errors:
  381. if isinstance(error, ErrorWrapper):
  382. new_errors = ValidationError( # type: ignore[call-arg]
  383. errors=[error], model=RequestErrorModel
  384. ).errors()
  385. use_errors.extend(new_errors)
  386. elif isinstance(error, list):
  387. use_errors.extend(_normalize_errors(error))
  388. else:
  389. use_errors.append(error)
  390. return use_errors
  391. def _model_rebuild(model: Type[BaseModel]) -> None:
  392. model.update_forward_refs()
  393. def _model_dump(
  394. model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any
  395. ) -> Any:
  396. return model.dict(**kwargs)
  397. def _get_model_config(model: BaseModel) -> Any:
  398. return model.__config__ # type: ignore[attr-defined]
  399. def get_schema_from_model_field(
  400. *,
  401. field: ModelField,
  402. schema_generator: GenerateJsonSchema,
  403. model_name_map: ModelNameMap,
  404. field_mapping: Dict[
  405. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  406. ],
  407. separate_input_output_schemas: bool = True,
  408. ) -> Dict[str, Any]:
  409. # This expects that GenerateJsonSchema was already used to generate the definitions
  410. return field_schema( # type: ignore[no-any-return]
  411. field, model_name_map=model_name_map, ref_prefix=REF_PREFIX
  412. )[0]
  413. def get_compat_model_name_map(fields: List[ModelField]) -> ModelNameMap:
  414. models = get_flat_models_from_fields(fields, known_models=set())
  415. return get_model_name_map(models) # type: ignore[no-any-return]
  416. def get_definitions(
  417. *,
  418. fields: List[ModelField],
  419. schema_generator: GenerateJsonSchema,
  420. model_name_map: ModelNameMap,
  421. separate_input_output_schemas: bool = True,
  422. ) -> Tuple[
  423. Dict[
  424. Tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
  425. ],
  426. Dict[str, Dict[str, Any]],
  427. ]:
  428. models = get_flat_models_from_fields(fields, known_models=set())
  429. return {}, get_model_definitions(
  430. flat_models=models, model_name_map=model_name_map
  431. )
  432. def is_scalar_field(field: ModelField) -> bool:
  433. return is_pv1_scalar_field(field)
  434. def is_sequence_field(field: ModelField) -> bool:
  435. return field.shape in sequence_shapes or _annotation_is_sequence(field.type_) # type: ignore[attr-defined]
  436. def is_scalar_sequence_field(field: ModelField) -> bool:
  437. return is_pv1_scalar_sequence_field(field)
  438. def is_bytes_field(field: ModelField) -> bool:
  439. return lenient_issubclass(field.type_, bytes)
  440. def is_bytes_sequence_field(field: ModelField) -> bool:
  441. return field.shape in sequence_shapes and lenient_issubclass(field.type_, bytes) # type: ignore[attr-defined]
  442. def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
  443. return copy(field_info)
  444. def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
  445. return sequence_shape_to_type[field.shape](value) # type: ignore[no-any-return,attr-defined]
  446. def get_missing_field_error(loc: Tuple[str, ...]) -> Dict[str, Any]:
  447. missing_field_error = ErrorWrapper(MissingError(), loc=loc) # type: ignore[call-arg]
  448. new_error = ValidationError([missing_field_error], RequestErrorModel)
  449. return new_error.errors()[0] # type: ignore[return-value]
  450. def create_body_model(
  451. *, fields: Sequence[ModelField], model_name: str
  452. ) -> Type[BaseModel]:
  453. BodyModel = create_model(model_name)
  454. for f in fields:
  455. BodyModel.__fields__[f.name] = f # type: ignore[index]
  456. return BodyModel
  457. def get_model_fields(model: Type[BaseModel]) -> List[ModelField]:
  458. return list(model.__fields__.values()) # type: ignore[attr-defined]
  459. def _regenerate_error_with_loc(
  460. *, errors: Sequence[Any], loc_prefix: Tuple[Union[str, int], ...]
  461. ) -> List[Dict[str, Any]]:
  462. updated_loc_errors: List[Any] = [
  463. {**err, "loc": loc_prefix + err.get("loc", ())}
  464. for err in _normalize_errors(errors)
  465. ]
  466. return updated_loc_errors
  467. def _annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
  468. if lenient_issubclass(annotation, (str, bytes)):
  469. return False
  470. return lenient_issubclass(annotation, sequence_types)
  471. def field_annotation_is_sequence(annotation: Union[Type[Any], None]) -> bool:
  472. origin = get_origin(annotation)
  473. if origin is Union or origin is UnionType:
  474. for arg in get_args(annotation):
  475. if field_annotation_is_sequence(arg):
  476. return True
  477. return False
  478. return _annotation_is_sequence(annotation) or _annotation_is_sequence(
  479. get_origin(annotation)
  480. )
  481. def value_is_sequence(value: Any) -> bool:
  482. return isinstance(value, sequence_types) and not isinstance(value, (str, bytes)) # type: ignore[arg-type]
  483. def _annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
  484. return (
  485. lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
  486. or _annotation_is_sequence(annotation)
  487. or is_dataclass(annotation)
  488. )
  489. def field_annotation_is_complex(annotation: Union[Type[Any], None]) -> bool:
  490. origin = get_origin(annotation)
  491. if origin is Union or origin is UnionType:
  492. return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
  493. return (
  494. _annotation_is_complex(annotation)
  495. or _annotation_is_complex(origin)
  496. or hasattr(origin, "__pydantic_core_schema__")
  497. or hasattr(origin, "__get_pydantic_core_schema__")
  498. )
  499. def field_annotation_is_scalar(annotation: Any) -> bool:
  500. # handle Ellipsis here to make tuple[int, ...] work nicely
  501. return annotation is Ellipsis or not field_annotation_is_complex(annotation)
  502. def field_annotation_is_scalar_sequence(annotation: Union[Type[Any], None]) -> bool:
  503. origin = get_origin(annotation)
  504. if origin is Union or origin is UnionType:
  505. at_least_one_scalar_sequence = False
  506. for arg in get_args(annotation):
  507. if field_annotation_is_scalar_sequence(arg):
  508. at_least_one_scalar_sequence = True
  509. continue
  510. elif not field_annotation_is_scalar(arg):
  511. return False
  512. return at_least_one_scalar_sequence
  513. return field_annotation_is_sequence(annotation) and all(
  514. field_annotation_is_scalar(sub_annotation)
  515. for sub_annotation in get_args(annotation)
  516. )
  517. def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
  518. if lenient_issubclass(annotation, bytes):
  519. return True
  520. origin = get_origin(annotation)
  521. if origin is Union or origin is UnionType:
  522. for arg in get_args(annotation):
  523. if lenient_issubclass(arg, bytes):
  524. return True
  525. return False
  526. def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
  527. if lenient_issubclass(annotation, UploadFile):
  528. return True
  529. origin = get_origin(annotation)
  530. if origin is Union or origin is UnionType:
  531. for arg in get_args(annotation):
  532. if lenient_issubclass(arg, UploadFile):
  533. return True
  534. return False
  535. def is_bytes_sequence_annotation(annotation: Any) -> bool:
  536. origin = get_origin(annotation)
  537. if origin is Union or origin is UnionType:
  538. at_least_one = False
  539. for arg in get_args(annotation):
  540. if is_bytes_sequence_annotation(arg):
  541. at_least_one = True
  542. continue
  543. return at_least_one
  544. return field_annotation_is_sequence(annotation) and all(
  545. is_bytes_or_nonable_bytes_annotation(sub_annotation)
  546. for sub_annotation in get_args(annotation)
  547. )
  548. def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
  549. origin = get_origin(annotation)
  550. if origin is Union or origin is UnionType:
  551. at_least_one = False
  552. for arg in get_args(annotation):
  553. if is_uploadfile_sequence_annotation(arg):
  554. at_least_one = True
  555. continue
  556. return at_least_one
  557. return field_annotation_is_sequence(annotation) and all(
  558. is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
  559. for sub_annotation in get_args(annotation)
  560. )
  561. @lru_cache
  562. def get_cached_model_fields(model: Type[BaseModel]) -> List[ModelField]:
  563. return get_model_fields(model)