Pooling.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. from __future__ import annotations
  2. import json
  3. import os
  4. from typing import Any
  5. import torch
  6. from torch import Tensor, nn
  7. class Pooling(nn.Module):
  8. """
  9. Performs pooling (max or mean) on the token embeddings.
  10. Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows
  11. to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings
  12. together.
  13. Args:
  14. word_embedding_dimension: Dimensions for the word embeddings
  15. pooling_mode: Either "cls", "lasttoken", "max", "mean",
  16. "mean_sqrt_len_tokens", or "weightedmean". If set,
  17. overwrites the other pooling_mode_* settings
  18. pooling_mode_cls_token: Use the first token (CLS token) as text
  19. representations
  20. pooling_mode_max_tokens: Use max in each dimension over all
  21. tokens.
  22. pooling_mode_mean_tokens: Perform mean-pooling
  23. pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but
  24. divide by sqrt(input_length).
  25. pooling_mode_weightedmean_tokens: Perform (position) weighted
  26. mean pooling. See `SGPT: GPT Sentence Embeddings for
  27. Semantic Search <https://arxiv.org/abs/2202.08904>`_.
  28. pooling_mode_lasttoken: Perform last token pooling. See `SGPT:
  29. GPT Sentence Embeddings for Semantic Search
  30. <https://arxiv.org/abs/2202.08904>`_ and `Text and Code
  31. Embeddings by Contrastive Pre-Training
  32. <https://arxiv.org/abs/2201.10005>`_.
  33. """
  34. POOLING_MODES = (
  35. "cls",
  36. "lasttoken",
  37. "max",
  38. "mean",
  39. "mean_sqrt_len_tokens",
  40. "weightedmean",
  41. )
  42. def __init__(
  43. self,
  44. word_embedding_dimension: int,
  45. pooling_mode: str = None,
  46. pooling_mode_cls_token: bool = False,
  47. pooling_mode_max_tokens: bool = False,
  48. pooling_mode_mean_tokens: bool = True,
  49. pooling_mode_mean_sqrt_len_tokens: bool = False,
  50. pooling_mode_weightedmean_tokens: bool = False,
  51. pooling_mode_lasttoken: bool = False,
  52. include_prompt: bool = True,
  53. ) -> None:
  54. super().__init__()
  55. self.config_keys = [
  56. "word_embedding_dimension",
  57. "pooling_mode_cls_token",
  58. "pooling_mode_mean_tokens",
  59. "pooling_mode_max_tokens",
  60. "pooling_mode_mean_sqrt_len_tokens",
  61. "pooling_mode_weightedmean_tokens",
  62. "pooling_mode_lasttoken",
  63. "include_prompt",
  64. ]
  65. if pooling_mode is not None: # Set pooling mode by string
  66. pooling_mode = pooling_mode.lower()
  67. if pooling_mode not in self.POOLING_MODES:
  68. raise ValueError(
  69. f"Set invalid pooling mode: {pooling_mode}. Valid pooling modes are: {self.POOLING_MODES}."
  70. )
  71. pooling_mode_cls_token = pooling_mode == "cls"
  72. pooling_mode_max_tokens = pooling_mode == "max"
  73. pooling_mode_mean_tokens = pooling_mode == "mean"
  74. pooling_mode_mean_sqrt_len_tokens = pooling_mode == "mean_sqrt_len_tokens"
  75. pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean"
  76. pooling_mode_lasttoken = pooling_mode == "lasttoken"
  77. self.word_embedding_dimension = word_embedding_dimension
  78. self.pooling_mode_cls_token = pooling_mode_cls_token
  79. self.pooling_mode_mean_tokens = pooling_mode_mean_tokens
  80. self.pooling_mode_max_tokens = pooling_mode_max_tokens
  81. self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens
  82. self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens
  83. self.pooling_mode_lasttoken = pooling_mode_lasttoken
  84. self.include_prompt = include_prompt
  85. pooling_mode_multiplier = sum(
  86. [
  87. pooling_mode_cls_token,
  88. pooling_mode_max_tokens,
  89. pooling_mode_mean_tokens,
  90. pooling_mode_mean_sqrt_len_tokens,
  91. pooling_mode_weightedmean_tokens,
  92. pooling_mode_lasttoken,
  93. ]
  94. )
  95. self.pooling_output_dimension = pooling_mode_multiplier * word_embedding_dimension
  96. def __repr__(self) -> str:
  97. return f"Pooling({self.get_config_dict()})"
  98. def get_pooling_mode_str(self) -> str:
  99. """Returns the pooling mode as string"""
  100. modes = []
  101. if self.pooling_mode_cls_token:
  102. modes.append("cls")
  103. if self.pooling_mode_mean_tokens:
  104. modes.append("mean")
  105. if self.pooling_mode_max_tokens:
  106. modes.append("max")
  107. if self.pooling_mode_mean_sqrt_len_tokens:
  108. modes.append("mean_sqrt_len_tokens")
  109. if self.pooling_mode_weightedmean_tokens:
  110. modes.append("weightedmean")
  111. if self.pooling_mode_lasttoken:
  112. modes.append("lasttoken")
  113. return "+".join(modes)
  114. def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
  115. token_embeddings = features["token_embeddings"]
  116. attention_mask = (
  117. features["attention_mask"]
  118. if "attention_mask" in features
  119. else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64)
  120. )
  121. if not self.include_prompt and "prompt_length" in features:
  122. attention_mask[:, : features["prompt_length"]] = 0
  123. ## Pooling strategy
  124. output_vectors = []
  125. if self.pooling_mode_cls_token:
  126. cls_token = features.get("cls_token_embeddings", token_embeddings[:, 0]) # Take first token by default
  127. output_vectors.append(cls_token)
  128. if self.pooling_mode_max_tokens:
  129. input_mask_expanded = (
  130. attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
  131. )
  132. token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
  133. max_over_time = torch.max(token_embeddings, 1)[0]
  134. output_vectors.append(max_over_time)
  135. if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
  136. input_mask_expanded = (
  137. attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
  138. )
  139. sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
  140. # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
  141. if "token_weights_sum" in features:
  142. sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size())
  143. else:
  144. sum_mask = input_mask_expanded.sum(1)
  145. sum_mask = torch.clamp(sum_mask, min=1e-9)
  146. if self.pooling_mode_mean_tokens:
  147. output_vectors.append(sum_embeddings / sum_mask)
  148. if self.pooling_mode_mean_sqrt_len_tokens:
  149. output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))
  150. if self.pooling_mode_weightedmean_tokens:
  151. input_mask_expanded = (
  152. attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
  153. )
  154. # token_embeddings shape: bs, seq, hidden_dim
  155. weights = (
  156. torch.arange(start=1, end=token_embeddings.shape[1] + 1)
  157. .unsqueeze(0)
  158. .unsqueeze(-1)
  159. .expand(token_embeddings.size())
  160. .to(token_embeddings.dtype)
  161. .to(token_embeddings.device)
  162. )
  163. assert weights.shape == token_embeddings.shape == input_mask_expanded.shape
  164. input_mask_expanded = input_mask_expanded * weights
  165. sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
  166. # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
  167. if "token_weights_sum" in features:
  168. sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size())
  169. else:
  170. sum_mask = input_mask_expanded.sum(1)
  171. sum_mask = torch.clamp(sum_mask, min=1e-9)
  172. output_vectors.append(sum_embeddings / sum_mask)
  173. if self.pooling_mode_lasttoken:
  174. bs, seq_len, hidden_dim = token_embeddings.shape
  175. # attention_mask shape: (bs, seq_len)
  176. # Get shape [bs] indices of the last token (i.e. the last token for each batch item)
  177. # Use flip and max() to get the last index of 1 in the attention mask
  178. if torch.jit.is_tracing():
  179. # Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime: https://github.com/microsoft/onnxruntime/issues/10068
  180. attention_mask = attention_mask.to(torch.int32)
  181. values, indices = attention_mask.flip(1).max(1)
  182. indices = torch.where(values == 0, seq_len - 1, indices)
  183. gather_indices = seq_len - indices - 1
  184. # Turn indices from shape [bs] --> [bs, 1, hidden_dim]
  185. gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim)
  186. gather_indices = gather_indices.unsqueeze(1)
  187. assert gather_indices.shape == (bs, 1, hidden_dim)
  188. # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim)
  189. # Actually no need for the attention mask as we gather the last token where attn_mask = 1
  190. # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we
  191. # use the attention mask to ignore them again
  192. input_mask_expanded = (
  193. attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
  194. )
  195. embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
  196. output_vectors.append(embedding)
  197. output_vector = torch.cat(output_vectors, 1)
  198. features["sentence_embedding"] = output_vector
  199. return features
  200. def get_sentence_embedding_dimension(self) -> int:
  201. return self.pooling_output_dimension
  202. def get_config_dict(self) -> dict[str, Any]:
  203. return {key: self.__dict__[key] for key in self.config_keys}
  204. def save(self, output_path) -> None:
  205. with open(os.path.join(output_path, "config.json"), "w") as fOut:
  206. json.dump(self.get_config_dict(), fOut, indent=2)
  207. @staticmethod
  208. def load(input_path) -> Pooling:
  209. with open(os.path.join(input_path, "config.json")) as fIn:
  210. config = json.load(fIn)
  211. return Pooling(**config)