fft.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # mypy: ignore-errors
  2. from __future__ import annotations
  3. import functools
  4. import torch
  5. from . import _dtypes_impl, _util
  6. from ._normalizations import ArrayLike, normalizer
  7. def upcast(func):
  8. """NumPy fft casts inputs to 64 bit and *returns 64-bit results*."""
  9. @functools.wraps(func)
  10. def wrapped(tensor, *args, **kwds):
  11. target_dtype = (
  12. _dtypes_impl.default_dtypes().complex_dtype
  13. if tensor.is_complex()
  14. else _dtypes_impl.default_dtypes().float_dtype
  15. )
  16. tensor = _util.cast_if_needed(tensor, target_dtype)
  17. return func(tensor, *args, **kwds)
  18. return wrapped
  19. @normalizer
  20. @upcast
  21. def fft(a: ArrayLike, n=None, axis=-1, norm=None):
  22. return torch.fft.fft(a, n, dim=axis, norm=norm)
  23. @normalizer
  24. @upcast
  25. def ifft(a: ArrayLike, n=None, axis=-1, norm=None):
  26. return torch.fft.ifft(a, n, dim=axis, norm=norm)
  27. @normalizer
  28. @upcast
  29. def rfft(a: ArrayLike, n=None, axis=-1, norm=None):
  30. return torch.fft.rfft(a, n, dim=axis, norm=norm)
  31. @normalizer
  32. @upcast
  33. def irfft(a: ArrayLike, n=None, axis=-1, norm=None):
  34. return torch.fft.irfft(a, n, dim=axis, norm=norm)
  35. @normalizer
  36. @upcast
  37. def fftn(a: ArrayLike, s=None, axes=None, norm=None):
  38. return torch.fft.fftn(a, s, dim=axes, norm=norm)
  39. @normalizer
  40. @upcast
  41. def ifftn(a: ArrayLike, s=None, axes=None, norm=None):
  42. return torch.fft.ifftn(a, s, dim=axes, norm=norm)
  43. @normalizer
  44. @upcast
  45. def rfftn(a: ArrayLike, s=None, axes=None, norm=None):
  46. return torch.fft.rfftn(a, s, dim=axes, norm=norm)
  47. @normalizer
  48. @upcast
  49. def irfftn(a: ArrayLike, s=None, axes=None, norm=None):
  50. return torch.fft.irfftn(a, s, dim=axes, norm=norm)
  51. @normalizer
  52. @upcast
  53. def fft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
  54. return torch.fft.fft2(a, s, dim=axes, norm=norm)
  55. @normalizer
  56. @upcast
  57. def ifft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
  58. return torch.fft.ifft2(a, s, dim=axes, norm=norm)
  59. @normalizer
  60. @upcast
  61. def rfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
  62. return torch.fft.rfft2(a, s, dim=axes, norm=norm)
  63. @normalizer
  64. @upcast
  65. def irfft2(a: ArrayLike, s=None, axes=(-2, -1), norm=None):
  66. return torch.fft.irfft2(a, s, dim=axes, norm=norm)
  67. @normalizer
  68. @upcast
  69. def hfft(a: ArrayLike, n=None, axis=-1, norm=None):
  70. return torch.fft.hfft(a, n, dim=axis, norm=norm)
  71. @normalizer
  72. @upcast
  73. def ihfft(a: ArrayLike, n=None, axis=-1, norm=None):
  74. return torch.fft.ihfft(a, n, dim=axis, norm=norm)
  75. @normalizer
  76. def fftfreq(n, d=1.0):
  77. return torch.fft.fftfreq(n, d)
  78. @normalizer
  79. def rfftfreq(n, d=1.0):
  80. return torch.fft.rfftfreq(n, d)
  81. @normalizer
  82. def fftshift(x: ArrayLike, axes=None):
  83. return torch.fft.fftshift(x, axes)
  84. @normalizer
  85. def ifftshift(x: ArrayLike, axes=None):
  86. return torch.fft.ifftshift(x, axes)