sparse.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. from torch import Tensor
  5. from torch.nn.parameter import Parameter
  6. from .module import Module
  7. from .. import functional as F
  8. from .. import init
  9. __all__ = ['Embedding', 'EmbeddingBag']
  10. class Embedding(Module):
  11. r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
  12. This module is often used to store word embeddings and retrieve them using indices.
  13. The input to the module is a list of indices, and the output is the corresponding
  14. word embeddings.
  15. Args:
  16. num_embeddings (int): size of the dictionary of embeddings
  17. embedding_dim (int): the size of each embedding vector
  18. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  19. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  20. i.e. it remains as a fixed "pad". For a newly constructed Embedding,
  21. the embedding vector at :attr:`padding_idx` will default to all zeros,
  22. but can be updated to another value to be used as the padding vector.
  23. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  24. is renormalized to have norm :attr:`max_norm`.
  25. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  26. scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse of frequency of
  27. the words in the mini-batch. Default ``False``.
  28. sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor.
  29. See Notes for more details regarding sparse gradients.
  30. Attributes:
  31. weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
  32. initialized from :math:`\mathcal{N}(0, 1)`
  33. Shape:
  34. - Input: :math:`(*)`, IntTensor or LongTensor of arbitrary shape containing the indices to extract
  35. - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
  36. .. note::
  37. Keep in mind that only a limited number of optimizers support
  38. sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
  39. :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
  40. .. note::
  41. When :attr:`max_norm` is not ``None``, :class:`Embedding`'s forward method will modify the
  42. :attr:`weight` tensor in-place. Since tensors needed for gradient computations cannot be
  43. modified in-place, performing a differentiable operation on ``Embedding.weight`` before
  44. calling :class:`Embedding`'s forward method requires cloning ``Embedding.weight`` when
  45. :attr:`max_norm` is not ``None``. For example::
  46. n, d, m = 3, 5, 7
  47. embedding = nn.Embedding(n, d, max_norm=True)
  48. W = torch.randn((m, d), requires_grad=True)
  49. idx = torch.tensor([1, 2])
  50. a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable
  51. b = embedding(idx) @ W.t() # modifies weight in-place
  52. out = (a.unsqueeze(0) + b.unsqueeze(1))
  53. loss = out.sigmoid().prod()
  54. loss.backward()
  55. Examples::
  56. >>> # an Embedding module containing 10 tensors of size 3
  57. >>> embedding = nn.Embedding(10, 3)
  58. >>> # a batch of 2 samples of 4 indices each
  59. >>> input = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])
  60. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  61. >>> embedding(input)
  62. tensor([[[-0.0251, -1.6902, 0.7172],
  63. [-0.6431, 0.0748, 0.6969],
  64. [ 1.4970, 1.3448, -0.9685],
  65. [-0.3677, -2.7265, -0.1685]],
  66. [[ 1.4970, 1.3448, -0.9685],
  67. [ 0.4362, -0.4004, 0.9400],
  68. [-0.6431, 0.0748, 0.6969],
  69. [ 0.9124, -2.3616, 1.1151]]])
  70. >>> # example with padding_idx
  71. >>> embedding = nn.Embedding(10, 3, padding_idx=0)
  72. >>> input = torch.LongTensor([[0, 2, 0, 5]])
  73. >>> embedding(input)
  74. tensor([[[ 0.0000, 0.0000, 0.0000],
  75. [ 0.1535, -2.0309, 0.9315],
  76. [ 0.0000, 0.0000, 0.0000],
  77. [-0.1655, 0.9897, 0.0635]]])
  78. >>> # example of changing `pad` vector
  79. >>> padding_idx = 0
  80. >>> embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
  81. >>> embedding.weight
  82. Parameter containing:
  83. tensor([[ 0.0000, 0.0000, 0.0000],
  84. [-0.7895, -0.7089, -0.0364],
  85. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  86. >>> with torch.no_grad():
  87. ... embedding.weight[padding_idx] = torch.ones(3)
  88. >>> embedding.weight
  89. Parameter containing:
  90. tensor([[ 1.0000, 1.0000, 1.0000],
  91. [-0.7895, -0.7089, -0.0364],
  92. [ 0.6778, 0.5803, 0.2678]], requires_grad=True)
  93. """
  94. __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', 'max_norm',
  95. 'norm_type', 'scale_grad_by_freq', 'sparse']
  96. num_embeddings: int
  97. embedding_dim: int
  98. padding_idx: Optional[int]
  99. max_norm: Optional[float]
  100. norm_type: float
  101. scale_grad_by_freq: bool
  102. weight: Tensor
  103. freeze: bool
  104. sparse: bool
  105. def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None,
  106. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  107. sparse: bool = False, _weight: Optional[Tensor] = None, _freeze: bool = False,
  108. device=None, dtype=None) -> None:
  109. factory_kwargs = {'device': device, 'dtype': dtype}
  110. super().__init__()
  111. self.num_embeddings = num_embeddings
  112. self.embedding_dim = embedding_dim
  113. if padding_idx is not None:
  114. if padding_idx > 0:
  115. assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
  116. elif padding_idx < 0:
  117. assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
  118. padding_idx = self.num_embeddings + padding_idx
  119. self.padding_idx = padding_idx
  120. self.max_norm = max_norm
  121. self.norm_type = norm_type
  122. self.scale_grad_by_freq = scale_grad_by_freq
  123. if _weight is None:
  124. self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs),
  125. requires_grad=not _freeze)
  126. self.reset_parameters()
  127. else:
  128. assert list(_weight.shape) == [num_embeddings, embedding_dim], \
  129. 'Shape of weight does not match num_embeddings and embedding_dim'
  130. self.weight = Parameter(_weight, requires_grad=not _freeze)
  131. self.sparse = sparse
  132. def reset_parameters(self) -> None:
  133. init.normal_(self.weight)
  134. self._fill_padding_idx_with_zero()
  135. def _fill_padding_idx_with_zero(self) -> None:
  136. if self.padding_idx is not None:
  137. with torch.no_grad():
  138. self.weight[self.padding_idx].fill_(0)
  139. def forward(self, input: Tensor) -> Tensor:
  140. return F.embedding(
  141. input, self.weight, self.padding_idx, self.max_norm,
  142. self.norm_type, self.scale_grad_by_freq, self.sparse)
  143. def extra_repr(self) -> str:
  144. s = '{num_embeddings}, {embedding_dim}'
  145. if self.padding_idx is not None:
  146. s += ', padding_idx={padding_idx}'
  147. if self.max_norm is not None:
  148. s += ', max_norm={max_norm}'
  149. if self.norm_type != 2:
  150. s += ', norm_type={norm_type}'
  151. if self.scale_grad_by_freq is not False:
  152. s += ', scale_grad_by_freq={scale_grad_by_freq}'
  153. if self.sparse is not False:
  154. s += ', sparse=True'
  155. return s.format(**self.__dict__)
  156. @classmethod
  157. def from_pretrained(cls, embeddings, freeze=True, padding_idx=None,
  158. max_norm=None, norm_type=2., scale_grad_by_freq=False,
  159. sparse=False):
  160. r"""Create Embedding instance from given 2-dimensional FloatTensor.
  161. Args:
  162. embeddings (Tensor): FloatTensor containing weights for the Embedding.
  163. First dimension is being passed to Embedding as ``num_embeddings``, second as ``embedding_dim``.
  164. freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
  165. Equivalent to ``embedding.weight.requires_grad = False``. Default: ``True``
  166. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the gradient;
  167. therefore, the embedding vector at :attr:`padding_idx` is not updated during training,
  168. i.e. it remains as a fixed "pad".
  169. max_norm (float, optional): See module initialization documentation.
  170. norm_type (float, optional): See module initialization documentation. Default ``2``.
  171. scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
  172. sparse (bool, optional): See module initialization documentation.
  173. Examples::
  174. >>> # FloatTensor containing pretrained weights
  175. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  176. >>> embedding = nn.Embedding.from_pretrained(weight)
  177. >>> # Get embeddings for index 1
  178. >>> input = torch.LongTensor([1])
  179. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  180. >>> embedding(input)
  181. tensor([[ 4.0000, 5.1000, 6.3000]])
  182. """
  183. assert embeddings.dim() == 2, \
  184. 'Embeddings parameter is expected to be 2-dimensional'
  185. rows, cols = embeddings.shape
  186. embedding = cls(
  187. num_embeddings=rows,
  188. embedding_dim=cols,
  189. _weight=embeddings,
  190. _freeze=freeze,
  191. padding_idx=padding_idx,
  192. max_norm=max_norm,
  193. norm_type=norm_type,
  194. scale_grad_by_freq=scale_grad_by_freq,
  195. sparse=sparse)
  196. return embedding
  197. class EmbeddingBag(Module):
  198. r"""Compute sums or means of 'bags' of embeddings, without instantiating the intermediate embeddings.
  199. For bags of constant length, no :attr:`per_sample_weights`, no indices equal to :attr:`padding_idx`,
  200. and with 2D inputs, this class
  201. * with ``mode="sum"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.sum(dim=1)``,
  202. * with ``mode="mean"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.mean(dim=1)``,
  203. * with ``mode="max"`` is equivalent to :class:`~torch.nn.Embedding` followed by ``torch.max(dim=1)``.
  204. However, :class:`~torch.nn.EmbeddingBag` is much more time and memory efficient than using a chain of these
  205. operations.
  206. EmbeddingBag also supports per-sample weights as an argument to the forward
  207. pass. This scales the output of the Embedding before performing a weighted
  208. reduction as specified by ``mode``. If :attr:`per_sample_weights` is passed, the
  209. only supported ``mode`` is ``"sum"``, which computes a weighted sum according to
  210. :attr:`per_sample_weights`.
  211. Args:
  212. num_embeddings (int): size of the dictionary of embeddings
  213. embedding_dim (int): the size of each embedding vector
  214. max_norm (float, optional): If given, each embedding vector with norm larger than :attr:`max_norm`
  215. is renormalized to have norm :attr:`max_norm`.
  216. norm_type (float, optional): The p of the p-norm to compute for the :attr:`max_norm` option. Default ``2``.
  217. scale_grad_by_freq (bool, optional): if given, this will scale gradients by the inverse of frequency of
  218. the words in the mini-batch. Default ``False``.
  219. Note: this option is not supported when ``mode="max"``.
  220. mode (str, optional): ``"sum"``, ``"mean"`` or ``"max"``. Specifies the way to reduce the bag.
  221. ``"sum"`` computes the weighted sum, taking :attr:`per_sample_weights`
  222. into consideration. ``"mean"`` computes the average of the values
  223. in the bag, ``"max"`` computes the max value over each bag.
  224. Default: ``"mean"``
  225. sparse (bool, optional): if ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. See
  226. Notes for more details regarding sparse gradients. Note: this option is not
  227. supported when ``mode="max"``.
  228. include_last_offset (bool, optional): if ``True``, :attr:`offsets` has one additional element, where the last element
  229. is equivalent to the size of `indices`. This matches the CSR format.
  230. padding_idx (int, optional): If specified, the entries at :attr:`padding_idx` do not contribute to the
  231. gradient; therefore, the embedding vector at :attr:`padding_idx` is not updated
  232. during training, i.e. it remains as a fixed "pad". For a newly constructed
  233. EmbeddingBag, the embedding vector at :attr:`padding_idx` will default to all
  234. zeros, but can be updated to another value to be used as the padding vector.
  235. Note that the embedding vector at :attr:`padding_idx` is excluded from the
  236. reduction.
  237. Attributes:
  238. weight (Tensor): the learnable weights of the module of shape `(num_embeddings, embedding_dim)`
  239. initialized from :math:`\mathcal{N}(0, 1)`.
  240. Examples::
  241. >>> # an EmbeddingBag module containing 10 tensors of size 3
  242. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')
  243. >>> # a batch of 2 samples of 4 indices each
  244. >>> input = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9], dtype=torch.long)
  245. >>> offsets = torch.tensor([0, 4], dtype=torch.long)
  246. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  247. >>> embedding_sum(input, offsets)
  248. tensor([[-0.8861, -5.4350, -0.0523],
  249. [ 1.1306, -2.5798, -1.0044]])
  250. >>> # Example with padding_idx
  251. >>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum', padding_idx=2)
  252. >>> input = torch.tensor([2, 2, 2, 2, 4, 3, 2, 9], dtype=torch.long)
  253. >>> offsets = torch.tensor([0, 4], dtype=torch.long)
  254. >>> embedding_sum(input, offsets)
  255. tensor([[ 0.0000, 0.0000, 0.0000],
  256. [-0.7082, 3.2145, -2.6251]])
  257. >>> # An EmbeddingBag can be loaded from an Embedding like so
  258. >>> embedding = nn.Embedding(10, 3, padding_idx=2)
  259. >>> embedding_sum = nn.EmbeddingBag.from_pretrained(
  260. embedding.weight,
  261. padding_idx=embedding.padding_idx,
  262. mode='sum')
  263. """
  264. __constants__ = ['num_embeddings', 'embedding_dim', 'max_norm', 'norm_type',
  265. 'scale_grad_by_freq', 'mode', 'sparse', 'include_last_offset',
  266. 'padding_idx']
  267. num_embeddings: int
  268. embedding_dim: int
  269. max_norm: Optional[float]
  270. norm_type: float
  271. scale_grad_by_freq: bool
  272. weight: Tensor
  273. mode: str
  274. sparse: bool
  275. include_last_offset: bool
  276. padding_idx: Optional[int]
  277. def __init__(self, num_embeddings: int, embedding_dim: int,
  278. max_norm: Optional[float] = None, norm_type: float = 2., scale_grad_by_freq: bool = False,
  279. mode: str = 'mean', sparse: bool = False, _weight: Optional[Tensor] = None,
  280. include_last_offset: bool = False, padding_idx: Optional[int] = None,
  281. device=None, dtype=None) -> None:
  282. factory_kwargs = {'device': device, 'dtype': dtype}
  283. super().__init__()
  284. self.num_embeddings = num_embeddings
  285. self.embedding_dim = embedding_dim
  286. self.max_norm = max_norm
  287. self.norm_type = norm_type
  288. self.scale_grad_by_freq = scale_grad_by_freq
  289. if padding_idx is not None:
  290. if padding_idx > 0:
  291. assert padding_idx < self.num_embeddings, 'padding_idx must be within num_embeddings'
  292. elif padding_idx < 0:
  293. assert padding_idx >= -self.num_embeddings, 'padding_idx must be within num_embeddings'
  294. padding_idx = self.num_embeddings + padding_idx
  295. self.padding_idx = padding_idx
  296. if _weight is None:
  297. self.weight = Parameter(torch.empty((num_embeddings, embedding_dim), **factory_kwargs))
  298. self.reset_parameters()
  299. else:
  300. assert list(_weight.shape) == [num_embeddings, embedding_dim], \
  301. 'Shape of weight does not match num_embeddings and embedding_dim'
  302. self.weight = Parameter(_weight)
  303. self.mode = mode
  304. self.sparse = sparse
  305. self.include_last_offset = include_last_offset
  306. def reset_parameters(self) -> None:
  307. init.normal_(self.weight)
  308. self._fill_padding_idx_with_zero()
  309. def _fill_padding_idx_with_zero(self) -> None:
  310. if self.padding_idx is not None:
  311. with torch.no_grad():
  312. self.weight[self.padding_idx].fill_(0)
  313. def forward(self, input: Tensor, offsets: Optional[Tensor] = None, per_sample_weights: Optional[Tensor] = None) -> Tensor:
  314. """Forward pass of EmbeddingBag.
  315. Args:
  316. input (Tensor): Tensor containing bags of indices into the embedding matrix.
  317. offsets (Tensor, optional): Only used when :attr:`input` is 1D. :attr:`offsets` determines
  318. the starting index position of each bag (sequence) in :attr:`input`.
  319. per_sample_weights (Tensor, optional): a tensor of float / double weights, or None
  320. to indicate all weights should be taken to be ``1``. If specified, :attr:`per_sample_weights`
  321. must have exactly the same shape as input and is treated as having the same
  322. :attr:`offsets`, if those are not ``None``. Only supported for ``mode='sum'``.
  323. Returns:
  324. Tensor output shape of `(B, embedding_dim)`.
  325. .. note::
  326. A few notes about ``input`` and ``offsets``:
  327. - :attr:`input` and :attr:`offsets` have to be of the same type, either int or long
  328. - If :attr:`input` is 2D of shape `(B, N)`, it will be treated as ``B`` bags (sequences)
  329. each of fixed length ``N``, and this will return ``B`` values aggregated in a way
  330. depending on the :attr:`mode`. :attr:`offsets` is ignored and required to be ``None`` in this case.
  331. - If :attr:`input` is 1D of shape `(N)`, it will be treated as a concatenation of
  332. multiple bags (sequences). :attr:`offsets` is required to be a 1D tensor containing the
  333. starting index positions of each bag in :attr:`input`. Therefore, for :attr:`offsets` of shape `(B)`,
  334. :attr:`input` will be viewed as having ``B`` bags. Empty bags (i.e., having 0-length) will have
  335. returned vectors filled by zeros.
  336. """
  337. return F.embedding_bag(input, self.weight, offsets,
  338. self.max_norm, self.norm_type,
  339. self.scale_grad_by_freq, self.mode, self.sparse,
  340. per_sample_weights, self.include_last_offset,
  341. self.padding_idx)
  342. def extra_repr(self) -> str:
  343. s = '{num_embeddings}, {embedding_dim}'
  344. if self.max_norm is not None:
  345. s += ', max_norm={max_norm}'
  346. if self.norm_type != 2:
  347. s += ', norm_type={norm_type}'
  348. if self.scale_grad_by_freq is not False:
  349. s += ', scale_grad_by_freq={scale_grad_by_freq}'
  350. s += ', mode={mode}'
  351. if self.padding_idx is not None:
  352. s += ', padding_idx={padding_idx}'
  353. return s.format(**{k: repr(v) for k, v in self.__dict__.items()})
  354. @classmethod
  355. def from_pretrained(cls, embeddings: Tensor, freeze: bool = True, max_norm: Optional[float] = None,
  356. norm_type: float = 2., scale_grad_by_freq: bool = False,
  357. mode: str = 'mean', sparse: bool = False, include_last_offset: bool = False,
  358. padding_idx: Optional[int] = None) -> 'EmbeddingBag':
  359. r"""Create EmbeddingBag instance from given 2-dimensional FloatTensor.
  360. Args:
  361. embeddings (Tensor): FloatTensor containing weights for the EmbeddingBag.
  362. First dimension is being passed to EmbeddingBag as 'num_embeddings', second as 'embedding_dim'.
  363. freeze (bool, optional): If ``True``, the tensor does not get updated in the learning process.
  364. Equivalent to ``embeddingbag.weight.requires_grad = False``. Default: ``True``
  365. max_norm (float, optional): See module initialization documentation. Default: ``None``
  366. norm_type (float, optional): See module initialization documentation. Default ``2``.
  367. scale_grad_by_freq (bool, optional): See module initialization documentation. Default ``False``.
  368. mode (str, optional): See module initialization documentation. Default: ``"mean"``
  369. sparse (bool, optional): See module initialization documentation. Default: ``False``.
  370. include_last_offset (bool, optional): See module initialization documentation. Default: ``False``.
  371. padding_idx (int, optional): See module initialization documentation. Default: ``None``.
  372. Examples::
  373. >>> # FloatTensor containing pretrained weights
  374. >>> weight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
  375. >>> embeddingbag = nn.EmbeddingBag.from_pretrained(weight)
  376. >>> # Get embeddings for index 1
  377. >>> input = torch.LongTensor([[1, 0]])
  378. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  379. >>> embeddingbag(input)
  380. tensor([[ 2.5000, 3.7000, 4.6500]])
  381. """
  382. assert embeddings.dim() == 2, \
  383. 'Embeddings parameter is expected to be 2-dimensional'
  384. rows, cols = embeddings.shape
  385. embeddingbag = cls(
  386. num_embeddings=rows,
  387. embedding_dim=cols,
  388. _weight=embeddings,
  389. max_norm=max_norm,
  390. norm_type=norm_type,
  391. scale_grad_by_freq=scale_grad_by_freq,
  392. mode=mode,
  393. sparse=sparse,
  394. include_last_offset=include_last_offset,
  395. padding_idx=padding_idx)
  396. embeddingbag.weight.requires_grad = not freeze
  397. return embeddingbag