torch_version.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # mypy: ignore-errors
  2. from typing import Any, Iterable
  3. from ._vendor.packaging.version import InvalidVersion, Version
  4. from .version import __version__ as internal_version
  5. __all__ = ["TorchVersion"]
  6. class TorchVersion(str):
  7. """A string with magic powers to compare to both Version and iterables!
  8. Prior to 1.10.0 torch.__version__ was stored as a str and so many did
  9. comparisons against torch.__version__ as if it were a str. In order to not
  10. break them we have TorchVersion which masquerades as a str while also
  11. having the ability to compare against both packaging.version.Version as
  12. well as tuples of values, eg. (1, 2, 1)
  13. Examples:
  14. Comparing a TorchVersion object to a Version object
  15. TorchVersion('1.10.0a') > Version('1.10.0a')
  16. Comparing a TorchVersion object to a Tuple object
  17. TorchVersion('1.10.0a') > (1, 2) # 1.2
  18. TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
  19. Comparing a TorchVersion object against a string
  20. TorchVersion('1.10.0a') > '1.2'
  21. TorchVersion('1.10.0a') > '1.2.1'
  22. """
  23. # fully qualified type names here to appease mypy
  24. def _convert_to_version(self, inp: Any) -> Any:
  25. if isinstance(inp, Version):
  26. return inp
  27. elif isinstance(inp, str):
  28. return Version(inp)
  29. elif isinstance(inp, Iterable):
  30. # Ideally this should work for most cases by attempting to group
  31. # the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
  32. # Examples:
  33. # * (1) -> Version("1")
  34. # * (1, 20) -> Version("1.20")
  35. # * (1, 20, 1) -> Version("1.20.1")
  36. return Version(".".join(str(item) for item in inp))
  37. else:
  38. raise InvalidVersion(inp)
  39. def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
  40. try:
  41. return getattr(Version(self), method)(self._convert_to_version(cmp))
  42. except BaseException as e:
  43. if not isinstance(e, InvalidVersion):
  44. raise
  45. # Fall back to regular string comparison if dealing with an invalid
  46. # version like 'parrot'
  47. return getattr(super(), method)(cmp)
  48. for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
  49. setattr(
  50. TorchVersion,
  51. cmp_method,
  52. lambda x, y, method=cmp_method: x._cmp_wrapper(y, method),
  53. )
  54. __version__ = TorchVersion(internal_version)