| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139 |
- """Pluggable schema validator for pydantic."""
- from __future__ import annotations
- import functools
- from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeVar
- from pydantic_core import CoreConfig, CoreSchema, SchemaValidator, ValidationError
- from typing_extensions import Literal, ParamSpec
- if TYPE_CHECKING:
- from . import BaseValidateHandlerProtocol, PydanticPluginProtocol, SchemaKind, SchemaTypePath
- P = ParamSpec('P')
- R = TypeVar('R')
- Event = Literal['on_validate_python', 'on_validate_json', 'on_validate_strings']
- events: list[Event] = list(Event.__args__) # type: ignore
- def create_schema_validator(
- schema: CoreSchema,
- schema_type: Any,
- schema_type_module: str,
- schema_type_name: str,
- schema_kind: SchemaKind,
- config: CoreConfig | None = None,
- plugin_settings: dict[str, Any] | None = None,
- ) -> SchemaValidator | PluggableSchemaValidator:
- """Create a `SchemaValidator` or `PluggableSchemaValidator` if plugins are installed.
- Returns:
- If plugins are installed then return `PluggableSchemaValidator`, otherwise return `SchemaValidator`.
- """
- from . import SchemaTypePath
- from ._loader import get_plugins
- plugins = get_plugins()
- if plugins:
- return PluggableSchemaValidator(
- schema,
- schema_type,
- SchemaTypePath(schema_type_module, schema_type_name),
- schema_kind,
- config,
- plugins,
- plugin_settings or {},
- )
- else:
- return SchemaValidator(schema, config)
- class PluggableSchemaValidator:
- """Pluggable schema validator."""
- __slots__ = '_schema_validator', 'validate_json', 'validate_python', 'validate_strings'
- def __init__(
- self,
- schema: CoreSchema,
- schema_type: Any,
- schema_type_path: SchemaTypePath,
- schema_kind: SchemaKind,
- config: CoreConfig | None,
- plugins: Iterable[PydanticPluginProtocol],
- plugin_settings: dict[str, Any],
- ) -> None:
- self._schema_validator = SchemaValidator(schema, config)
- python_event_handlers: list[BaseValidateHandlerProtocol] = []
- json_event_handlers: list[BaseValidateHandlerProtocol] = []
- strings_event_handlers: list[BaseValidateHandlerProtocol] = []
- for plugin in plugins:
- try:
- p, j, s = plugin.new_schema_validator(
- schema, schema_type, schema_type_path, schema_kind, config, plugin_settings
- )
- except TypeError as e: # pragma: no cover
- raise TypeError(f'Error using plugin `{plugin.__module__}:{plugin.__class__.__name__}`: {e}') from e
- if p is not None:
- python_event_handlers.append(p)
- if j is not None:
- json_event_handlers.append(j)
- if s is not None:
- strings_event_handlers.append(s)
- self.validate_python = build_wrapper(self._schema_validator.validate_python, python_event_handlers)
- self.validate_json = build_wrapper(self._schema_validator.validate_json, json_event_handlers)
- self.validate_strings = build_wrapper(self._schema_validator.validate_strings, strings_event_handlers)
- def __getattr__(self, name: str) -> Any:
- return getattr(self._schema_validator, name)
- def build_wrapper(func: Callable[P, R], event_handlers: list[BaseValidateHandlerProtocol]) -> Callable[P, R]:
- if not event_handlers:
- return func
- else:
- on_enters = tuple(h.on_enter for h in event_handlers if filter_handlers(h, 'on_enter'))
- on_successes = tuple(h.on_success for h in event_handlers if filter_handlers(h, 'on_success'))
- on_errors = tuple(h.on_error for h in event_handlers if filter_handlers(h, 'on_error'))
- on_exceptions = tuple(h.on_exception for h in event_handlers if filter_handlers(h, 'on_exception'))
- @functools.wraps(func)
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
- for on_enter_handler in on_enters:
- on_enter_handler(*args, **kwargs)
- try:
- result = func(*args, **kwargs)
- except ValidationError as error:
- for on_error_handler in on_errors:
- on_error_handler(error)
- raise
- except Exception as exception:
- for on_exception_handler in on_exceptions:
- on_exception_handler(exception)
- raise
- else:
- for on_success_handler in on_successes:
- on_success_handler(result)
- return result
- return wrapper
- def filter_handlers(handler_cls: BaseValidateHandlerProtocol, method_name: str) -> bool:
- """Filter out handler methods which are not implemented by the plugin directly - e.g. are missing
- or are inherited from the protocol.
- """
- handler = getattr(handler_cls, method_name, None)
- if handler is None:
- return False
- elif handler.__module__ == 'pydantic.plugin':
- # this is the original handler, from the protocol due to runtime inheritance
- # we don't want to call it
- return False
- else:
- return True
|