| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115 |
- from __future__ import annotations as _annotations
- import functools
- import inspect
- from functools import partial
- from typing import Any, Awaitable, Callable
- import pydantic_core
- from ..config import ConfigDict
- from ..plugin._schema_validator import create_schema_validator
- from ._config import ConfigWrapper
- from ._generate_schema import GenerateSchema, ValidateCallSupportedTypes
- from ._namespace_utils import MappingNamespace, NsResolver, ns_for_function
- def extract_function_name(func: ValidateCallSupportedTypes) -> str:
- """Extract the name of a `ValidateCallSupportedTypes` object."""
- return f'partial({func.func.__name__})' if isinstance(func, functools.partial) else func.__name__
- def extract_function_qualname(func: ValidateCallSupportedTypes) -> str:
- """Extract the qualname of a `ValidateCallSupportedTypes` object."""
- return f'partial({func.func.__qualname__})' if isinstance(func, functools.partial) else func.__qualname__
- def update_wrapper_attributes(wrapped: ValidateCallSupportedTypes, wrapper: Callable[..., Any]):
- """Update the `wrapper` function with the attributes of the `wrapped` function. Return the updated function."""
- if inspect.iscoroutinefunction(wrapped):
- @functools.wraps(wrapped)
- async def wrapper_function(*args, **kwargs): # type: ignore
- return await wrapper(*args, **kwargs)
- else:
- @functools.wraps(wrapped)
- def wrapper_function(*args, **kwargs):
- return wrapper(*args, **kwargs)
- # We need to manually update this because `partial` object has no `__name__` and `__qualname__`.
- wrapper_function.__name__ = extract_function_name(wrapped)
- wrapper_function.__qualname__ = extract_function_qualname(wrapped)
- wrapper_function.raw_function = wrapped # type: ignore
- return wrapper_function
- class ValidateCallWrapper:
- """This is a wrapper around a function that validates the arguments passed to it, and optionally the return value."""
- __slots__ = ('__pydantic_validator__', '__return_pydantic_validator__')
- def __init__(
- self,
- function: ValidateCallSupportedTypes,
- config: ConfigDict | None,
- validate_return: bool,
- parent_namespace: MappingNamespace | None,
- ) -> None:
- if isinstance(function, partial):
- schema_type = function.func
- module = function.func.__module__
- else:
- schema_type = function
- module = function.__module__
- qualname = extract_function_qualname(function)
- ns_resolver = NsResolver(namespaces_tuple=ns_for_function(schema_type, parent_namespace=parent_namespace))
- config_wrapper = ConfigWrapper(config)
- gen_schema = GenerateSchema(config_wrapper, ns_resolver)
- schema = gen_schema.clean_schema(gen_schema.generate_schema(function))
- core_config = config_wrapper.core_config(title=qualname)
- self.__pydantic_validator__ = create_schema_validator(
- schema,
- schema_type,
- module,
- qualname,
- 'validate_call',
- core_config,
- config_wrapper.plugin_settings,
- )
- if validate_return:
- signature = inspect.signature(function)
- return_type = signature.return_annotation if signature.return_annotation is not signature.empty else Any
- gen_schema = GenerateSchema(config_wrapper, ns_resolver)
- schema = gen_schema.clean_schema(gen_schema.generate_schema(return_type))
- validator = create_schema_validator(
- schema,
- schema_type,
- module,
- qualname,
- 'validate_call',
- core_config,
- config_wrapper.plugin_settings,
- )
- if inspect.iscoroutinefunction(function):
- async def return_val_wrapper(aw: Awaitable[Any]) -> None:
- return validator.validate_python(await aw)
- self.__return_pydantic_validator__ = return_val_wrapper
- else:
- self.__return_pydantic_validator__ = validator.validate_python
- else:
- self.__return_pydantic_validator__ = None
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
- if self.__return_pydantic_validator__:
- return self.__return_pydantic_validator__(res)
- else:
- return res
|