fft.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590
  1. import math
  2. from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union
  3. import torch
  4. import torch._prims as prims
  5. import torch._prims_common as utils
  6. from torch._decomp import register_decomposition
  7. from torch._prims_common import DimsType, ShapeType, TensorLikeType
  8. from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
  9. __all__ = [
  10. # Transforms
  11. "fft",
  12. "fft2",
  13. "fftn",
  14. "hfft",
  15. "hfft2",
  16. "hfftn",
  17. "rfft",
  18. "rfft2",
  19. "rfftn",
  20. "ifft",
  21. "ifft2",
  22. "ifftn",
  23. "ihfft",
  24. "ihfft2",
  25. "ihfftn",
  26. "irfft",
  27. "irfft2",
  28. "irfftn",
  29. # Helpers
  30. "fftshift",
  31. "ifftshift",
  32. ]
  33. NormType = Union[None, Literal["forward", "backward", "ortho"]]
  34. _NORM_VALUES = {None, "forward", "backward", "ortho"}
  35. aten = torch._ops.ops.aten
  36. def _apply_norm(
  37. x: TensorLikeType, norm: NormType, signal_numel: int, forward: bool
  38. ) -> TensorLikeType:
  39. """Apply normalization to the un-normalized FFT result"""
  40. torch._check(norm in _NORM_VALUES, lambda: f"Invalid normalization mode: {norm}")
  41. if norm == "ortho":
  42. return x * (1 / math.sqrt(signal_numel))
  43. normalize = (not forward and (norm is None or norm == "backward")) or (
  44. forward and norm == "forward"
  45. )
  46. return x * (1 / signal_numel) if normalize else x
  47. def _promote_type_fft(
  48. dtype: torch.dtype, require_complex: bool, device: torch.device
  49. ) -> torch.dtype:
  50. """Helper to promote a dtype to one supported by the FFT primitives"""
  51. if dtype.is_complex:
  52. return dtype
  53. # Promote integral to default float type
  54. if not dtype.is_floating_point:
  55. dtype = torch.get_default_dtype()
  56. allowed_types = [torch.float32, torch.float64]
  57. maybe_support_half = device.type in ["cuda", "meta"]
  58. if maybe_support_half:
  59. allowed_types.append(torch.float16)
  60. torch._check(dtype in allowed_types, lambda: f"Unsupported dtype {dtype}")
  61. if require_complex:
  62. dtype = utils.corresponding_complex_dtype(dtype)
  63. return dtype
  64. def _maybe_promote_tensor_fft(
  65. t: TensorLikeType, require_complex: bool = False
  66. ) -> TensorLikeType:
  67. """Helper to promote a tensor to a dtype supported by the FFT primitives"""
  68. cur_type = t.dtype
  69. new_type = _promote_type_fft(cur_type, require_complex, t.device)
  70. return _maybe_convert_to_dtype(t, new_type) # type: ignore[return-value]
  71. def _resize_fft_input(
  72. x: TensorLikeType, dims: Tuple[int, ...], sizes: Tuple[int, ...]
  73. ) -> TensorLikeType:
  74. """
  75. Fixes the shape of x such that x.size(dims[i]) == sizes[i],
  76. either by zero-padding, or by slicing x starting from 0.
  77. """
  78. assert len(dims) == len(sizes)
  79. must_copy = False
  80. x_sizes = x.shape
  81. pad_amount = [0] * len(x_sizes) * 2
  82. for i in range(len(dims)):
  83. if sizes[i] == -1:
  84. continue
  85. if x_sizes[dims[i]] < sizes[i]:
  86. must_copy = True
  87. pad_idx = len(pad_amount) - 2 * dims[i] - 1
  88. pad_amount[pad_idx] = sizes[i] - x_sizes[dims[i]]
  89. if x_sizes[dims[i]] > sizes[i]:
  90. x = x.narrow(dims[i], 0, sizes[i])
  91. return torch.constant_pad_nd(x, pad_amount) if must_copy else x
  92. def _fft_c2r(
  93. func_name: str,
  94. input: TensorLikeType,
  95. n: Optional[int],
  96. dim: int,
  97. norm: NormType,
  98. forward: bool,
  99. ) -> TensorLikeType:
  100. """Common code for performing any complex to real FFT (irfft or hfft)"""
  101. input = _maybe_promote_tensor_fft(input, require_complex=True)
  102. dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
  103. last_dim_size = n if n is not None else 2 * (input.shape[dim] - 1)
  104. torch._check(
  105. last_dim_size >= 1,
  106. lambda: f"Invalid number of data points ({last_dim_size}) specified",
  107. )
  108. if n is not None:
  109. input = _resize_fft_input(input, dims=dims, sizes=(last_dim_size // 2 + 1,))
  110. if forward:
  111. input = torch.conj(input)
  112. output = prims.fft_c2r(input, dim=dims, last_dim_size=last_dim_size)
  113. return _apply_norm(output, norm=norm, signal_numel=last_dim_size, forward=forward)
  114. def _fft_r2c(
  115. func_name: str,
  116. input: TensorLikeType,
  117. n: Optional[int],
  118. dim: int,
  119. norm: NormType,
  120. forward: bool,
  121. onesided: bool,
  122. ) -> TensorLikeType:
  123. """Common code for performing any real to complex FFT (rfft or ihfft)"""
  124. torch._check(
  125. not input.dtype.is_complex,
  126. lambda: f"{func_name} expects a floating point input tensor, but got {input.dtype}",
  127. )
  128. input = _maybe_promote_tensor_fft(input)
  129. dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
  130. dim_size = n if n is not None else input.shape[dim]
  131. torch._check(
  132. dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
  133. )
  134. if n is not None:
  135. input = _resize_fft_input(input, dims, (n,))
  136. ret = prims.fft_r2c(input, dim=dims, onesided=onesided)
  137. ret = _apply_norm(ret, norm, dim_size, forward)
  138. return ret if forward else torch.conj(ret)
  139. def _fft_c2c(
  140. func_name: str,
  141. input: TensorLikeType,
  142. n: Optional[int],
  143. dim: int,
  144. norm: NormType,
  145. forward: bool,
  146. ) -> TensorLikeType:
  147. """Common code for performing any complex to complex FFT (fft or ifft)"""
  148. torch._check(
  149. input.dtype.is_complex,
  150. lambda: f"{func_name} expects a complex input tensor, but got {input.dtype}",
  151. )
  152. dims = (utils.canonicalize_dim(input.ndim, dim, wrap_scalar=False),)
  153. dim_size = n if n is not None else input.shape[dim]
  154. torch._check(
  155. dim_size >= 1, lambda: f"Invalid number of data points ({dim_size}) specified"
  156. )
  157. if n is not None:
  158. input = _resize_fft_input(input, dims, (n,))
  159. ret = prims.fft_c2c(input, dim=dims, forward=forward)
  160. return _apply_norm(ret, norm, dim_size, forward)
  161. @register_decomposition(aten.fft_fft)
  162. @out_wrapper()
  163. def fft(
  164. input: TensorLikeType,
  165. n: Optional[int] = None,
  166. dim: int = -1,
  167. norm: NormType = None,
  168. ) -> TensorLikeType:
  169. if input.dtype.is_complex:
  170. return _fft_c2c("fft", input, n, dim, norm, forward=True)
  171. else:
  172. return _fft_r2c("fft", input, n, dim, norm, forward=True, onesided=False)
  173. @register_decomposition(aten.fft_ifft)
  174. @out_wrapper()
  175. def ifft(
  176. input: TensorLikeType,
  177. n: Optional[int] = None,
  178. dim: int = -1,
  179. norm: NormType = None,
  180. ) -> TensorLikeType:
  181. if input.dtype.is_complex:
  182. return _fft_c2c("ifft", input, n, dim, norm, forward=False)
  183. else:
  184. return _fft_r2c("ifft", input, n, dim, norm, forward=False, onesided=False)
  185. @register_decomposition(aten.fft_rfft)
  186. @out_wrapper()
  187. def rfft(
  188. input: TensorLikeType,
  189. n: Optional[int] = None,
  190. dim: int = -1,
  191. norm: NormType = None,
  192. ) -> TensorLikeType:
  193. return _fft_r2c("rfft", input, n, dim, norm, forward=True, onesided=True)
  194. @register_decomposition(aten.fft_irfft)
  195. @out_wrapper()
  196. def irfft(
  197. input: TensorLikeType,
  198. n: Optional[int] = None,
  199. dim: int = -1,
  200. norm: NormType = None,
  201. ) -> TensorLikeType:
  202. return _fft_c2r("irfft", input, n, dim, norm, forward=False)
  203. @register_decomposition(aten.fft_hfft)
  204. @out_wrapper()
  205. def hfft(
  206. input: TensorLikeType,
  207. n: Optional[int] = None,
  208. dim: int = -1,
  209. norm: NormType = None,
  210. ) -> TensorLikeType:
  211. return _fft_c2r("hfft", input, n, dim, norm, forward=True)
  212. @register_decomposition(aten.fft_ihfft)
  213. @out_wrapper()
  214. def ihfft(
  215. input: TensorLikeType,
  216. n: Optional[int] = None,
  217. dim: int = -1,
  218. norm: NormType = None,
  219. ) -> TensorLikeType:
  220. return _fft_r2c("ihfft", input, n, dim, norm, forward=False, onesided=True)
  221. class _ShapeAndDims(NamedTuple):
  222. shape: Tuple[int, ...]
  223. dims: Tuple[int, ...]
  224. def _canonicalize_fft_shape_and_dim_args(
  225. input: TensorLikeType, shape: Optional[ShapeType], dim: Optional[DimsType]
  226. ) -> _ShapeAndDims:
  227. """Convert the shape and dim arguments into a canonical form where neither are optional"""
  228. input_dim = input.ndim
  229. input_sizes = input.shape
  230. if dim is not None:
  231. if not isinstance(dim, Sequence):
  232. dim = (dim,)
  233. ret_dims = utils.canonicalize_dims(input_dim, dim, wrap_scalar=False)
  234. # Check dims are unique
  235. torch._check(
  236. len(set(ret_dims)) == len(ret_dims), lambda: "FFT dims must be unique"
  237. )
  238. if shape is not None:
  239. if not isinstance(shape, Sequence):
  240. shape = (shape,)
  241. # Has shape, might have dim
  242. torch._check(
  243. dim is None or len(dim) == len(shape),
  244. lambda: "When given, dim and shape arguments must have the same length",
  245. )
  246. transform_ndim = len(shape)
  247. torch._check(
  248. transform_ndim <= input_dim,
  249. lambda: f"Got shape with {transform_ndim} values but input tensor "
  250. f"only has {input_dim} dimensions.",
  251. )
  252. # If shape is given, dims defaults to the last len(shape) dimensions
  253. if dim is None:
  254. ret_dims = tuple(range(input_dim - transform_ndim, input_dim))
  255. # Translate any -1 values in shape to the default length
  256. ret_shape = tuple(
  257. s if s != -1 else input_sizes[d] for (s, d) in zip(shape, ret_dims) # type: ignore[possibly-undefined]
  258. )
  259. elif dim is None:
  260. # No shape, no dim
  261. ret_dims = tuple(range(input_dim))
  262. ret_shape = tuple(input_sizes)
  263. else:
  264. # No shape, has dim
  265. ret_shape = tuple(input_sizes[d] for d in ret_dims) # type: ignore[possibly-undefined]
  266. for n in ret_shape:
  267. torch._check(n > 0, lambda: f"Invalid number of data points ({n}) specified")
  268. return _ShapeAndDims(shape=ret_shape, dims=ret_dims) # type: ignore[possibly-undefined]
  269. def _prod(xs: Iterable[int]) -> int:
  270. """Compute product of a list"""
  271. prod = 1
  272. for x in xs:
  273. prod *= x
  274. return prod
  275. def _fftn_c2c(
  276. function_name: str,
  277. input: TensorLikeType,
  278. shape: Tuple[int, ...],
  279. dim: Tuple[int, ...],
  280. norm: NormType,
  281. forward: bool,
  282. ) -> TensorLikeType:
  283. """Common code for n-dimensional complex to complex FFTs (fftn or ifftn)"""
  284. torch._check(
  285. input.dtype.is_complex,
  286. lambda: f"{function_name} expects a complex input tensor, "
  287. f"but got {input.dtype}",
  288. )
  289. x = _resize_fft_input(input, dim, shape)
  290. output = prims.fft_c2c(x, dim=dim, forward=forward)
  291. return _apply_norm(output, norm=norm, signal_numel=_prod(shape), forward=forward)
  292. @register_decomposition(aten.fft_fftn)
  293. @out_wrapper()
  294. def fftn(
  295. input: TensorLikeType,
  296. s: Optional[ShapeType] = None,
  297. dim: Optional[DimsType] = None,
  298. norm: NormType = None,
  299. ) -> TensorLikeType:
  300. (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  301. x = _maybe_promote_tensor_fft(input, require_complex=True)
  302. return _fftn_c2c("fftn", x, shape, dim, norm, forward=True)
  303. @register_decomposition(aten.fft_ifftn)
  304. @out_wrapper()
  305. def ifftn(
  306. input: TensorLikeType,
  307. s: Optional[ShapeType] = None,
  308. dim: Optional[DimsType] = None,
  309. norm: NormType = None,
  310. ) -> TensorLikeType:
  311. (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  312. x = _maybe_promote_tensor_fft(input, require_complex=True)
  313. return _fftn_c2c("ifftn", x, shape, dim, norm, forward=False)
  314. @register_decomposition(aten.fft_rfftn)
  315. @out_wrapper()
  316. def rfftn(
  317. input: TensorLikeType,
  318. s: Optional[ShapeType] = None,
  319. dim: Optional[DimsType] = None,
  320. norm: NormType = None,
  321. ) -> TensorLikeType:
  322. torch._check(
  323. not input.dtype.is_complex,
  324. lambda: f"rfftn expects a real-valued input tensor, but got {input.dtype}",
  325. )
  326. shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  327. input = _maybe_promote_tensor_fft(input, require_complex=False)
  328. input = _resize_fft_input(input, dim, shape)
  329. out = prims.fft_r2c(input, dim=dim, onesided=True)
  330. return _apply_norm(out, norm=norm, signal_numel=_prod(shape), forward=True)
  331. @register_decomposition(aten.fft_ihfftn)
  332. @out_wrapper()
  333. def ihfftn(
  334. input: TensorLikeType,
  335. s: Optional[ShapeType] = None,
  336. dim: Optional[DimsType] = None,
  337. norm: NormType = None,
  338. ) -> TensorLikeType:
  339. torch._check(
  340. not input.dtype.is_complex,
  341. lambda: f"ihfftn expects a real-valued input tensor, but got {input.dtype}",
  342. )
  343. shape, dim = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  344. torch._check(len(shape) > 0, lambda: "ihfftn must transform at least one axis")
  345. input = _maybe_promote_tensor_fft(input, require_complex=False)
  346. input = _resize_fft_input(input, dim, shape)
  347. tmp = prims.fft_r2c(input, dim=dim[-1:], onesided=True)
  348. if len(dim) == 1:
  349. tmp = _apply_norm(tmp, norm=norm, signal_numel=shape[0], forward=False)
  350. return prims.conj(tmp)
  351. tmp = prims.conj_physical(tmp)
  352. tmp = prims.fft_c2c(tmp, dim=dim[:-1], forward=False)
  353. return _apply_norm(tmp, norm=norm, signal_numel=_prod(shape), forward=False)
  354. class _CanonicalizeC2rReturn(NamedTuple):
  355. shape: Tuple[int, ...]
  356. dim: Tuple[int, ...]
  357. last_dim_size: int
  358. def _canonicalize_fft_c2r_shape_and_dim_args(
  359. fname: str,
  360. input: TensorLikeType,
  361. s: Optional[ShapeType],
  362. dim: Optional[DimsType],
  363. ) -> _CanonicalizeC2rReturn:
  364. """Canonicalize shape and dim arguments for n-dimensional c2r transforms,
  365. as well as calculating the last_dim_size which is shape[dim[-1]] for the output"""
  366. (shape, dim) = _canonicalize_fft_shape_and_dim_args(input, s, dim)
  367. torch._check(len(shape) > 0, lambda: f"{fname} must transform at least one axis")
  368. if s is None or s[-1] == -1:
  369. last_dim_size = 2 * (input.shape[dim[-1]] - 1)
  370. else:
  371. last_dim_size = shape[-1]
  372. torch._check(
  373. last_dim_size >= 1,
  374. lambda: f"Invalid number of data points ({last_dim_size}) specified",
  375. )
  376. shape_list = list(shape)
  377. shape_list[-1] = last_dim_size // 2 + 1
  378. return _CanonicalizeC2rReturn(
  379. shape=tuple(shape_list), dim=dim, last_dim_size=last_dim_size
  380. )
  381. @register_decomposition(aten.fft_irfftn)
  382. @out_wrapper()
  383. def irfftn(
  384. input: TensorLikeType,
  385. s: Optional[ShapeType] = None,
  386. dim: Optional[DimsType] = None,
  387. norm: NormType = None,
  388. ) -> TensorLikeType:
  389. shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
  390. "irfftn", input, s, dim
  391. )
  392. input = _maybe_promote_tensor_fft(input, require_complex=True)
  393. input = _resize_fft_input(input, dim, shape)
  394. out = prims.fft_c2r(input, dim=dim, last_dim_size=last_dim_size)
  395. return _apply_norm(out, norm, _prod(out.shape[d] for d in dim), forward=False)
  396. @register_decomposition(aten.fft_hfftn)
  397. @out_wrapper()
  398. def hfftn(
  399. input: TensorLikeType,
  400. s: Optional[ShapeType] = None,
  401. dim: Optional[DimsType] = None,
  402. norm: NormType = None,
  403. ) -> TensorLikeType:
  404. shape, dim, last_dim_size = _canonicalize_fft_c2r_shape_and_dim_args(
  405. "hfftn", input, s, dim
  406. )
  407. input = _maybe_promote_tensor_fft(input, require_complex=True)
  408. input = _resize_fft_input(input, dim, shape)
  409. tmp = prims.fft_c2c(input, dim=dim[:-1], forward=True) if len(dim) > 1 else input
  410. tmp = _apply_norm(tmp, norm, _prod(shape[:-1]), forward=True)
  411. tmp = prims.conj_physical(tmp)
  412. out = prims.fft_c2r(tmp, dim=dim[-1:], last_dim_size=last_dim_size)
  413. return _apply_norm(out, norm, last_dim_size, forward=True)
  414. @register_decomposition(aten.fft_fft2)
  415. @out_wrapper()
  416. def fft2(
  417. input: TensorLikeType,
  418. s: Optional[ShapeType] = None,
  419. dim: Optional[DimsType] = (-2, -1),
  420. norm: NormType = None,
  421. ) -> TensorLikeType:
  422. return torch.fft.fftn(input, s=s, dim=dim, norm=norm)
  423. @register_decomposition(aten.fft_ifft2)
  424. @out_wrapper()
  425. def ifft2(
  426. input: TensorLikeType,
  427. s: Optional[ShapeType] = None,
  428. dim: Optional[DimsType] = (-2, -1),
  429. norm: NormType = None,
  430. ) -> TensorLikeType:
  431. return torch.fft.ifftn(input, s=s, dim=dim, norm=norm)
  432. @register_decomposition(aten.fft_rfft2)
  433. @out_wrapper()
  434. def rfft2(
  435. input: TensorLikeType,
  436. s: Optional[ShapeType] = None,
  437. dim: Optional[DimsType] = (-2, -1),
  438. norm: NormType = None,
  439. ) -> TensorLikeType:
  440. return torch.fft.rfftn(input, s=s, dim=dim, norm=norm)
  441. @register_decomposition(aten.fft_irfft2)
  442. @out_wrapper()
  443. def irfft2(
  444. input: TensorLikeType,
  445. s: Optional[ShapeType] = None,
  446. dim: Optional[DimsType] = (-2, -1),
  447. norm: NormType = None,
  448. ) -> TensorLikeType:
  449. return torch.fft.irfftn(input, s=s, dim=dim, norm=norm)
  450. @register_decomposition(aten.fft_hfft2)
  451. @out_wrapper()
  452. def hfft2(
  453. input: TensorLikeType,
  454. s: Optional[ShapeType] = None,
  455. dim: Optional[DimsType] = (-2, -1),
  456. norm: NormType = None,
  457. ) -> TensorLikeType:
  458. return torch.fft.hfftn(input, s=s, dim=dim, norm=norm)
  459. @register_decomposition(aten.fft_ihfft2)
  460. @out_wrapper()
  461. def ihfft2(
  462. input: TensorLikeType,
  463. s: Optional[ShapeType] = None,
  464. dim: Optional[DimsType] = (-2, -1),
  465. norm: NormType = None,
  466. ) -> TensorLikeType:
  467. return torch.fft.ihfftn(input, s=s, dim=dim, norm=norm)
  468. def _default_alldims(dim: Optional[DimsType], x: TensorLikeType) -> List[int]:
  469. """Convert Optional[DimsType] to a simple list, defaulting to all dimensions"""
  470. if dim is None:
  471. return list(range(x.ndim))
  472. elif not isinstance(dim, Sequence):
  473. return [dim]
  474. else:
  475. return list(dim)
  476. @register_decomposition(aten.fft_fftshift)
  477. def fftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
  478. dims = _default_alldims(dim, input)
  479. shift = [input.shape[d] // 2 for d in dims]
  480. return torch.roll(input, shift, dims)
  481. @register_decomposition(aten.fft_ifftshift)
  482. def ifftshift(input: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType:
  483. dims = _default_alldims(dim, input)
  484. shift = [(input.shape[d] + 1) // 2 for d in dims]
  485. return torch.roll(input, shift, dims)