| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247 |
- from __future__ import annotations
- import json
- import os
- from typing import Any
- import torch
- from torch import Tensor, nn
- class Pooling(nn.Module):
- """
- Performs pooling (max or mean) on the token embeddings.
- Using pooling, it generates from a variable sized sentence a fixed sized sentence embedding. This layer also allows
- to use the CLS token if it is returned by the underlying word embedding model. You can concatenate multiple poolings
- together.
- Args:
- word_embedding_dimension: Dimensions for the word embeddings
- pooling_mode: Either "cls", "lasttoken", "max", "mean",
- "mean_sqrt_len_tokens", or "weightedmean". If set,
- overwrites the other pooling_mode_* settings
- pooling_mode_cls_token: Use the first token (CLS token) as text
- representations
- pooling_mode_max_tokens: Use max in each dimension over all
- tokens.
- pooling_mode_mean_tokens: Perform mean-pooling
- pooling_mode_mean_sqrt_len_tokens: Perform mean-pooling, but
- divide by sqrt(input_length).
- pooling_mode_weightedmean_tokens: Perform (position) weighted
- mean pooling. See `SGPT: GPT Sentence Embeddings for
- Semantic Search <https://arxiv.org/abs/2202.08904>`_.
- pooling_mode_lasttoken: Perform last token pooling. See `SGPT:
- GPT Sentence Embeddings for Semantic Search
- <https://arxiv.org/abs/2202.08904>`_ and `Text and Code
- Embeddings by Contrastive Pre-Training
- <https://arxiv.org/abs/2201.10005>`_.
- """
- POOLING_MODES = (
- "cls",
- "lasttoken",
- "max",
- "mean",
- "mean_sqrt_len_tokens",
- "weightedmean",
- )
- def __init__(
- self,
- word_embedding_dimension: int,
- pooling_mode: str = None,
- pooling_mode_cls_token: bool = False,
- pooling_mode_max_tokens: bool = False,
- pooling_mode_mean_tokens: bool = True,
- pooling_mode_mean_sqrt_len_tokens: bool = False,
- pooling_mode_weightedmean_tokens: bool = False,
- pooling_mode_lasttoken: bool = False,
- include_prompt: bool = True,
- ) -> None:
- super().__init__()
- self.config_keys = [
- "word_embedding_dimension",
- "pooling_mode_cls_token",
- "pooling_mode_mean_tokens",
- "pooling_mode_max_tokens",
- "pooling_mode_mean_sqrt_len_tokens",
- "pooling_mode_weightedmean_tokens",
- "pooling_mode_lasttoken",
- "include_prompt",
- ]
- if pooling_mode is not None: # Set pooling mode by string
- pooling_mode = pooling_mode.lower()
- if pooling_mode not in self.POOLING_MODES:
- raise ValueError(
- f"Set invalid pooling mode: {pooling_mode}. Valid pooling modes are: {self.POOLING_MODES}."
- )
- pooling_mode_cls_token = pooling_mode == "cls"
- pooling_mode_max_tokens = pooling_mode == "max"
- pooling_mode_mean_tokens = pooling_mode == "mean"
- pooling_mode_mean_sqrt_len_tokens = pooling_mode == "mean_sqrt_len_tokens"
- pooling_mode_weightedmean_tokens = pooling_mode == "weightedmean"
- pooling_mode_lasttoken = pooling_mode == "lasttoken"
- self.word_embedding_dimension = word_embedding_dimension
- self.pooling_mode_cls_token = pooling_mode_cls_token
- self.pooling_mode_mean_tokens = pooling_mode_mean_tokens
- self.pooling_mode_max_tokens = pooling_mode_max_tokens
- self.pooling_mode_mean_sqrt_len_tokens = pooling_mode_mean_sqrt_len_tokens
- self.pooling_mode_weightedmean_tokens = pooling_mode_weightedmean_tokens
- self.pooling_mode_lasttoken = pooling_mode_lasttoken
- self.include_prompt = include_prompt
- pooling_mode_multiplier = sum(
- [
- pooling_mode_cls_token,
- pooling_mode_max_tokens,
- pooling_mode_mean_tokens,
- pooling_mode_mean_sqrt_len_tokens,
- pooling_mode_weightedmean_tokens,
- pooling_mode_lasttoken,
- ]
- )
- self.pooling_output_dimension = pooling_mode_multiplier * word_embedding_dimension
- def __repr__(self) -> str:
- return f"Pooling({self.get_config_dict()})"
- def get_pooling_mode_str(self) -> str:
- """Returns the pooling mode as string"""
- modes = []
- if self.pooling_mode_cls_token:
- modes.append("cls")
- if self.pooling_mode_mean_tokens:
- modes.append("mean")
- if self.pooling_mode_max_tokens:
- modes.append("max")
- if self.pooling_mode_mean_sqrt_len_tokens:
- modes.append("mean_sqrt_len_tokens")
- if self.pooling_mode_weightedmean_tokens:
- modes.append("weightedmean")
- if self.pooling_mode_lasttoken:
- modes.append("lasttoken")
- return "+".join(modes)
- def forward(self, features: dict[str, Tensor]) -> dict[str, Tensor]:
- token_embeddings = features["token_embeddings"]
- attention_mask = (
- features["attention_mask"]
- if "attention_mask" in features
- else torch.ones(token_embeddings.shape[:-1], device=token_embeddings.device, dtype=torch.int64)
- )
- if not self.include_prompt and "prompt_length" in features:
- attention_mask[:, : features["prompt_length"]] = 0
- ## Pooling strategy
- output_vectors = []
- if self.pooling_mode_cls_token:
- cls_token = features.get("cls_token_embeddings", token_embeddings[:, 0]) # Take first token by default
- output_vectors.append(cls_token)
- if self.pooling_mode_max_tokens:
- input_mask_expanded = (
- attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
- )
- token_embeddings[input_mask_expanded == 0] = -1e9 # Set padding tokens to large negative value
- max_over_time = torch.max(token_embeddings, 1)[0]
- output_vectors.append(max_over_time)
- if self.pooling_mode_mean_tokens or self.pooling_mode_mean_sqrt_len_tokens:
- input_mask_expanded = (
- attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
- )
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
- # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
- if "token_weights_sum" in features:
- sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size())
- else:
- sum_mask = input_mask_expanded.sum(1)
- sum_mask = torch.clamp(sum_mask, min=1e-9)
- if self.pooling_mode_mean_tokens:
- output_vectors.append(sum_embeddings / sum_mask)
- if self.pooling_mode_mean_sqrt_len_tokens:
- output_vectors.append(sum_embeddings / torch.sqrt(sum_mask))
- if self.pooling_mode_weightedmean_tokens:
- input_mask_expanded = (
- attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
- )
- # token_embeddings shape: bs, seq, hidden_dim
- weights = (
- torch.arange(start=1, end=token_embeddings.shape[1] + 1)
- .unsqueeze(0)
- .unsqueeze(-1)
- .expand(token_embeddings.size())
- .to(token_embeddings.dtype)
- .to(token_embeddings.device)
- )
- assert weights.shape == token_embeddings.shape == input_mask_expanded.shape
- input_mask_expanded = input_mask_expanded * weights
- sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
- # If tokens are weighted (by WordWeights layer), feature 'token_weights_sum' will be present
- if "token_weights_sum" in features:
- sum_mask = features["token_weights_sum"].unsqueeze(-1).expand(sum_embeddings.size())
- else:
- sum_mask = input_mask_expanded.sum(1)
- sum_mask = torch.clamp(sum_mask, min=1e-9)
- output_vectors.append(sum_embeddings / sum_mask)
- if self.pooling_mode_lasttoken:
- bs, seq_len, hidden_dim = token_embeddings.shape
- # attention_mask shape: (bs, seq_len)
- # Get shape [bs] indices of the last token (i.e. the last token for each batch item)
- # Use flip and max() to get the last index of 1 in the attention mask
- if torch.jit.is_tracing():
- # Avoid tracing the argmax with int64 input that can not be handled by ONNX Runtime: https://github.com/microsoft/onnxruntime/issues/10068
- attention_mask = attention_mask.to(torch.int32)
- values, indices = attention_mask.flip(1).max(1)
- indices = torch.where(values == 0, seq_len - 1, indices)
- gather_indices = seq_len - indices - 1
- # Turn indices from shape [bs] --> [bs, 1, hidden_dim]
- gather_indices = gather_indices.unsqueeze(-1).repeat(1, hidden_dim)
- gather_indices = gather_indices.unsqueeze(1)
- assert gather_indices.shape == (bs, 1, hidden_dim)
- # Gather along the 1st dim (seq_len) (bs, seq_len, hidden_dim -> bs, hidden_dim)
- # Actually no need for the attention mask as we gather the last token where attn_mask = 1
- # but as we set some indices (which shouldn't be attended to) to 0 with clamp, we
- # use the attention mask to ignore them again
- input_mask_expanded = (
- attention_mask.unsqueeze(-1).expand(token_embeddings.size()).to(token_embeddings.dtype)
- )
- embedding = torch.gather(token_embeddings * input_mask_expanded, 1, gather_indices).squeeze(dim=1)
- output_vectors.append(embedding)
- output_vector = torch.cat(output_vectors, 1)
- features["sentence_embedding"] = output_vector
- return features
- def get_sentence_embedding_dimension(self) -> int:
- return self.pooling_output_dimension
- def get_config_dict(self) -> dict[str, Any]:
- return {key: self.__dict__[key] for key in self.config_keys}
- def save(self, output_path) -> None:
- with open(os.path.join(output_path, "config.json"), "w") as fOut:
- json.dump(self.get_config_dict(), fOut, indent=2)
- @staticmethod
- def load(input_path) -> Pooling:
- with open(os.path.join(input_path, "config.json")) as fIn:
- config = json.load(fIn)
- return Pooling(**config)
|