image_processing_utils_fast.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import functools
  16. from dataclasses import dataclass
  17. from .image_processing_utils import BaseImageProcessor
  18. from .utils.import_utils import is_torchvision_available
  19. if is_torchvision_available():
  20. from torchvision.transforms import Compose
  21. @dataclass(frozen=True)
  22. class SizeDict:
  23. """
  24. Hashable dictionary to store image size information.
  25. """
  26. height: int = None
  27. width: int = None
  28. longest_edge: int = None
  29. shortest_edge: int = None
  30. max_height: int = None
  31. max_width: int = None
  32. def __getitem__(self, key):
  33. if hasattr(self, key):
  34. return getattr(self, key)
  35. raise KeyError(f"Key {key} not found in SizeDict.")
  36. class BaseImageProcessorFast(BaseImageProcessor):
  37. _transform_params = None
  38. def _build_transforms(self, **kwargs) -> "Compose":
  39. """
  40. Given the input settings e.g. do_resize, build the image transforms.
  41. """
  42. raise NotImplementedError
  43. def _validate_params(self, **kwargs) -> None:
  44. for k, v in kwargs.items():
  45. if k not in self._transform_params:
  46. raise ValueError(f"Invalid transform parameter {k}={v}.")
  47. @functools.lru_cache(maxsize=1)
  48. def get_transforms(self, **kwargs) -> "Compose":
  49. self._validate_params(**kwargs)
  50. return self._build_transforms(**kwargs)
  51. def to_dict(self):
  52. encoder_dict = super().to_dict()
  53. encoder_dict.pop("_transform_params", None)
  54. return encoder_dict