_conversions.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # mypy: allow-untyped-defs
  2. import torch
  3. import torch._prims_common as utils
  4. # Utilities should come BEFORE this import
  5. from torch._decomp import register_decomposition
  6. from torch._prims_common import TensorLikeType
  7. from torch._prims_common.wrappers import out_wrapper
  8. from torch._refs import _broadcast_shapes
  9. # Data conversion references.
  10. #
  11. # Note: this module breaks the usual _refs to torch naming scheme where
  12. # _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not
  13. # part of _refs/__init__.py to avoid name clashes with Python builtin types
  14. # (like int).
  15. __all__ = [
  16. # dtypes
  17. "bfloat16",
  18. "bool",
  19. "byte",
  20. "cdouble",
  21. "cfloat",
  22. "chalf",
  23. "char",
  24. "double",
  25. "float",
  26. "half",
  27. "int",
  28. "long",
  29. "short",
  30. # misc
  31. "complex",
  32. "polar",
  33. ]
  34. def _make_conversion_method(name: str, dtype: torch.dtype):
  35. def fn(
  36. self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
  37. ) -> TensorLikeType:
  38. return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload]
  39. fn.__name__ = name
  40. return fn
  41. bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
  42. bool = _make_conversion_method("bool", torch.bool)
  43. byte = _make_conversion_method("byte", torch.uint8)
  44. cdouble = _make_conversion_method("cdouble", torch.cdouble)
  45. cfloat = _make_conversion_method("cfloat", torch.cfloat)
  46. chalf = _make_conversion_method("chalf", torch.complex32)
  47. char = _make_conversion_method("char", torch.int8)
  48. double = _make_conversion_method("double", torch.double)
  49. float = _make_conversion_method("float", torch.float)
  50. half = _make_conversion_method("half", torch.half)
  51. int = _make_conversion_method("int", torch.int)
  52. long = _make_conversion_method("long", torch.long)
  53. short = _make_conversion_method("short", torch.short)
  54. @register_decomposition(torch._ops.ops.aten.complex)
  55. # Note: complex has type promotion tests disabled due to different semantics.
  56. # exact_dtype is for compat with complex_check_dtype from core.
  57. @out_wrapper(exact_dtype=True)
  58. def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
  59. allowed_dtypes = (torch.float32, torch.float64, torch.float16)
  60. torch._check(
  61. real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
  62. lambda: (
  63. f"Expected both inputs to be Half, Float or Double tensors but got "
  64. f"{real.dtype} and {imag.dtype}"
  65. ),
  66. )
  67. torch._check(
  68. real.dtype == imag.dtype,
  69. lambda: (
  70. f"Expected object of scalar type {real.dtype} but got "
  71. f"scalar type {imag.dtype} for second argument"
  72. ),
  73. )
  74. result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type]
  75. common_shape = _broadcast_shapes(real.shape, imag.shape)
  76. result = real.new_empty(
  77. common_shape,
  78. dtype=result_dtype,
  79. layout=real.layout,
  80. device=real.device,
  81. # pin_memory=real.is_pinned(), # NYI
  82. )
  83. result.real = real
  84. result.imag = imag
  85. return result
  86. @register_decomposition(torch._ops.ops.aten.polar)
  87. # Note: polar has type promotion tests disabled due to different semantics.
  88. # exact_dtype is for compat with complex_check_dtype from core.
  89. @out_wrapper(exact_dtype=True)
  90. def polar(abs: TensorLikeType, angle: TensorLikeType) -> TensorLikeType:
  91. result = torch.complex(abs, angle)
  92. result.real = abs * torch.cos(angle)
  93. result.imag = abs * torch.sin(angle)
  94. return result