flatten.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. # mypy: allow-untyped-defs
  2. from .module import Module
  3. from typing import Tuple, Union
  4. from torch import Tensor
  5. from torch.types import _size
  6. __all__ = ['Flatten', 'Unflatten']
  7. class Flatten(Module):
  8. r"""
  9. Flattens a contiguous range of dims into a tensor.
  10. For use with :class:`~nn.Sequential`, see :meth:`torch.flatten` for details.
  11. Shape:
  12. - Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,'
  13. where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any
  14. number of dimensions including none.
  15. - Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`.
  16. Args:
  17. start_dim: first dim to flatten (default = 1).
  18. end_dim: last dim to flatten (default = -1).
  19. Examples::
  20. >>> input = torch.randn(32, 1, 5, 5)
  21. >>> # With default parameters
  22. >>> m = nn.Flatten()
  23. >>> output = m(input)
  24. >>> output.size()
  25. torch.Size([32, 25])
  26. >>> # With non-default parameters
  27. >>> m = nn.Flatten(0, 2)
  28. >>> output = m(input)
  29. >>> output.size()
  30. torch.Size([160, 5])
  31. """
  32. __constants__ = ['start_dim', 'end_dim']
  33. start_dim: int
  34. end_dim: int
  35. def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
  36. super().__init__()
  37. self.start_dim = start_dim
  38. self.end_dim = end_dim
  39. def forward(self, input: Tensor) -> Tensor:
  40. return input.flatten(self.start_dim, self.end_dim)
  41. def extra_repr(self) -> str:
  42. return f'start_dim={self.start_dim}, end_dim={self.end_dim}'
  43. class Unflatten(Module):
  44. r"""
  45. Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`.
  46. * :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can
  47. be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively.
  48. * :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be
  49. a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape`
  50. (tuple of `(name, size)` tuples) for `NamedTensor` input.
  51. Shape:
  52. - Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at
  53. dimension :attr:`dim` and :math:`*` means any number of dimensions including none.
  54. - Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and
  55. :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
  56. Args:
  57. dim (Union[int, str]): Dimension to be unflattened
  58. unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension
  59. Examples:
  60. >>> input = torch.randn(2, 50)
  61. >>> # With tuple of ints
  62. >>> m = nn.Sequential(
  63. >>> nn.Linear(50, 50),
  64. >>> nn.Unflatten(1, (2, 5, 5))
  65. >>> )
  66. >>> output = m(input)
  67. >>> output.size()
  68. torch.Size([2, 2, 5, 5])
  69. >>> # With torch.Size
  70. >>> m = nn.Sequential(
  71. >>> nn.Linear(50, 50),
  72. >>> nn.Unflatten(1, torch.Size([2, 5, 5]))
  73. >>> )
  74. >>> output = m(input)
  75. >>> output.size()
  76. torch.Size([2, 2, 5, 5])
  77. >>> # With namedshape (tuple of tuples)
  78. >>> input = torch.randn(2, 50, names=('N', 'features'))
  79. >>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
  80. >>> output = unflatten(input)
  81. >>> output.size()
  82. torch.Size([2, 2, 5, 5])
  83. """
  84. NamedShape = Tuple[Tuple[str, int]]
  85. __constants__ = ['dim', 'unflattened_size']
  86. dim: Union[int, str]
  87. unflattened_size: Union[_size, NamedShape]
  88. def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None:
  89. super().__init__()
  90. if isinstance(dim, int):
  91. self._require_tuple_int(unflattened_size)
  92. elif isinstance(dim, str):
  93. self._require_tuple_tuple(unflattened_size)
  94. else:
  95. raise TypeError("invalid argument type for dim parameter")
  96. self.dim = dim
  97. self.unflattened_size = unflattened_size
  98. def _require_tuple_tuple(self, input):
  99. if (isinstance(input, tuple)):
  100. for idx, elem in enumerate(input):
  101. if not isinstance(elem, tuple):
  102. raise TypeError("unflattened_size must be tuple of tuples, " +
  103. f"but found element of type {type(elem).__name__} at pos {idx}")
  104. return
  105. raise TypeError("unflattened_size must be a tuple of tuples, " +
  106. f"but found type {type(input).__name__}")
  107. def _require_tuple_int(self, input):
  108. if (isinstance(input, (tuple, list))):
  109. for idx, elem in enumerate(input):
  110. if not isinstance(elem, int):
  111. raise TypeError("unflattened_size must be tuple of ints, " +
  112. f"but found element of type {type(elem).__name__} at pos {idx}")
  113. return
  114. raise TypeError(f"unflattened_size must be a tuple of ints, but found type {type(input).__name__}")
  115. def forward(self, input: Tensor) -> Tensor:
  116. return input.unflatten(self.dim, self.unflattened_size)
  117. def extra_repr(self) -> str:
  118. return f'dim={self.dim}, unflattened_size={self.unflattened_size}'