StaticEmbedding.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. from __future__ import annotations
  2. import math
  3. import os
  4. from pathlib import Path
  5. import numpy as np
  6. import torch
  7. from safetensors.torch import load_file as load_safetensors_file
  8. from safetensors.torch import save_file as save_safetensors_file
  9. from tokenizers import Tokenizer
  10. from torch import nn
  11. from transformers import PreTrainedTokenizerFast
  12. from sentence_transformers.util import get_device_name
  13. class StaticEmbedding(nn.Module):
  14. def __init__(
  15. self,
  16. tokenizer: Tokenizer | PreTrainedTokenizerFast,
  17. embedding_weights: np.array | torch.Tensor | None = None,
  18. embedding_dim: int | None = None,
  19. **kwargs,
  20. ) -> None:
  21. """
  22. Initializes the StaticEmbedding model given a tokenizer. The model is a simple embedding bag model that
  23. takes the mean of trained per-token embeddings to compute text embeddings.
  24. Args:
  25. tokenizer (Tokenizer | PreTrainedTokenizerFast): The tokenizer to be used. Must be a fast tokenizer
  26. from ``transformers`` or ``tokenizers``.
  27. embedding_weights (np.array | torch.Tensor | None, optional): Pre-trained embedding weights.
  28. Defaults to None.
  29. embedding_dim (int | None, optional): Dimension of the embeddings. Required if embedding_weights
  30. is not provided. Defaults to None.
  31. Example::
  32. from sentence_transformers import SentenceTransformer
  33. from sentence_transformers.models import StaticEmbedding
  34. from tokenizers import Tokenizer
  35. # Pre-distilled embeddings:
  36. static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_base_output")
  37. # or distill your own embeddings:
  38. static_embedding = StaticEmbedding.from_distillation("BAAI/bge-base-en-v1.5", device="cuda")
  39. # or start with randomized embeddings:
  40. tokenizer = Tokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
  41. static_embedding = StaticEmbedding(tokenizer, embedding_dim=512)
  42. model = SentenceTransformer(modules=[static_embedding])
  43. embeddings = model.encode(["What are Pandas?", "The giant panda (Ailuropoda melanoleuca; Chinese: 大熊猫; pinyin: dàxióngmāo), also known as the panda bear or simply the panda, is a bear native to south central China."])
  44. similarity = model.similarity(embeddings[0], embeddings[1])
  45. # tensor([[0.9177]]) (If you use the distilled bge-base)
  46. Raises:
  47. ValueError: If the tokenizer is not a fast tokenizer.
  48. ValueError: If neither `embedding_weights` nor `embedding_dim` is provided.
  49. """
  50. super().__init__()
  51. if isinstance(tokenizer, PreTrainedTokenizerFast):
  52. tokenizer = tokenizer._tokenizer
  53. elif not isinstance(tokenizer, Tokenizer):
  54. raise ValueError(
  55. "The tokenizer must be fast (i.e. Rust-backed) to use this class. "
  56. "Use Tokenizer.from_pretrained() from `tokenizers` to load a fast tokenizer."
  57. )
  58. if embedding_weights is not None:
  59. if isinstance(embedding_weights, np.ndarray):
  60. embedding_weights = torch.from_numpy(embedding_weights)
  61. self.embedding = nn.EmbeddingBag.from_pretrained(embedding_weights, freeze=False)
  62. elif embedding_dim is not None:
  63. self.embedding = nn.EmbeddingBag(tokenizer.get_vocab_size(), embedding_dim)
  64. else:
  65. raise ValueError("Either `embedding_weights` or `embedding_dim` must be provided.")
  66. self.num_embeddings = self.embedding.num_embeddings
  67. self.embedding_dim = self.embedding.embedding_dim
  68. self.tokenizer: Tokenizer = tokenizer
  69. self.tokenizer.no_padding()
  70. # For the model card
  71. self.base_model = kwargs.get("base_model", None)
  72. def tokenize(self, texts: list[str], **kwargs) -> dict[str, torch.Tensor]:
  73. encodings = self.tokenizer.encode_batch(texts, add_special_tokens=False)
  74. encodings_ids = [encoding.ids for encoding in encodings]
  75. offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]]))
  76. input_ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
  77. return {"input_ids": input_ids, "offsets": offsets}
  78. def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
  79. features["sentence_embedding"] = self.embedding(features["input_ids"], features["offsets"])
  80. return features
  81. def get_config_dict(self) -> dict[str, float]:
  82. return {}
  83. @property
  84. def max_seq_length(self) -> int:
  85. return math.inf
  86. def get_sentence_embedding_dimension(self) -> int:
  87. return self.embedding_dim
  88. def save(self, save_dir: str, safe_serialization: bool = True, **kwargs) -> None:
  89. if safe_serialization:
  90. save_safetensors_file(self.state_dict(), os.path.join(save_dir, "model.safetensors"))
  91. else:
  92. torch.save(self.state_dict(), os.path.join(save_dir, "pytorch_model.bin"))
  93. self.tokenizer.save(str(Path(save_dir) / "tokenizer.json"))
  94. def load(load_dir: str, **kwargs) -> StaticEmbedding:
  95. tokenizer = Tokenizer.from_file(str(Path(load_dir) / "tokenizer.json"))
  96. if os.path.exists(os.path.join(load_dir, "model.safetensors")):
  97. weights = load_safetensors_file(os.path.join(load_dir, "model.safetensors"))
  98. else:
  99. weights = torch.load(
  100. os.path.join(load_dir, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True
  101. )
  102. weights = weights["embedding.weight"]
  103. return StaticEmbedding(tokenizer, embedding_weights=weights)
  104. @classmethod
  105. def from_distillation(
  106. cls,
  107. model_name: str,
  108. vocabulary: list[str] | None = None,
  109. device: str | None = None,
  110. pca_dims: int | None = 256,
  111. apply_zipf: bool = True,
  112. use_subword: bool = True,
  113. ) -> StaticEmbedding:
  114. """
  115. Creates a StaticEmbedding instance from a distillation process using the `model2vec` package.
  116. Args:
  117. model_name (str): The name of the model to distill.
  118. vocabulary (list[str] | None, optional): A list of vocabulary words to use. Defaults to None.
  119. device (str): The device to run the distillation on (e.g., 'cpu', 'cuda'). If not specified,
  120. the strongest device is automatically detected. Defaults to None.
  121. pca_dims (int | None, optional): The number of dimensions for PCA reduction. Defaults to 256.
  122. apply_zipf (bool): Whether to apply Zipf's law during distillation. Defaults to True.
  123. use_subword (bool): Whether to use subword tokenization. Defaults to True.
  124. Returns:
  125. StaticEmbedding: An instance of StaticEmbedding initialized with the distilled model's
  126. tokenizer and embedding weights.
  127. Raises:
  128. ImportError: If the `model2vec` package is not installed.
  129. """
  130. try:
  131. from model2vec.distill import distill
  132. except ImportError:
  133. raise ImportError(
  134. "To use this method, please install the `model2vec` package: `pip install model2vec[distill]`"
  135. )
  136. device = get_device_name()
  137. static_model = distill(
  138. model_name,
  139. vocabulary=vocabulary,
  140. device=device,
  141. pca_dims=pca_dims,
  142. apply_zipf=apply_zipf,
  143. use_subword=use_subword,
  144. )
  145. if isinstance(static_model.embedding, np.ndarray):
  146. embedding_weights = torch.from_numpy(static_model.embedding)
  147. else:
  148. embedding_weights = static_model.embedding.weight
  149. tokenizer: Tokenizer = static_model.tokenizer
  150. return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_name)
  151. @classmethod
  152. def from_model2vec(cls, model_id_or_path: str) -> StaticEmbedding:
  153. """
  154. Create a StaticEmbedding instance from a model2vec model. This method loads a pre-trained model2vec model
  155. and extracts the embedding weights and tokenizer to create a StaticEmbedding instance.
  156. Args:
  157. model_id_or_path (str): The identifier or path to the pre-trained model2vec model.
  158. Returns:
  159. StaticEmbedding: An instance of StaticEmbedding initialized with the tokenizer and embedding weights
  160. the model2vec model.
  161. Raises:
  162. ImportError: If the `model2vec` package is not installed.
  163. """
  164. try:
  165. from model2vec import StaticModel
  166. except ImportError:
  167. raise ImportError("To use this method, please install the `model2vec` package: `pip install model2vec`")
  168. static_model = StaticModel.from_pretrained(model_id_or_path)
  169. if isinstance(static_model.embedding, np.ndarray):
  170. embedding_weights = torch.from_numpy(static_model.embedding)
  171. else:
  172. embedding_weights = static_model.embedding.weight
  173. tokenizer: Tokenizer = static_model.tokenizer
  174. return cls(tokenizer, embedding_weights=embedding_weights, base_model=model_id_or_path)