utils.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980
  1. import inspect
  2. from contextlib import AsyncExitStack, contextmanager
  3. from copy import copy, deepcopy
  4. from dataclasses import dataclass
  5. from typing import (
  6. Any,
  7. Callable,
  8. Coroutine,
  9. Dict,
  10. ForwardRef,
  11. List,
  12. Mapping,
  13. Optional,
  14. Sequence,
  15. Tuple,
  16. Type,
  17. Union,
  18. cast,
  19. )
  20. import anyio
  21. from fastapi import params
  22. from fastapi._compat import (
  23. PYDANTIC_V2,
  24. ErrorWrapper,
  25. ModelField,
  26. RequiredParam,
  27. Undefined,
  28. _regenerate_error_with_loc,
  29. copy_field_info,
  30. create_body_model,
  31. evaluate_forwardref,
  32. field_annotation_is_scalar,
  33. get_annotation_from_field_info,
  34. get_cached_model_fields,
  35. get_missing_field_error,
  36. is_bytes_field,
  37. is_bytes_sequence_field,
  38. is_scalar_field,
  39. is_scalar_sequence_field,
  40. is_sequence_field,
  41. is_uploadfile_or_nonable_uploadfile_annotation,
  42. is_uploadfile_sequence_annotation,
  43. lenient_issubclass,
  44. sequence_types,
  45. serialize_sequence_value,
  46. value_is_sequence,
  47. )
  48. from fastapi.background import BackgroundTasks
  49. from fastapi.concurrency import (
  50. asynccontextmanager,
  51. contextmanager_in_threadpool,
  52. )
  53. from fastapi.dependencies.models import Dependant, SecurityRequirement
  54. from fastapi.logger import logger
  55. from fastapi.security.base import SecurityBase
  56. from fastapi.security.oauth2 import OAuth2, SecurityScopes
  57. from fastapi.security.open_id_connect_url import OpenIdConnect
  58. from fastapi.utils import create_model_field, get_path_param_names
  59. from pydantic import BaseModel
  60. from pydantic.fields import FieldInfo
  61. from starlette.background import BackgroundTasks as StarletteBackgroundTasks
  62. from starlette.concurrency import run_in_threadpool
  63. from starlette.datastructures import (
  64. FormData,
  65. Headers,
  66. ImmutableMultiDict,
  67. QueryParams,
  68. UploadFile,
  69. )
  70. from starlette.requests import HTTPConnection, Request
  71. from starlette.responses import Response
  72. from starlette.websockets import WebSocket
  73. from typing_extensions import Annotated, get_args, get_origin
  74. multipart_not_installed_error = (
  75. 'Form data requires "python-multipart" to be installed. \n'
  76. 'You can install "python-multipart" with: \n\n'
  77. "pip install python-multipart\n"
  78. )
  79. multipart_incorrect_install_error = (
  80. 'Form data requires "python-multipart" to be installed. '
  81. 'It seems you installed "multipart" instead. \n'
  82. 'You can remove "multipart" with: \n\n'
  83. "pip uninstall multipart\n\n"
  84. 'And then install "python-multipart" with: \n\n'
  85. "pip install python-multipart\n"
  86. )
  87. def ensure_multipart_is_installed() -> None:
  88. try:
  89. from python_multipart import __version__
  90. # Import an attribute that can be mocked/deleted in testing
  91. assert __version__ > "0.0.12"
  92. except (ImportError, AssertionError):
  93. try:
  94. # __version__ is available in both multiparts, and can be mocked
  95. from multipart import __version__ # type: ignore[no-redef,import-untyped]
  96. assert __version__
  97. try:
  98. # parse_options_header is only available in the right multipart
  99. from multipart.multipart import ( # type: ignore[import-untyped]
  100. parse_options_header,
  101. )
  102. assert parse_options_header
  103. except ImportError:
  104. logger.error(multipart_incorrect_install_error)
  105. raise RuntimeError(multipart_incorrect_install_error) from None
  106. except ImportError:
  107. logger.error(multipart_not_installed_error)
  108. raise RuntimeError(multipart_not_installed_error) from None
  109. def get_param_sub_dependant(
  110. *,
  111. param_name: str,
  112. depends: params.Depends,
  113. path: str,
  114. security_scopes: Optional[List[str]] = None,
  115. ) -> Dependant:
  116. assert depends.dependency
  117. return get_sub_dependant(
  118. depends=depends,
  119. dependency=depends.dependency,
  120. path=path,
  121. name=param_name,
  122. security_scopes=security_scopes,
  123. )
  124. def get_parameterless_sub_dependant(*, depends: params.Depends, path: str) -> Dependant:
  125. assert callable(depends.dependency), (
  126. "A parameter-less dependency must have a callable dependency"
  127. )
  128. return get_sub_dependant(depends=depends, dependency=depends.dependency, path=path)
  129. def get_sub_dependant(
  130. *,
  131. depends: params.Depends,
  132. dependency: Callable[..., Any],
  133. path: str,
  134. name: Optional[str] = None,
  135. security_scopes: Optional[List[str]] = None,
  136. ) -> Dependant:
  137. security_requirement = None
  138. security_scopes = security_scopes or []
  139. if isinstance(depends, params.Security):
  140. dependency_scopes = depends.scopes
  141. security_scopes.extend(dependency_scopes)
  142. if isinstance(dependency, SecurityBase):
  143. use_scopes: List[str] = []
  144. if isinstance(dependency, (OAuth2, OpenIdConnect)):
  145. use_scopes = security_scopes
  146. security_requirement = SecurityRequirement(
  147. security_scheme=dependency, scopes=use_scopes
  148. )
  149. sub_dependant = get_dependant(
  150. path=path,
  151. call=dependency,
  152. name=name,
  153. security_scopes=security_scopes,
  154. use_cache=depends.use_cache,
  155. )
  156. if security_requirement:
  157. sub_dependant.security_requirements.append(security_requirement)
  158. return sub_dependant
  159. CacheKey = Tuple[Optional[Callable[..., Any]], Tuple[str, ...]]
  160. def get_flat_dependant(
  161. dependant: Dependant,
  162. *,
  163. skip_repeats: bool = False,
  164. visited: Optional[List[CacheKey]] = None,
  165. ) -> Dependant:
  166. if visited is None:
  167. visited = []
  168. visited.append(dependant.cache_key)
  169. flat_dependant = Dependant(
  170. path_params=dependant.path_params.copy(),
  171. query_params=dependant.query_params.copy(),
  172. header_params=dependant.header_params.copy(),
  173. cookie_params=dependant.cookie_params.copy(),
  174. body_params=dependant.body_params.copy(),
  175. security_requirements=dependant.security_requirements.copy(),
  176. use_cache=dependant.use_cache,
  177. path=dependant.path,
  178. )
  179. for sub_dependant in dependant.dependencies:
  180. if skip_repeats and sub_dependant.cache_key in visited:
  181. continue
  182. flat_sub = get_flat_dependant(
  183. sub_dependant, skip_repeats=skip_repeats, visited=visited
  184. )
  185. flat_dependant.path_params.extend(flat_sub.path_params)
  186. flat_dependant.query_params.extend(flat_sub.query_params)
  187. flat_dependant.header_params.extend(flat_sub.header_params)
  188. flat_dependant.cookie_params.extend(flat_sub.cookie_params)
  189. flat_dependant.body_params.extend(flat_sub.body_params)
  190. flat_dependant.security_requirements.extend(flat_sub.security_requirements)
  191. return flat_dependant
  192. def _get_flat_fields_from_params(fields: List[ModelField]) -> List[ModelField]:
  193. if not fields:
  194. return fields
  195. first_field = fields[0]
  196. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  197. fields_to_extract = get_cached_model_fields(first_field.type_)
  198. return fields_to_extract
  199. return fields
  200. def get_flat_params(dependant: Dependant) -> List[ModelField]:
  201. flat_dependant = get_flat_dependant(dependant, skip_repeats=True)
  202. path_params = _get_flat_fields_from_params(flat_dependant.path_params)
  203. query_params = _get_flat_fields_from_params(flat_dependant.query_params)
  204. header_params = _get_flat_fields_from_params(flat_dependant.header_params)
  205. cookie_params = _get_flat_fields_from_params(flat_dependant.cookie_params)
  206. return path_params + query_params + header_params + cookie_params
  207. def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
  208. signature = inspect.signature(call)
  209. globalns = getattr(call, "__globals__", {})
  210. typed_params = [
  211. inspect.Parameter(
  212. name=param.name,
  213. kind=param.kind,
  214. default=param.default,
  215. annotation=get_typed_annotation(param.annotation, globalns),
  216. )
  217. for param in signature.parameters.values()
  218. ]
  219. typed_signature = inspect.Signature(typed_params)
  220. return typed_signature
  221. def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
  222. if isinstance(annotation, str):
  223. annotation = ForwardRef(annotation)
  224. annotation = evaluate_forwardref(annotation, globalns, globalns)
  225. return annotation
  226. def get_typed_return_annotation(call: Callable[..., Any]) -> Any:
  227. signature = inspect.signature(call)
  228. annotation = signature.return_annotation
  229. if annotation is inspect.Signature.empty:
  230. return None
  231. globalns = getattr(call, "__globals__", {})
  232. return get_typed_annotation(annotation, globalns)
  233. def get_dependant(
  234. *,
  235. path: str,
  236. call: Callable[..., Any],
  237. name: Optional[str] = None,
  238. security_scopes: Optional[List[str]] = None,
  239. use_cache: bool = True,
  240. ) -> Dependant:
  241. path_param_names = get_path_param_names(path)
  242. endpoint_signature = get_typed_signature(call)
  243. signature_params = endpoint_signature.parameters
  244. dependant = Dependant(
  245. call=call,
  246. name=name,
  247. path=path,
  248. security_scopes=security_scopes,
  249. use_cache=use_cache,
  250. )
  251. for param_name, param in signature_params.items():
  252. is_path_param = param_name in path_param_names
  253. param_details = analyze_param(
  254. param_name=param_name,
  255. annotation=param.annotation,
  256. value=param.default,
  257. is_path_param=is_path_param,
  258. )
  259. if param_details.depends is not None:
  260. sub_dependant = get_param_sub_dependant(
  261. param_name=param_name,
  262. depends=param_details.depends,
  263. path=path,
  264. security_scopes=security_scopes,
  265. )
  266. dependant.dependencies.append(sub_dependant)
  267. continue
  268. if add_non_field_param_to_dependency(
  269. param_name=param_name,
  270. type_annotation=param_details.type_annotation,
  271. dependant=dependant,
  272. ):
  273. assert param_details.field is None, (
  274. f"Cannot specify multiple FastAPI annotations for {param_name!r}"
  275. )
  276. continue
  277. assert param_details.field is not None
  278. if isinstance(param_details.field.field_info, params.Body):
  279. dependant.body_params.append(param_details.field)
  280. else:
  281. add_param_to_fields(field=param_details.field, dependant=dependant)
  282. return dependant
  283. def add_non_field_param_to_dependency(
  284. *, param_name: str, type_annotation: Any, dependant: Dependant
  285. ) -> Optional[bool]:
  286. if lenient_issubclass(type_annotation, Request):
  287. dependant.request_param_name = param_name
  288. return True
  289. elif lenient_issubclass(type_annotation, WebSocket):
  290. dependant.websocket_param_name = param_name
  291. return True
  292. elif lenient_issubclass(type_annotation, HTTPConnection):
  293. dependant.http_connection_param_name = param_name
  294. return True
  295. elif lenient_issubclass(type_annotation, Response):
  296. dependant.response_param_name = param_name
  297. return True
  298. elif lenient_issubclass(type_annotation, StarletteBackgroundTasks):
  299. dependant.background_tasks_param_name = param_name
  300. return True
  301. elif lenient_issubclass(type_annotation, SecurityScopes):
  302. dependant.security_scopes_param_name = param_name
  303. return True
  304. return None
  305. @dataclass
  306. class ParamDetails:
  307. type_annotation: Any
  308. depends: Optional[params.Depends]
  309. field: Optional[ModelField]
  310. def analyze_param(
  311. *,
  312. param_name: str,
  313. annotation: Any,
  314. value: Any,
  315. is_path_param: bool,
  316. ) -> ParamDetails:
  317. field_info = None
  318. depends = None
  319. type_annotation: Any = Any
  320. use_annotation: Any = Any
  321. if annotation is not inspect.Signature.empty:
  322. use_annotation = annotation
  323. type_annotation = annotation
  324. # Extract Annotated info
  325. if get_origin(use_annotation) is Annotated:
  326. annotated_args = get_args(annotation)
  327. type_annotation = annotated_args[0]
  328. fastapi_annotations = [
  329. arg
  330. for arg in annotated_args[1:]
  331. if isinstance(arg, (FieldInfo, params.Depends))
  332. ]
  333. fastapi_specific_annotations = [
  334. arg
  335. for arg in fastapi_annotations
  336. if isinstance(arg, (params.Param, params.Body, params.Depends))
  337. ]
  338. if fastapi_specific_annotations:
  339. fastapi_annotation: Union[FieldInfo, params.Depends, None] = (
  340. fastapi_specific_annotations[-1]
  341. )
  342. else:
  343. fastapi_annotation = None
  344. # Set default for Annotated FieldInfo
  345. if isinstance(fastapi_annotation, FieldInfo):
  346. # Copy `field_info` because we mutate `field_info.default` below.
  347. field_info = copy_field_info(
  348. field_info=fastapi_annotation, annotation=use_annotation
  349. )
  350. assert (
  351. field_info.default is Undefined or field_info.default is RequiredParam
  352. ), (
  353. f"`{field_info.__class__.__name__}` default value cannot be set in"
  354. f" `Annotated` for {param_name!r}. Set the default value with `=` instead."
  355. )
  356. if value is not inspect.Signature.empty:
  357. assert not is_path_param, "Path parameters cannot have default values"
  358. field_info.default = value
  359. else:
  360. field_info.default = RequiredParam
  361. # Get Annotated Depends
  362. elif isinstance(fastapi_annotation, params.Depends):
  363. depends = fastapi_annotation
  364. # Get Depends from default value
  365. if isinstance(value, params.Depends):
  366. assert depends is None, (
  367. "Cannot specify `Depends` in `Annotated` and default value"
  368. f" together for {param_name!r}"
  369. )
  370. assert field_info is None, (
  371. "Cannot specify a FastAPI annotation in `Annotated` and `Depends` as a"
  372. f" default value together for {param_name!r}"
  373. )
  374. depends = value
  375. # Get FieldInfo from default value
  376. elif isinstance(value, FieldInfo):
  377. assert field_info is None, (
  378. "Cannot specify FastAPI annotations in `Annotated` and default value"
  379. f" together for {param_name!r}"
  380. )
  381. field_info = value
  382. if PYDANTIC_V2:
  383. field_info.annotation = type_annotation
  384. # Get Depends from type annotation
  385. if depends is not None and depends.dependency is None:
  386. # Copy `depends` before mutating it
  387. depends = copy(depends)
  388. depends.dependency = type_annotation
  389. # Handle non-param type annotations like Request
  390. if lenient_issubclass(
  391. type_annotation,
  392. (
  393. Request,
  394. WebSocket,
  395. HTTPConnection,
  396. Response,
  397. StarletteBackgroundTasks,
  398. SecurityScopes,
  399. ),
  400. ):
  401. assert depends is None, f"Cannot specify `Depends` for type {type_annotation!r}"
  402. assert field_info is None, (
  403. f"Cannot specify FastAPI annotation for type {type_annotation!r}"
  404. )
  405. # Handle default assignations, neither field_info nor depends was not found in Annotated nor default value
  406. elif field_info is None and depends is None:
  407. default_value = value if value is not inspect.Signature.empty else RequiredParam
  408. if is_path_param:
  409. # We might check here that `default_value is RequiredParam`, but the fact is that the same
  410. # parameter might sometimes be a path parameter and sometimes not. See
  411. # `tests/test_infer_param_optionality.py` for an example.
  412. field_info = params.Path(annotation=use_annotation)
  413. elif is_uploadfile_or_nonable_uploadfile_annotation(
  414. type_annotation
  415. ) or is_uploadfile_sequence_annotation(type_annotation):
  416. field_info = params.File(annotation=use_annotation, default=default_value)
  417. elif not field_annotation_is_scalar(annotation=type_annotation):
  418. field_info = params.Body(annotation=use_annotation, default=default_value)
  419. else:
  420. field_info = params.Query(annotation=use_annotation, default=default_value)
  421. field = None
  422. # It's a field_info, not a dependency
  423. if field_info is not None:
  424. # Handle field_info.in_
  425. if is_path_param:
  426. assert isinstance(field_info, params.Path), (
  427. f"Cannot use `{field_info.__class__.__name__}` for path param"
  428. f" {param_name!r}"
  429. )
  430. elif (
  431. isinstance(field_info, params.Param)
  432. and getattr(field_info, "in_", None) is None
  433. ):
  434. field_info.in_ = params.ParamTypes.query
  435. use_annotation_from_field_info = get_annotation_from_field_info(
  436. use_annotation,
  437. field_info,
  438. param_name,
  439. )
  440. if isinstance(field_info, params.Form):
  441. ensure_multipart_is_installed()
  442. if not field_info.alias and getattr(field_info, "convert_underscores", None):
  443. alias = param_name.replace("_", "-")
  444. else:
  445. alias = field_info.alias or param_name
  446. field_info.alias = alias
  447. field = create_model_field(
  448. name=param_name,
  449. type_=use_annotation_from_field_info,
  450. default=field_info.default,
  451. alias=alias,
  452. required=field_info.default in (RequiredParam, Undefined),
  453. field_info=field_info,
  454. )
  455. if is_path_param:
  456. assert is_scalar_field(field=field), (
  457. "Path params must be of one of the supported types"
  458. )
  459. elif isinstance(field_info, params.Query):
  460. assert (
  461. is_scalar_field(field)
  462. or is_scalar_sequence_field(field)
  463. or (
  464. lenient_issubclass(field.type_, BaseModel)
  465. # For Pydantic v1
  466. and getattr(field, "shape", 1) == 1
  467. )
  468. )
  469. return ParamDetails(type_annotation=type_annotation, depends=depends, field=field)
  470. def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
  471. field_info = field.field_info
  472. field_info_in = getattr(field_info, "in_", None)
  473. if field_info_in == params.ParamTypes.path:
  474. dependant.path_params.append(field)
  475. elif field_info_in == params.ParamTypes.query:
  476. dependant.query_params.append(field)
  477. elif field_info_in == params.ParamTypes.header:
  478. dependant.header_params.append(field)
  479. else:
  480. assert field_info_in == params.ParamTypes.cookie, (
  481. f"non-body parameters must be in path, query, header or cookie: {field.name}"
  482. )
  483. dependant.cookie_params.append(field)
  484. def is_coroutine_callable(call: Callable[..., Any]) -> bool:
  485. if inspect.isroutine(call):
  486. return inspect.iscoroutinefunction(call)
  487. if inspect.isclass(call):
  488. return False
  489. dunder_call = getattr(call, "__call__", None) # noqa: B004
  490. return inspect.iscoroutinefunction(dunder_call)
  491. def is_async_gen_callable(call: Callable[..., Any]) -> bool:
  492. if inspect.isasyncgenfunction(call):
  493. return True
  494. dunder_call = getattr(call, "__call__", None) # noqa: B004
  495. return inspect.isasyncgenfunction(dunder_call)
  496. def is_gen_callable(call: Callable[..., Any]) -> bool:
  497. if inspect.isgeneratorfunction(call):
  498. return True
  499. dunder_call = getattr(call, "__call__", None) # noqa: B004
  500. return inspect.isgeneratorfunction(dunder_call)
  501. async def solve_generator(
  502. *, call: Callable[..., Any], stack: AsyncExitStack, sub_values: Dict[str, Any]
  503. ) -> Any:
  504. if is_gen_callable(call):
  505. cm = contextmanager_in_threadpool(contextmanager(call)(**sub_values))
  506. elif is_async_gen_callable(call):
  507. cm = asynccontextmanager(call)(**sub_values)
  508. return await stack.enter_async_context(cm)
  509. @dataclass
  510. class SolvedDependency:
  511. values: Dict[str, Any]
  512. errors: List[Any]
  513. background_tasks: Optional[StarletteBackgroundTasks]
  514. response: Response
  515. dependency_cache: Dict[Tuple[Callable[..., Any], Tuple[str]], Any]
  516. async def solve_dependencies(
  517. *,
  518. request: Union[Request, WebSocket],
  519. dependant: Dependant,
  520. body: Optional[Union[Dict[str, Any], FormData]] = None,
  521. background_tasks: Optional[StarletteBackgroundTasks] = None,
  522. response: Optional[Response] = None,
  523. dependency_overrides_provider: Optional[Any] = None,
  524. dependency_cache: Optional[Dict[Tuple[Callable[..., Any], Tuple[str]], Any]] = None,
  525. async_exit_stack: AsyncExitStack,
  526. embed_body_fields: bool,
  527. ) -> SolvedDependency:
  528. values: Dict[str, Any] = {}
  529. errors: List[Any] = []
  530. if response is None:
  531. response = Response()
  532. del response.headers["content-length"]
  533. response.status_code = None # type: ignore
  534. dependency_cache = dependency_cache or {}
  535. sub_dependant: Dependant
  536. for sub_dependant in dependant.dependencies:
  537. sub_dependant.call = cast(Callable[..., Any], sub_dependant.call)
  538. sub_dependant.cache_key = cast(
  539. Tuple[Callable[..., Any], Tuple[str]], sub_dependant.cache_key
  540. )
  541. call = sub_dependant.call
  542. use_sub_dependant = sub_dependant
  543. if (
  544. dependency_overrides_provider
  545. and dependency_overrides_provider.dependency_overrides
  546. ):
  547. original_call = sub_dependant.call
  548. call = getattr(
  549. dependency_overrides_provider, "dependency_overrides", {}
  550. ).get(original_call, original_call)
  551. use_path: str = sub_dependant.path # type: ignore
  552. use_sub_dependant = get_dependant(
  553. path=use_path,
  554. call=call,
  555. name=sub_dependant.name,
  556. security_scopes=sub_dependant.security_scopes,
  557. )
  558. solved_result = await solve_dependencies(
  559. request=request,
  560. dependant=use_sub_dependant,
  561. body=body,
  562. background_tasks=background_tasks,
  563. response=response,
  564. dependency_overrides_provider=dependency_overrides_provider,
  565. dependency_cache=dependency_cache,
  566. async_exit_stack=async_exit_stack,
  567. embed_body_fields=embed_body_fields,
  568. )
  569. background_tasks = solved_result.background_tasks
  570. dependency_cache.update(solved_result.dependency_cache)
  571. if solved_result.errors:
  572. errors.extend(solved_result.errors)
  573. continue
  574. if sub_dependant.use_cache and sub_dependant.cache_key in dependency_cache:
  575. solved = dependency_cache[sub_dependant.cache_key]
  576. elif is_gen_callable(call) or is_async_gen_callable(call):
  577. solved = await solve_generator(
  578. call=call, stack=async_exit_stack, sub_values=solved_result.values
  579. )
  580. elif is_coroutine_callable(call):
  581. solved = await call(**solved_result.values)
  582. else:
  583. solved = await run_in_threadpool(call, **solved_result.values)
  584. if sub_dependant.name is not None:
  585. values[sub_dependant.name] = solved
  586. if sub_dependant.cache_key not in dependency_cache:
  587. dependency_cache[sub_dependant.cache_key] = solved
  588. path_values, path_errors = request_params_to_args(
  589. dependant.path_params, request.path_params
  590. )
  591. query_values, query_errors = request_params_to_args(
  592. dependant.query_params, request.query_params
  593. )
  594. header_values, header_errors = request_params_to_args(
  595. dependant.header_params, request.headers
  596. )
  597. cookie_values, cookie_errors = request_params_to_args(
  598. dependant.cookie_params, request.cookies
  599. )
  600. values.update(path_values)
  601. values.update(query_values)
  602. values.update(header_values)
  603. values.update(cookie_values)
  604. errors += path_errors + query_errors + header_errors + cookie_errors
  605. if dependant.body_params:
  606. (
  607. body_values,
  608. body_errors,
  609. ) = await request_body_to_args( # body_params checked above
  610. body_fields=dependant.body_params,
  611. received_body=body,
  612. embed_body_fields=embed_body_fields,
  613. )
  614. values.update(body_values)
  615. errors.extend(body_errors)
  616. if dependant.http_connection_param_name:
  617. values[dependant.http_connection_param_name] = request
  618. if dependant.request_param_name and isinstance(request, Request):
  619. values[dependant.request_param_name] = request
  620. elif dependant.websocket_param_name and isinstance(request, WebSocket):
  621. values[dependant.websocket_param_name] = request
  622. if dependant.background_tasks_param_name:
  623. if background_tasks is None:
  624. background_tasks = BackgroundTasks()
  625. values[dependant.background_tasks_param_name] = background_tasks
  626. if dependant.response_param_name:
  627. values[dependant.response_param_name] = response
  628. if dependant.security_scopes_param_name:
  629. values[dependant.security_scopes_param_name] = SecurityScopes(
  630. scopes=dependant.security_scopes
  631. )
  632. return SolvedDependency(
  633. values=values,
  634. errors=errors,
  635. background_tasks=background_tasks,
  636. response=response,
  637. dependency_cache=dependency_cache,
  638. )
  639. def _validate_value_with_model_field(
  640. *, field: ModelField, value: Any, values: Dict[str, Any], loc: Tuple[str, ...]
  641. ) -> Tuple[Any, List[Any]]:
  642. if value is None:
  643. if field.required:
  644. return None, [get_missing_field_error(loc=loc)]
  645. else:
  646. return deepcopy(field.default), []
  647. v_, errors_ = field.validate(value, values, loc=loc)
  648. if isinstance(errors_, ErrorWrapper):
  649. return None, [errors_]
  650. elif isinstance(errors_, list):
  651. new_errors = _regenerate_error_with_loc(errors=errors_, loc_prefix=())
  652. return None, new_errors
  653. else:
  654. return v_, []
  655. def _get_multidict_value(
  656. field: ModelField, values: Mapping[str, Any], alias: Union[str, None] = None
  657. ) -> Any:
  658. alias = alias or field.alias
  659. if is_sequence_field(field) and isinstance(values, (ImmutableMultiDict, Headers)):
  660. value = values.getlist(alias)
  661. else:
  662. value = values.get(alias, None)
  663. if (
  664. value is None
  665. or (
  666. isinstance(field.field_info, params.Form)
  667. and isinstance(value, str) # For type checks
  668. and value == ""
  669. )
  670. or (is_sequence_field(field) and len(value) == 0)
  671. ):
  672. if field.required:
  673. return
  674. else:
  675. return deepcopy(field.default)
  676. return value
  677. def request_params_to_args(
  678. fields: Sequence[ModelField],
  679. received_params: Union[Mapping[str, Any], QueryParams, Headers],
  680. ) -> Tuple[Dict[str, Any], List[Any]]:
  681. values: Dict[str, Any] = {}
  682. errors: List[Dict[str, Any]] = []
  683. if not fields:
  684. return values, errors
  685. first_field = fields[0]
  686. fields_to_extract = fields
  687. single_not_embedded_field = False
  688. default_convert_underscores = True
  689. if len(fields) == 1 and lenient_issubclass(first_field.type_, BaseModel):
  690. fields_to_extract = get_cached_model_fields(first_field.type_)
  691. single_not_embedded_field = True
  692. # If headers are in a Pydantic model, the way to disable convert_underscores
  693. # would be with Header(convert_underscores=False) at the Pydantic model level
  694. default_convert_underscores = getattr(
  695. first_field.field_info, "convert_underscores", True
  696. )
  697. params_to_process: Dict[str, Any] = {}
  698. processed_keys = set()
  699. for field in fields_to_extract:
  700. alias = None
  701. if isinstance(received_params, Headers):
  702. # Handle fields extracted from a Pydantic Model for a header, each field
  703. # doesn't have a FieldInfo of type Header with the default convert_underscores=True
  704. convert_underscores = getattr(
  705. field.field_info, "convert_underscores", default_convert_underscores
  706. )
  707. if convert_underscores:
  708. alias = (
  709. field.alias
  710. if field.alias != field.name
  711. else field.name.replace("_", "-")
  712. )
  713. value = _get_multidict_value(field, received_params, alias=alias)
  714. if value is not None:
  715. params_to_process[field.name] = value
  716. processed_keys.add(alias or field.alias)
  717. processed_keys.add(field.name)
  718. for key, value in received_params.items():
  719. if key not in processed_keys:
  720. params_to_process[key] = value
  721. if single_not_embedded_field:
  722. field_info = first_field.field_info
  723. assert isinstance(field_info, params.Param), (
  724. "Params must be subclasses of Param"
  725. )
  726. loc: Tuple[str, ...] = (field_info.in_.value,)
  727. v_, errors_ = _validate_value_with_model_field(
  728. field=first_field, value=params_to_process, values=values, loc=loc
  729. )
  730. return {first_field.name: v_}, errors_
  731. for field in fields:
  732. value = _get_multidict_value(field, received_params)
  733. field_info = field.field_info
  734. assert isinstance(field_info, params.Param), (
  735. "Params must be subclasses of Param"
  736. )
  737. loc = (field_info.in_.value, field.alias)
  738. v_, errors_ = _validate_value_with_model_field(
  739. field=field, value=value, values=values, loc=loc
  740. )
  741. if errors_:
  742. errors.extend(errors_)
  743. else:
  744. values[field.name] = v_
  745. return values, errors
  746. def _should_embed_body_fields(fields: List[ModelField]) -> bool:
  747. if not fields:
  748. return False
  749. # More than one dependency could have the same field, it would show up as multiple
  750. # fields but it's the same one, so count them by name
  751. body_param_names_set = {field.name for field in fields}
  752. # A top level field has to be a single field, not multiple
  753. if len(body_param_names_set) > 1:
  754. return True
  755. first_field = fields[0]
  756. # If it explicitly specifies it is embedded, it has to be embedded
  757. if getattr(first_field.field_info, "embed", None):
  758. return True
  759. # If it's a Form (or File) field, it has to be a BaseModel to be top level
  760. # otherwise it has to be embedded, so that the key value pair can be extracted
  761. if isinstance(first_field.field_info, params.Form) and not lenient_issubclass(
  762. first_field.type_, BaseModel
  763. ):
  764. return True
  765. return False
  766. async def _extract_form_body(
  767. body_fields: List[ModelField],
  768. received_body: FormData,
  769. ) -> Dict[str, Any]:
  770. values = {}
  771. first_field = body_fields[0]
  772. first_field_info = first_field.field_info
  773. for field in body_fields:
  774. value = _get_multidict_value(field, received_body)
  775. if (
  776. isinstance(first_field_info, params.File)
  777. and is_bytes_field(field)
  778. and isinstance(value, UploadFile)
  779. ):
  780. value = await value.read()
  781. elif (
  782. is_bytes_sequence_field(field)
  783. and isinstance(first_field_info, params.File)
  784. and value_is_sequence(value)
  785. ):
  786. # For types
  787. assert isinstance(value, sequence_types) # type: ignore[arg-type]
  788. results: List[Union[bytes, str]] = []
  789. async def process_fn(
  790. fn: Callable[[], Coroutine[Any, Any, Any]],
  791. ) -> None:
  792. result = await fn()
  793. results.append(result) # noqa: B023
  794. async with anyio.create_task_group() as tg:
  795. for sub_value in value:
  796. tg.start_soon(process_fn, sub_value.read)
  797. value = serialize_sequence_value(field=field, value=results)
  798. if value is not None:
  799. values[field.alias] = value
  800. for key, value in received_body.items():
  801. if key not in values:
  802. values[key] = value
  803. return values
  804. async def request_body_to_args(
  805. body_fields: List[ModelField],
  806. received_body: Optional[Union[Dict[str, Any], FormData]],
  807. embed_body_fields: bool,
  808. ) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
  809. values: Dict[str, Any] = {}
  810. errors: List[Dict[str, Any]] = []
  811. assert body_fields, "request_body_to_args() should be called with fields"
  812. single_not_embedded_field = len(body_fields) == 1 and not embed_body_fields
  813. first_field = body_fields[0]
  814. body_to_process = received_body
  815. fields_to_extract: List[ModelField] = body_fields
  816. if single_not_embedded_field and lenient_issubclass(first_field.type_, BaseModel):
  817. fields_to_extract = get_cached_model_fields(first_field.type_)
  818. if isinstance(received_body, FormData):
  819. body_to_process = await _extract_form_body(fields_to_extract, received_body)
  820. if single_not_embedded_field:
  821. loc: Tuple[str, ...] = ("body",)
  822. v_, errors_ = _validate_value_with_model_field(
  823. field=first_field, value=body_to_process, values=values, loc=loc
  824. )
  825. return {first_field.name: v_}, errors_
  826. for field in body_fields:
  827. loc = ("body", field.alias)
  828. value: Optional[Any] = None
  829. if body_to_process is not None:
  830. try:
  831. value = body_to_process.get(field.alias)
  832. # If the received body is a list, not a dict
  833. except AttributeError:
  834. errors.append(get_missing_field_error(loc))
  835. continue
  836. v_, errors_ = _validate_value_with_model_field(
  837. field=field, value=value, values=values, loc=loc
  838. )
  839. if errors_:
  840. errors.extend(errors_)
  841. else:
  842. values[field.name] = v_
  843. return values, errors
  844. def get_body_field(
  845. *, flat_dependant: Dependant, name: str, embed_body_fields: bool
  846. ) -> Optional[ModelField]:
  847. """
  848. Get a ModelField representing the request body for a path operation, combining
  849. all body parameters into a single field if necessary.
  850. Used to check if it's form data (with `isinstance(body_field, params.Form)`)
  851. or JSON and to generate the JSON Schema for a request body.
  852. This is **not** used to validate/parse the request body, that's done with each
  853. individual body parameter.
  854. """
  855. if not flat_dependant.body_params:
  856. return None
  857. first_param = flat_dependant.body_params[0]
  858. if not embed_body_fields:
  859. return first_param
  860. model_name = "Body_" + name
  861. BodyModel = create_body_model(
  862. fields=flat_dependant.body_params, model_name=model_name
  863. )
  864. required = any(True for f in flat_dependant.body_params if f.required)
  865. BodyFieldInfo_kwargs: Dict[str, Any] = {
  866. "annotation": BodyModel,
  867. "alias": "body",
  868. }
  869. if not required:
  870. BodyFieldInfo_kwargs["default"] = None
  871. if any(isinstance(f.field_info, params.File) for f in flat_dependant.body_params):
  872. BodyFieldInfo: Type[params.Body] = params.File
  873. elif any(isinstance(f.field_info, params.Form) for f in flat_dependant.body_params):
  874. BodyFieldInfo = params.Form
  875. else:
  876. BodyFieldInfo = params.Body
  877. body_param_media_types = [
  878. f.field_info.media_type
  879. for f in flat_dependant.body_params
  880. if isinstance(f.field_info, params.Body)
  881. ]
  882. if len(set(body_param_media_types)) == 1:
  883. BodyFieldInfo_kwargs["media_type"] = body_param_media_types[0]
  884. final_field = create_model_field(
  885. name="body",
  886. type_=BodyModel,
  887. required=required,
  888. alias="body",
  889. field_info=BodyFieldInfo(**BodyFieldInfo_kwargs),
  890. )
  891. return final_field