_validate_call.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. from __future__ import annotations as _annotations
  2. import functools
  3. import inspect
  4. from functools import partial
  5. from typing import Any, Awaitable, Callable
  6. import pydantic_core
  7. from ..config import ConfigDict
  8. from ..plugin._schema_validator import create_schema_validator
  9. from ._config import ConfigWrapper
  10. from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
  11. from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
  12. def extract_function_name(func: ValidateCallSupportedTypes) -> str:
  13. """Extract the name of a `ValidateCallSupportedTypes` object."""
  14. return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
  15. def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
  16. """Extract the qualname of a `ValidateCallSupportedTypes` object."""
  17. return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
  18. def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
  19. """Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
  20. if inspect.iscoroutinefunction(wrapped):
  21. @functools.wraps(wrapped)
  22. async def wrapper_function(*args, **kwargs): # type: ignore
  23. return await wrapper(*args, **kwargs)
  24. else:
  25. @functools.wraps(wrapped)
  26. def wrapper_function(*args, **kwargs):
  27. return wrapper(*args, **kwargs)
  28. # We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
  29. wrapper_function.__name__ = extract_function_name(wrapped)
  30. wrapper_function.__qualname__ = extract_function_qualname(wrapped)
  31. wrapper_function.raw_function = wrapped # type: ignore
  32. return wrapper_function
  33. class ValidateCallWrapper:
  34. """This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
  35. __slots__ = ('__pydantic_validator__', '__return_pydantic_validator__')
  36. def __init__(
  37. self,
  38. function: ValidateCallSupportedTypes,
  39. config: ConfigDict | None,
  40. validate_return: bool,
  41. parent_namespace: MappingNamespace | None,
  42. ) -> None:
  43. if isinstance(function, partial):
  44. schema_type = function.func
  45. module = function.func.__module__
  46. else:
  47. schema_type = function
  48. module = function.__module__
  49. qualname = extract_function_qualname(function)
  50. ns_resolver = NsResolver(namespaces_tuple=ns_for_function(schema_type, parent_namespace=parent_namespace))
  51. config_wrapper = ConfigWrapper(config)
  52. gen_schema = GenerateSchema(config_wrapper, ns_resolver)
  53. schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
  54. core_config = config_wrapper.core_config(title=qualname)
  55. self.__pydantic_validator__ = create_schema_validator(
  56. schema,
  57. schema_type,
  58. module,
  59. qualname,
  60. 'validate_call',
  61. core_config,
  62. config_wrapper.plugin_settings,
  63. )
  64. if validate_return:
  65. signature = inspect.signature(function)
  66. return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
  67. gen_schema = GenerateSchema(config_wrapper, ns_resolver)
  68. schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
  69. validator = create_schema_validator(
  70. schema,
  71. schema_type,
  72. module,
  73. qualname,
  74. 'validate_call',
  75. core_config,
  76. config_wrapper.plugin_settings,
  77. )
  78. if inspect.iscoroutinefunction(function):
  79. async def return_val_wrapper(aw: Awaitable[Any]) -> None:
  80. return validator.validate_python(await aw)
  81. self.__return_pydantic_validator__ = return_val_wrapper
  82. else:
  83. self.__return_pydantic_validator__ = validator.validate_python
  84. else:
  85. self.__return_pydantic_validator__ = None
  86. def __call__(self, *args: Any, **kwargs: Any) -> Any:
  87. res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
  88. if self.__return_pydantic_validator__:
  89. return self.__return_pydantic_validator__(res)
  90. else:
  91. return res