fold.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. from .module import Module
  2. from .. import functional as F
  3. from torch import Tensor
  4. from ..common_types import _size_any_t
  5. __all__ = ['Fold', 'Unfold']
  6. class Fold(Module):
  7. r"""Combines an array of sliding local blocks into a large containing tensor.
  8. Consider a batched :attr:`input` tensor containing sliding local blocks,
  9. e.g., patches of images, of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`,
  10. where :math:`N` is batch dimension, :math:`C \times \prod(\text{kernel\_size})`
  11. is the number of values within a block (a block has :math:`\prod(\text{kernel\_size})`
  12. spatial locations each containing a :math:`C`-channeled vector), and
  13. :math:`L` is the total number of blocks. (This is exactly the
  14. same specification as the output shape of :class:`~torch.nn.Unfold`.) This
  15. operation combines these local blocks into the large :attr:`output` tensor
  16. of shape :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
  17. by summing the overlapping values. Similar to :class:`~torch.nn.Unfold`, the
  18. arguments must satisfy
  19. .. math::
  20. L = \prod_d \left\lfloor\frac{\text{output\_size}[d] + 2 \times \text{padding}[d] %
  21. - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
  22. where :math:`d` is over all spatial dimensions.
  23. * :attr:`output_size` describes the spatial shape of the large containing
  24. tensor of the sliding local blocks. It is useful to resolve the ambiguity
  25. when multiple input shapes map to same number of sliding blocks, e.g.,
  26. with ``stride > 0``.
  27. The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
  28. how the sliding blocks are retrieved.
  29. * :attr:`stride` controls the stride for the sliding blocks.
  30. * :attr:`padding` controls the amount of implicit zero-paddings on both
  31. sides for :attr:`padding` number of points for each dimension before
  32. reshaping.
  33. * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
  34. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
  35. Args:
  36. output_size (int or tuple): the shape of the spatial dimensions of the
  37. output (i.e., ``output.sizes()[2:]``)
  38. kernel_size (int or tuple): the size of the sliding blocks
  39. dilation (int or tuple, optional): a parameter that controls the
  40. stride of elements within the
  41. neighborhood. Default: 1
  42. padding (int or tuple, optional): implicit zero padding to be added on
  43. both sides of input. Default: 0
  44. stride (int or tuple): the stride of the sliding blocks in the input
  45. spatial dimensions. Default: 1
  46. * If :attr:`output_size`, :attr:`kernel_size`, :attr:`dilation`,
  47. :attr:`padding` or :attr:`stride` is an int or a tuple of length 1 then
  48. their values will be replicated across all spatial dimensions.
  49. * For the case of two output spatial dimensions this operation is sometimes
  50. called ``col2im``.
  51. .. note::
  52. :class:`~torch.nn.Fold` calculates each combined value in the resulting
  53. large tensor by summing all values from all containing blocks.
  54. :class:`~torch.nn.Unfold` extracts the values in the local blocks by
  55. copying from the large tensor. So, if the blocks overlap, they are not
  56. inverses of each other.
  57. In general, folding and unfolding operations are related as
  58. follows. Consider :class:`~torch.nn.Fold` and
  59. :class:`~torch.nn.Unfold` instances created with the same
  60. parameters:
  61. >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
  62. >>> fold = nn.Fold(output_size=..., **fold_params)
  63. >>> unfold = nn.Unfold(**fold_params)
  64. Then for any (supported) ``input`` tensor the following
  65. equality holds:
  66. ::
  67. fold(unfold(input)) == divisor * input
  68. where ``divisor`` is a tensor that depends only on the shape
  69. and dtype of the ``input``:
  70. >>> # xdoctest: +SKIP
  71. >>> input_ones = torch.ones(input.shape, dtype=input.dtype)
  72. >>> divisor = fold(unfold(input_ones))
  73. When the ``divisor`` tensor contains no zero elements, then
  74. ``fold`` and ``unfold`` operations are inverses of each
  75. other (up to constant divisor).
  76. .. warning::
  77. Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.
  78. Shape:
  79. - Input: :math:`(N, C \times \prod(\text{kernel\_size}), L)` or :math:`(C \times \prod(\text{kernel\_size}), L)`
  80. - Output: :math:`(N, C, \text{output\_size}[0], \text{output\_size}[1], \dots)`
  81. or :math:`(C, \text{output\_size}[0], \text{output\_size}[1], \dots)` as described above
  82. Examples::
  83. >>> fold = nn.Fold(output_size=(4, 5), kernel_size=(2, 2))
  84. >>> input = torch.randn(1, 3 * 2 * 2, 12)
  85. >>> output = fold(input)
  86. >>> output.size()
  87. torch.Size([1, 3, 4, 5])
  88. .. _link:
  89. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  90. """
  91. __constants__ = ['output_size', 'kernel_size', 'dilation', 'padding',
  92. 'stride']
  93. output_size: _size_any_t
  94. kernel_size: _size_any_t
  95. dilation: _size_any_t
  96. padding: _size_any_t
  97. stride: _size_any_t
  98. def __init__(
  99. self,
  100. output_size: _size_any_t,
  101. kernel_size: _size_any_t,
  102. dilation: _size_any_t = 1,
  103. padding: _size_any_t = 0,
  104. stride: _size_any_t = 1
  105. ) -> None:
  106. super().__init__()
  107. self.output_size = output_size
  108. self.kernel_size = kernel_size
  109. self.dilation = dilation
  110. self.padding = padding
  111. self.stride = stride
  112. def forward(self, input: Tensor) -> Tensor:
  113. return F.fold(input, self.output_size, self.kernel_size, self.dilation,
  114. self.padding, self.stride)
  115. def extra_repr(self) -> str:
  116. return 'output_size={output_size}, kernel_size={kernel_size}, ' \
  117. 'dilation={dilation}, padding={padding}, stride={stride}'.format(
  118. **self.__dict__
  119. )
  120. class Unfold(Module):
  121. r"""Extracts sliding local blocks from a batched input tensor.
  122. Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,
  123. where :math:`N` is the batch dimension, :math:`C` is the channel dimension,
  124. and :math:`*` represent arbitrary spatial dimensions. This operation flattens
  125. each sliding :attr:`kernel_size`-sized block within the spatial dimensions
  126. of :attr:`input` into a column (i.e., last dimension) of a 3-D :attr:`output`
  127. tensor of shape :math:`(N, C \times \prod(\text{kernel\_size}), L)`, where
  128. :math:`C \times \prod(\text{kernel\_size})` is the total number of values
  129. within each block (a block has :math:`\prod(\text{kernel\_size})` spatial
  130. locations each containing a :math:`C`-channeled vector), and :math:`L` is
  131. the total number of such blocks:
  132. .. math::
  133. L = \prod_d \left\lfloor\frac{\text{spatial\_size}[d] + 2 \times \text{padding}[d] %
  134. - \text{dilation}[d] \times (\text{kernel\_size}[d] - 1) - 1}{\text{stride}[d]} + 1\right\rfloor,
  135. where :math:`\text{spatial\_size}` is formed by the spatial dimensions
  136. of :attr:`input` (:math:`*` above), and :math:`d` is over all spatial
  137. dimensions.
  138. Therefore, indexing :attr:`output` at the last dimension (column dimension)
  139. gives all values within a certain block.
  140. The :attr:`padding`, :attr:`stride` and :attr:`dilation` arguments specify
  141. how the sliding blocks are retrieved.
  142. * :attr:`stride` controls the stride for the sliding blocks.
  143. * :attr:`padding` controls the amount of implicit zero-paddings on both
  144. sides for :attr:`padding` number of points for each dimension before
  145. reshaping.
  146. * :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
  147. It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
  148. Args:
  149. kernel_size (int or tuple): the size of the sliding blocks
  150. dilation (int or tuple, optional): a parameter that controls the
  151. stride of elements within the
  152. neighborhood. Default: 1
  153. padding (int or tuple, optional): implicit zero padding to be added on
  154. both sides of input. Default: 0
  155. stride (int or tuple, optional): the stride of the sliding blocks in the input
  156. spatial dimensions. Default: 1
  157. * If :attr:`kernel_size`, :attr:`dilation`, :attr:`padding` or
  158. :attr:`stride` is an int or a tuple of length 1, their values will be
  159. replicated across all spatial dimensions.
  160. * For the case of two input spatial dimensions this operation is sometimes
  161. called ``im2col``.
  162. .. note::
  163. :class:`~torch.nn.Fold` calculates each combined value in the resulting
  164. large tensor by summing all values from all containing blocks.
  165. :class:`~torch.nn.Unfold` extracts the values in the local blocks by
  166. copying from the large tensor. So, if the blocks overlap, they are not
  167. inverses of each other.
  168. In general, folding and unfolding operations are related as
  169. follows. Consider :class:`~torch.nn.Fold` and
  170. :class:`~torch.nn.Unfold` instances created with the same
  171. parameters:
  172. >>> fold_params = dict(kernel_size=..., dilation=..., padding=..., stride=...)
  173. >>> fold = nn.Fold(output_size=..., **fold_params)
  174. >>> unfold = nn.Unfold(**fold_params)
  175. Then for any (supported) ``input`` tensor the following
  176. equality holds:
  177. ::
  178. fold(unfold(input)) == divisor * input
  179. where ``divisor`` is a tensor that depends only on the shape
  180. and dtype of the ``input``:
  181. >>> # xdoctest: +SKIP
  182. >>> input_ones = torch.ones(input.shape, dtype=input.dtype)
  183. >>> divisor = fold(unfold(input_ones))
  184. When the ``divisor`` tensor contains no zero elements, then
  185. ``fold`` and ``unfold`` operations are inverses of each
  186. other (up to constant divisor).
  187. .. warning::
  188. Currently, only 4-D input tensors (batched image-like tensors) are
  189. supported.
  190. Shape:
  191. - Input: :math:`(N, C, *)`
  192. - Output: :math:`(N, C \times \prod(\text{kernel\_size}), L)` as described above
  193. Examples::
  194. >>> unfold = nn.Unfold(kernel_size=(2, 3))
  195. >>> input = torch.randn(2, 5, 3, 4)
  196. >>> output = unfold(input)
  197. >>> # each patch contains 30 values (2x3=6 vectors, each of 5 channels)
  198. >>> # 4 blocks (2x3 kernels) in total in the 3x4 input
  199. >>> output.size()
  200. torch.Size([2, 30, 4])
  201. >>> # xdoctest: +IGNORE_WANT
  202. >>> # Convolution is equivalent with Unfold + Matrix Multiplication + Fold (or view to output shape)
  203. >>> inp = torch.randn(1, 3, 10, 12)
  204. >>> w = torch.randn(2, 3, 4, 5)
  205. >>> inp_unf = torch.nn.functional.unfold(inp, (4, 5))
  206. >>> out_unf = inp_unf.transpose(1, 2).matmul(w.view(w.size(0), -1).t()).transpose(1, 2)
  207. >>> out = torch.nn.functional.fold(out_unf, (7, 8), (1, 1))
  208. >>> # or equivalently (and avoiding a copy),
  209. >>> # out = out_unf.view(1, 2, 7, 8)
  210. >>> (torch.nn.functional.conv2d(inp, w) - out).abs().max()
  211. tensor(1.9073e-06)
  212. .. _link:
  213. https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
  214. """
  215. __constants__ = ['kernel_size', 'dilation', 'padding', 'stride']
  216. kernel_size: _size_any_t
  217. dilation: _size_any_t
  218. padding: _size_any_t
  219. stride: _size_any_t
  220. def __init__(
  221. self,
  222. kernel_size: _size_any_t,
  223. dilation: _size_any_t = 1,
  224. padding: _size_any_t = 0,
  225. stride: _size_any_t = 1
  226. ) -> None:
  227. super().__init__()
  228. self.kernel_size = kernel_size
  229. self.dilation = dilation
  230. self.padding = padding
  231. self.stride = stride
  232. def forward(self, input: Tensor) -> Tensor:
  233. return F.unfold(input, self.kernel_size, self.dilation,
  234. self.padding, self.stride)
  235. def extra_repr(self) -> str:
  236. return 'kernel_size={kernel_size}, dilation={dilation}, padding={padding},' \
  237. ' stride={stride}'.format(**self.__dict__)