linear.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # mypy: allow-untyped-defs
  2. import math
  3. from typing import Any
  4. import torch
  5. from torch import Tensor
  6. from torch.nn.parameter import Parameter, UninitializedParameter
  7. from .. import functional as F
  8. from .. import init
  9. from .module import Module
  10. from .lazy import LazyModuleMixin
  11. __all__ = [
  12. 'Bilinear',
  13. 'Identity',
  14. 'LazyLinear',
  15. 'Linear',
  16. ]
  17. class Identity(Module):
  18. r"""A placeholder identity operator that is argument-insensitive.
  19. Args:
  20. args: any argument (unused)
  21. kwargs: any keyword argument (unused)
  22. Shape:
  23. - Input: :math:`(*)`, where :math:`*` means any number of dimensions.
  24. - Output: :math:`(*)`, same shape as the input.
  25. Examples::
  26. >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False)
  27. >>> input = torch.randn(128, 20)
  28. >>> output = m(input)
  29. >>> print(output.size())
  30. torch.Size([128, 20])
  31. """
  32. def __init__(self, *args: Any, **kwargs: Any) -> None:
  33. super().__init__()
  34. def forward(self, input: Tensor) -> Tensor:
  35. return input
  36. class Linear(Module):
  37. r"""Applies an affine linear transformation to the incoming data: :math:`y = xA^T + b`.
  38. This module supports :ref:`TensorFloat32<tf32_on_ampere>`.
  39. On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward.
  40. Args:
  41. in_features: size of each input sample
  42. out_features: size of each output sample
  43. bias: If set to ``False``, the layer will not learn an additive bias.
  44. Default: ``True``
  45. Shape:
  46. - Input: :math:`(*, H_{in})` where :math:`*` means any number of
  47. dimensions including none and :math:`H_{in} = \text{in\_features}`.
  48. - Output: :math:`(*, H_{out})` where all but the last dimension
  49. are the same shape as the input and :math:`H_{out} = \text{out\_features}`.
  50. Attributes:
  51. weight: the learnable weights of the module of shape
  52. :math:`(\text{out\_features}, \text{in\_features})`. The values are
  53. initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  54. :math:`k = \frac{1}{\text{in\_features}}`
  55. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  56. If :attr:`bias` is ``True``, the values are initialized from
  57. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  58. :math:`k = \frac{1}{\text{in\_features}}`
  59. Examples::
  60. >>> m = nn.Linear(20, 30)
  61. >>> input = torch.randn(128, 20)
  62. >>> output = m(input)
  63. >>> print(output.size())
  64. torch.Size([128, 30])
  65. """
  66. __constants__ = ['in_features', 'out_features']
  67. in_features: int
  68. out_features: int
  69. weight: Tensor
  70. def __init__(self, in_features: int, out_features: int, bias: bool = True,
  71. device=None, dtype=None) -> None:
  72. factory_kwargs = {'device': device, 'dtype': dtype}
  73. super().__init__()
  74. self.in_features = in_features
  75. self.out_features = out_features
  76. self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
  77. if bias:
  78. self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
  79. else:
  80. self.register_parameter('bias', None)
  81. self.reset_parameters()
  82. def reset_parameters(self) -> None:
  83. # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
  84. # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
  85. # https://github.com/pytorch/pytorch/issues/57109
  86. init.kaiming_uniform_(self.weight, a=math.sqrt(5))
  87. if self.bias is not None:
  88. fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  89. bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
  90. init.uniform_(self.bias, -bound, bound)
  91. def forward(self, input: Tensor) -> Tensor:
  92. return F.linear(input, self.weight, self.bias)
  93. def extra_repr(self) -> str:
  94. return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}'
  95. # This class exists solely to avoid triggering an obscure error when scripting
  96. # an improperly quantized attention layer. See this issue for details:
  97. # https://github.com/pytorch/pytorch/issues/58969
  98. # TODO: fail fast on quantization API usage error, then remove this class
  99. # and replace uses of it with plain Linear
  100. class NonDynamicallyQuantizableLinear(Linear):
  101. def __init__(self, in_features: int, out_features: int, bias: bool = True,
  102. device=None, dtype=None) -> None:
  103. super().__init__(in_features, out_features, bias=bias,
  104. device=device, dtype=dtype)
  105. class Bilinear(Module):
  106. r"""Applies a bilinear transformation to the incoming data: :math:`y = x_1^T A x_2 + b`.
  107. Args:
  108. in1_features: size of each first input sample
  109. in2_features: size of each second input sample
  110. out_features: size of each output sample
  111. bias: If set to False, the layer will not learn an additive bias.
  112. Default: ``True``
  113. Shape:
  114. - Input1: :math:`(*, H_{in1})` where :math:`H_{in1}=\text{in1\_features}` and
  115. :math:`*` means any number of additional dimensions including none. All but the last dimension
  116. of the inputs should be the same.
  117. - Input2: :math:`(*, H_{in2})` where :math:`H_{in2}=\text{in2\_features}`.
  118. - Output: :math:`(*, H_{out})` where :math:`H_{out}=\text{out\_features}`
  119. and all but the last dimension are the same shape as the input.
  120. Attributes:
  121. weight: the learnable weights of the module of shape
  122. :math:`(\text{out\_features}, \text{in1\_features}, \text{in2\_features})`.
  123. The values are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  124. :math:`k = \frac{1}{\text{in1\_features}}`
  125. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  126. If :attr:`bias` is ``True``, the values are initialized from
  127. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  128. :math:`k = \frac{1}{\text{in1\_features}}`
  129. Examples::
  130. >>> m = nn.Bilinear(20, 30, 40)
  131. >>> input1 = torch.randn(128, 20)
  132. >>> input2 = torch.randn(128, 30)
  133. >>> output = m(input1, input2)
  134. >>> print(output.size())
  135. torch.Size([128, 40])
  136. """
  137. __constants__ = ['in1_features', 'in2_features', 'out_features']
  138. in1_features: int
  139. in2_features: int
  140. out_features: int
  141. weight: Tensor
  142. def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True,
  143. device=None, dtype=None) -> None:
  144. factory_kwargs = {'device': device, 'dtype': dtype}
  145. super().__init__()
  146. self.in1_features = in1_features
  147. self.in2_features = in2_features
  148. self.out_features = out_features
  149. self.weight = Parameter(torch.empty((out_features, in1_features, in2_features), **factory_kwargs))
  150. if bias:
  151. self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
  152. else:
  153. self.register_parameter('bias', None)
  154. self.reset_parameters()
  155. def reset_parameters(self) -> None:
  156. bound = 1 / math.sqrt(self.weight.size(1))
  157. init.uniform_(self.weight, -bound, bound)
  158. if self.bias is not None:
  159. init.uniform_(self.bias, -bound, bound)
  160. def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
  161. return F.bilinear(input1, input2, self.weight, self.bias)
  162. def extra_repr(self) -> str:
  163. return (f'in1_features={self.in1_features}, in2_features={self.in2_features}, '
  164. f'out_features={self.out_features}, bias={self.bias is not None}')
  165. class LazyLinear(LazyModuleMixin, Linear):
  166. r"""A :class:`torch.nn.Linear` module where `in_features` is inferred.
  167. In this module, the `weight` and `bias` are of :class:`torch.nn.UninitializedParameter`
  168. class. They will be initialized after the first call to ``forward`` is done and the
  169. module will become a regular :class:`torch.nn.Linear` module. The ``in_features`` argument
  170. of the :class:`Linear` is inferred from the ``input.shape[-1]``.
  171. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  172. on lazy modules and their limitations.
  173. Args:
  174. out_features: size of each output sample
  175. bias: If set to ``False``, the layer will not learn an additive bias.
  176. Default: ``True``
  177. Attributes:
  178. weight: the learnable weights of the module of shape
  179. :math:`(\text{out\_features}, \text{in\_features})`. The values are
  180. initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})`, where
  181. :math:`k = \frac{1}{\text{in\_features}}`
  182. bias: the learnable bias of the module of shape :math:`(\text{out\_features})`.
  183. If :attr:`bias` is ``True``, the values are initialized from
  184. :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` where
  185. :math:`k = \frac{1}{\text{in\_features}}`
  186. """
  187. cls_to_become = Linear # type: ignore[assignment]
  188. weight: UninitializedParameter
  189. bias: UninitializedParameter # type: ignore[assignment]
  190. def __init__(self, out_features: int, bias: bool = True,
  191. device=None, dtype=None) -> None:
  192. factory_kwargs = {'device': device, 'dtype': dtype}
  193. # bias is hardcoded to False to avoid creating tensor
  194. # that will soon be overwritten.
  195. super().__init__(0, 0, False)
  196. self.weight = UninitializedParameter(**factory_kwargs)
  197. self.out_features = out_features
  198. if bias:
  199. self.bias = UninitializedParameter(**factory_kwargs)
  200. def reset_parameters(self) -> None:
  201. if not self.has_uninitialized_params() and self.in_features != 0:
  202. super().reset_parameters()
  203. def initialize_parameters(self, input) -> None: # type: ignore[override]
  204. if self.has_uninitialized_params():
  205. with torch.no_grad():
  206. self.in_features = input.shape[-1]
  207. self.weight.materialize((self.out_features, self.in_features))
  208. if self.bias is not None:
  209. self.bias.materialize((self.out_features,))
  210. self.reset_parameters()
  211. # TODO: PartialLinear - maybe in sparse?