SentenceTransformer.py 84 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850
  1. from __future__ import annotations
  2. import copy
  3. import importlib
  4. import json
  5. import logging
  6. import math
  7. import os
  8. import queue
  9. import shutil
  10. import sys
  11. import tempfile
  12. import traceback
  13. import warnings
  14. from collections import OrderedDict
  15. from contextlib import contextmanager
  16. from multiprocessing import Queue
  17. from pathlib import Path
  18. from typing import Any, Callable, Iterable, Iterator, Literal, overload
  19. import numpy as np
  20. import torch
  21. import torch.multiprocessing as mp
  22. import transformers
  23. from huggingface_hub import HfApi
  24. from numpy import ndarray
  25. from torch import Tensor, device, nn
  26. from tqdm.autonotebook import trange
  27. from transformers import is_torch_npu_available
  28. from transformers.dynamic_module_utils import get_class_from_dynamic_module, get_relative_import_files
  29. from sentence_transformers.model_card import SentenceTransformerModelCardData, generate_model_card
  30. from sentence_transformers.similarity_functions import SimilarityFunction
  31. from . import __MODEL_HUB_ORGANIZATION__, __version__
  32. from .evaluation import SentenceEvaluator
  33. from .fit_mixin import FitMixin
  34. from .models import Normalize, Pooling, Transformer
  35. from .quantization import quantize_embeddings
  36. from .util import (
  37. batch_to_device,
  38. get_device_name,
  39. import_from_string,
  40. is_sentence_transformer_model,
  41. load_dir_path,
  42. load_file_path,
  43. save_to_hub_args_decorator,
  44. truncate_embeddings,
  45. )
  46. logger = logging.getLogger(__name__)
  47. class SentenceTransformer(nn.Sequential, FitMixin):
  48. """
  49. Loads or creates a SentenceTransformer model that can be used to map sentences / text to embeddings.
  50. Args:
  51. model_name_or_path (str, optional): If it is a filepath on disc, it loads the model from that path. If it is not a path,
  52. it first tries to download a pre-trained SentenceTransformer model. If that fails, tries to construct a model
  53. from the Hugging Face Hub with that name.
  54. modules (Iterable[nn.Module], optional): A list of torch Modules that should be called sequentially, can be used to create custom
  55. SentenceTransformer models from scratch.
  56. device (str, optional): Device (like "cuda", "cpu", "mps", "npu") that should be used for computation. If None, checks if a GPU
  57. can be used.
  58. prompts (Dict[str, str], optional): A dictionary with prompts for the model. The key is the prompt name, the value is the prompt text.
  59. The prompt text will be prepended before any text to encode. For example:
  60. `{"query": "query: ", "passage": "passage: "}` or `{"clustering": "Identify the main category based on the
  61. titles in "}`.
  62. default_prompt_name (str, optional): The name of the prompt that should be used by default. If not set,
  63. no prompt will be applied.
  64. similarity_fn_name (str or SimilarityFunction, optional): The name of the similarity function to use. Valid options are "cosine", "dot",
  65. "euclidean", and "manhattan". If not set, it is automatically set to "cosine" if `similarity` or
  66. `similarity_pairwise` are called while `model.similarity_fn_name` is still `None`.
  67. cache_folder (str, optional): Path to store models. Can also be set by the SENTENCE_TRANSFORMERS_HOME environment variable.
  68. trust_remote_code (bool, optional): Whether or not to allow for custom models defined on the Hub in their own modeling files.
  69. This option should only be set to True for repositories you trust and in which you have read the code, as it
  70. will execute code present on the Hub on your local machine.
  71. revision (str, optional): The specific model version to use. It can be a branch name, a tag name, or a commit id,
  72. for a stored model on Hugging Face.
  73. local_files_only (bool, optional): Whether or not to only look at local files (i.e., do not try to download the model).
  74. token (bool or str, optional): Hugging Face authentication token to download private models.
  75. use_auth_token (bool or str, optional): Deprecated argument. Please use `token` instead.
  76. truncate_dim (int, optional): The dimension to truncate sentence embeddings to. `None` does no truncation. Truncation is
  77. only applicable during inference when :meth:`SentenceTransformer.encode` is called.
  78. model_kwargs (Dict[str, Any], optional): Additional model configuration parameters to be passed to the Hugging Face Transformers model.
  79. Particularly useful options are:
  80. - ``torch_dtype``: Override the default `torch.dtype` and load the model under a specific `dtype`.
  81. The different options are:
  82. 1. ``torch.float16``, ``torch.bfloat16`` or ``torch.float``: load in a specified
  83. ``dtype``, ignoring the model's ``config.torch_dtype`` if one exists. If not specified - the model will
  84. get loaded in ``torch.float`` (fp32).
  85. 2. ``"auto"`` - A ``torch_dtype`` entry in the ``config.json`` file of the model will be
  86. attempted to be used. If this entry isn't found then next check the ``dtype`` of the first weight in
  87. the checkpoint that's of a floating point type and use that as ``dtype``. This will load the model
  88. using the ``dtype`` it was saved in at the end of the training. It can't be used as an indicator of how
  89. the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
  90. - ``attn_implementation``: The attention implementation to use in the model (if relevant). Can be any of
  91. `"eager"` (manual implementation of the attention), `"sdpa"` (using `F.scaled_dot_product_attention
  92. <https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html>`_),
  93. or `"flash_attention_2"` (using `Dao-AILab/flash-attention <https://github.com/Dao-AILab/flash-attention>`_).
  94. By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"`
  95. implementation.
  96. - ``provider``: If backend is "onnx", this is the provider to use for inference, for example "CPUExecutionProvider",
  97. "CUDAExecutionProvider", etc. See https://onnxruntime.ai/docs/execution-providers/ for all ONNX execution providers.
  98. - ``file_name``: If backend is "onnx" or "openvino", this is the file name to load, useful for loading optimized
  99. or quantized ONNX or OpenVINO models.
  100. - ``export``: If backend is "onnx" or "openvino", then this is a boolean flag specifying whether this model should
  101. be exported to the backend. If not specified, the model will be exported only if the model repository or directory
  102. does not already contain an exported model.
  103. See the `PreTrainedModel.from_pretrained
  104. <https://huggingface.co/docs/transformers/en/main_classes/model#transformers.PreTrainedModel.from_pretrained>`_
  105. documentation for more details.
  106. tokenizer_kwargs (Dict[str, Any], optional): Additional tokenizer configuration parameters to be passed to the Hugging Face Transformers tokenizer.
  107. See the `AutoTokenizer.from_pretrained
  108. <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained>`_
  109. documentation for more details.
  110. config_kwargs (Dict[str, Any], optional): Additional model configuration parameters to be passed to the Hugging Face Transformers config.
  111. See the `AutoConfig.from_pretrained
  112. <https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoConfig.from_pretrained>`_
  113. documentation for more details.
  114. model_card_data (:class:`~sentence_transformers.model_card.SentenceTransformerModelCardData`, optional): A model
  115. card data object that contains information about the model. This is used to generate a model card when saving
  116. the model. If not set, a default model card data object is created.
  117. backend (str): The backend to use for inference. Can be one of "torch" (default), "onnx", or "openvino".
  118. See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for benchmarking information
  119. on the different backends.
  120. Example:
  121. ::
  122. from sentence_transformers import SentenceTransformer
  123. # Load a pre-trained SentenceTransformer model
  124. model = SentenceTransformer('all-mpnet-base-v2')
  125. # Encode some texts
  126. sentences = [
  127. "The weather is lovely today.",
  128. "It's so sunny outside!",
  129. "He drove to the stadium.",
  130. ]
  131. embeddings = model.encode(sentences)
  132. print(embeddings.shape)
  133. # (3, 768)
  134. # Get the similarity scores between all sentences
  135. similarities = model.similarity(embeddings, embeddings)
  136. print(similarities)
  137. # tensor([[1.0000, 0.6817, 0.0492],
  138. # [0.6817, 1.0000, 0.0421],
  139. # [0.0492, 0.0421, 1.0000]])
  140. """
  141. def __init__(
  142. self,
  143. model_name_or_path: str | None = None,
  144. modules: Iterable[nn.Module] | None = None,
  145. device: str | None = None,
  146. prompts: dict[str, str] | None = None,
  147. default_prompt_name: str | None = None,
  148. similarity_fn_name: str | SimilarityFunction | None = None,
  149. cache_folder: str | None = None,
  150. trust_remote_code: bool = False,
  151. revision: str | None = None,
  152. local_files_only: bool = False,
  153. token: bool | str | None = None,
  154. use_auth_token: bool | str | None = None,
  155. truncate_dim: int | None = None,
  156. model_kwargs: dict[str, Any] | None = None,
  157. tokenizer_kwargs: dict[str, Any] | None = None,
  158. config_kwargs: dict[str, Any] | None = None,
  159. model_card_data: SentenceTransformerModelCardData | None = None,
  160. backend: Literal["torch", "onnx", "openvino"] = "torch",
  161. ) -> None:
  162. # Note: self._load_sbert_model can also update `self.prompts` and `self.default_prompt_name`
  163. self.prompts = prompts or {}
  164. self.default_prompt_name = default_prompt_name
  165. self.similarity_fn_name = similarity_fn_name
  166. self.trust_remote_code = trust_remote_code
  167. self.truncate_dim = truncate_dim
  168. self.model_card_data = model_card_data or SentenceTransformerModelCardData()
  169. self.module_kwargs = None
  170. self._model_card_vars = {}
  171. self._model_card_text = None
  172. self._model_config = {}
  173. self.backend = backend
  174. if use_auth_token is not None:
  175. warnings.warn(
  176. "The `use_auth_token` argument is deprecated and will be removed in v4 of SentenceTransformers.",
  177. FutureWarning,
  178. )
  179. if token is not None:
  180. raise ValueError(
  181. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  182. )
  183. token = use_auth_token
  184. if cache_folder is None:
  185. cache_folder = os.getenv("SENTENCE_TRANSFORMERS_HOME")
  186. if device is None:
  187. device = get_device_name()
  188. logger.info(f"Use pytorch device_name: {device}")
  189. if device == "hpu" and importlib.util.find_spec("optimum") is not None:
  190. from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
  191. adapt_transformers_to_gaudi()
  192. if model_name_or_path is not None and model_name_or_path != "":
  193. logger.info(f"Load pretrained SentenceTransformer: {model_name_or_path}")
  194. # Old models that don't belong to any organization
  195. basic_transformer_models = [
  196. "albert-base-v1",
  197. "albert-base-v2",
  198. "albert-large-v1",
  199. "albert-large-v2",
  200. "albert-xlarge-v1",
  201. "albert-xlarge-v2",
  202. "albert-xxlarge-v1",
  203. "albert-xxlarge-v2",
  204. "bert-base-cased-finetuned-mrpc",
  205. "bert-base-cased",
  206. "bert-base-chinese",
  207. "bert-base-german-cased",
  208. "bert-base-german-dbmdz-cased",
  209. "bert-base-german-dbmdz-uncased",
  210. "bert-base-multilingual-cased",
  211. "bert-base-multilingual-uncased",
  212. "bert-base-uncased",
  213. "bert-large-cased-whole-word-masking-finetuned-squad",
  214. "bert-large-cased-whole-word-masking",
  215. "bert-large-cased",
  216. "bert-large-uncased-whole-word-masking-finetuned-squad",
  217. "bert-large-uncased-whole-word-masking",
  218. "bert-large-uncased",
  219. "camembert-base",
  220. "ctrl",
  221. "distilbert-base-cased-distilled-squad",
  222. "distilbert-base-cased",
  223. "distilbert-base-german-cased",
  224. "distilbert-base-multilingual-cased",
  225. "distilbert-base-uncased-distilled-squad",
  226. "distilbert-base-uncased-finetuned-sst-2-english",
  227. "distilbert-base-uncased",
  228. "distilgpt2",
  229. "distilroberta-base",
  230. "gpt2-large",
  231. "gpt2-medium",
  232. "gpt2-xl",
  233. "gpt2",
  234. "openai-gpt",
  235. "roberta-base-openai-detector",
  236. "roberta-base",
  237. "roberta-large-mnli",
  238. "roberta-large-openai-detector",
  239. "roberta-large",
  240. "t5-11b",
  241. "t5-3b",
  242. "t5-base",
  243. "t5-large",
  244. "t5-small",
  245. "transfo-xl-wt103",
  246. "xlm-clm-ende-1024",
  247. "xlm-clm-enfr-1024",
  248. "xlm-mlm-100-1280",
  249. "xlm-mlm-17-1280",
  250. "xlm-mlm-en-2048",
  251. "xlm-mlm-ende-1024",
  252. "xlm-mlm-enfr-1024",
  253. "xlm-mlm-enro-1024",
  254. "xlm-mlm-tlm-xnli15-1024",
  255. "xlm-mlm-xnli15-1024",
  256. "xlm-roberta-base",
  257. "xlm-roberta-large-finetuned-conll02-dutch",
  258. "xlm-roberta-large-finetuned-conll02-spanish",
  259. "xlm-roberta-large-finetuned-conll03-english",
  260. "xlm-roberta-large-finetuned-conll03-german",
  261. "xlm-roberta-large",
  262. "xlnet-base-cased",
  263. "xlnet-large-cased",
  264. ]
  265. if not os.path.exists(model_name_or_path):
  266. # Not a path, load from hub
  267. if "\\" in model_name_or_path or model_name_or_path.count("/") > 1:
  268. raise ValueError(f"Path {model_name_or_path} not found")
  269. if "/" not in model_name_or_path and model_name_or_path.lower() not in basic_transformer_models:
  270. # A model from sentence-transformers
  271. model_name_or_path = __MODEL_HUB_ORGANIZATION__ + "/" + model_name_or_path
  272. if is_sentence_transformer_model(
  273. model_name_or_path,
  274. token,
  275. cache_folder=cache_folder,
  276. revision=revision,
  277. local_files_only=local_files_only,
  278. ):
  279. modules, self.module_kwargs = self._load_sbert_model(
  280. model_name_or_path,
  281. token=token,
  282. cache_folder=cache_folder,
  283. revision=revision,
  284. trust_remote_code=trust_remote_code,
  285. local_files_only=local_files_only,
  286. model_kwargs=model_kwargs,
  287. tokenizer_kwargs=tokenizer_kwargs,
  288. config_kwargs=config_kwargs,
  289. )
  290. else:
  291. modules = self._load_auto_model(
  292. model_name_or_path,
  293. token=token,
  294. cache_folder=cache_folder,
  295. revision=revision,
  296. trust_remote_code=trust_remote_code,
  297. local_files_only=local_files_only,
  298. model_kwargs=model_kwargs,
  299. tokenizer_kwargs=tokenizer_kwargs,
  300. config_kwargs=config_kwargs,
  301. )
  302. if modules is not None and not isinstance(modules, OrderedDict):
  303. modules = OrderedDict([(str(idx), module) for idx, module in enumerate(modules)])
  304. super().__init__(modules)
  305. # Ensure all tensors in the model are of the same dtype as the first tensor
  306. # This is necessary if the first module has been given a lower precision via
  307. # model_kwargs["torch_dtype"]. The rest of the model should be loaded in the same dtype
  308. # See #2887 for more details
  309. try:
  310. dtype = next(self.parameters()).dtype
  311. self.to(dtype)
  312. except StopIteration:
  313. pass
  314. self.to(device)
  315. self.is_hpu_graph_enabled = False
  316. if self.default_prompt_name is not None and self.default_prompt_name not in self.prompts:
  317. raise ValueError(
  318. f"Default prompt name '{self.default_prompt_name}' not found in the configured prompts "
  319. f"dictionary with keys {list(self.prompts.keys())!r}."
  320. )
  321. if self.prompts:
  322. logger.info(f"{len(self.prompts)} prompts are loaded, with the keys: {list(self.prompts.keys())}")
  323. if self.default_prompt_name:
  324. logger.warning(
  325. f"Default prompt name is set to '{self.default_prompt_name}'. "
  326. "This prompt will be applied to all `encode()` calls, except if `encode()` "
  327. "is called with `prompt` or `prompt_name` parameters."
  328. )
  329. # Ideally, INSTRUCTOR models should set `include_prompt=False` in their pooling configuration, but
  330. # that would be a breaking change for users currently using the InstructorEmbedding project.
  331. # So, instead we hardcode setting it for the main INSTRUCTOR models, and otherwise give a warning if we
  332. # suspect the user is using an INSTRUCTOR model.
  333. if model_name_or_path in ("hkunlp/instructor-base", "hkunlp/instructor-large", "hkunlp/instructor-xl"):
  334. self.set_pooling_include_prompt(include_prompt=False)
  335. elif (
  336. model_name_or_path
  337. and "/" in model_name_or_path
  338. and "instructor" in model_name_or_path.split("/")[1].lower()
  339. ):
  340. if any([module.include_prompt for module in self if isinstance(module, Pooling)]):
  341. logger.warning(
  342. "Instructor models require `include_prompt=False` in the pooling configuration. "
  343. "Either update the model configuration or call `model.set_pooling_include_prompt(False)` after loading the model."
  344. )
  345. # Pass the model to the model card data for later use in generating a model card upon saving this model
  346. self.model_card_data.register_model(self)
  347. def get_backend(self) -> Literal["torch", "onnx", "openvino"]:
  348. """Return the backend used for inference, which can be one of "torch", "onnx", or "openvino".
  349. Returns:
  350. str: The backend used for inference.
  351. """
  352. return self.backend
  353. @overload
  354. def encode(
  355. self,
  356. sentences: str,
  357. prompt_name: str | None = ...,
  358. prompt: str | None = ...,
  359. batch_size: int = ...,
  360. show_progress_bar: bool | None = ...,
  361. output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
  362. precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
  363. convert_to_numpy: Literal[False] = ...,
  364. convert_to_tensor: Literal[False] = ...,
  365. device: str = ...,
  366. normalize_embeddings: bool = ...,
  367. **kwargs,
  368. ) -> Tensor: ...
  369. @overload
  370. def encode(
  371. self,
  372. sentences: str | list[str],
  373. prompt_name: str | None = ...,
  374. prompt: str | None = ...,
  375. batch_size: int = ...,
  376. show_progress_bar: bool | None = ...,
  377. output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
  378. precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
  379. convert_to_numpy: Literal[True] = ...,
  380. convert_to_tensor: Literal[False] = ...,
  381. device: str = ...,
  382. normalize_embeddings: bool = ...,
  383. **kwargs,
  384. ) -> np.ndarray: ...
  385. @overload
  386. def encode(
  387. self,
  388. sentences: str | list[str],
  389. prompt_name: str | None = ...,
  390. prompt: str | None = ...,
  391. batch_size: int = ...,
  392. show_progress_bar: bool | None = ...,
  393. output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
  394. precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
  395. convert_to_numpy: bool = ...,
  396. convert_to_tensor: Literal[True] = ...,
  397. device: str = ...,
  398. normalize_embeddings: bool = ...,
  399. **kwargs,
  400. ) -> Tensor: ...
  401. @overload
  402. def encode(
  403. self,
  404. sentences: list[str] | np.ndarray,
  405. prompt_name: str | None = ...,
  406. prompt: str | None = ...,
  407. batch_size: int = ...,
  408. show_progress_bar: bool | None = ...,
  409. output_value: Literal["sentence_embedding", "token_embeddings"] | None = ...,
  410. precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = ...,
  411. convert_to_numpy: Literal[False] = ...,
  412. convert_to_tensor: Literal[False] = ...,
  413. device: str = ...,
  414. normalize_embeddings: bool = ...,
  415. **kwargs,
  416. ) -> list[Tensor]: ...
  417. def encode(
  418. self,
  419. sentences: str | list[str],
  420. prompt_name: str | None = None,
  421. prompt: str | None = None,
  422. batch_size: int = 32,
  423. show_progress_bar: bool | None = None,
  424. output_value: Literal["sentence_embedding", "token_embeddings"] | None = "sentence_embedding",
  425. precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
  426. convert_to_numpy: bool = True,
  427. convert_to_tensor: bool = False,
  428. device: str = None,
  429. normalize_embeddings: bool = False,
  430. **kwargs,
  431. ) -> list[Tensor] | np.ndarray | Tensor:
  432. """
  433. Computes sentence embeddings.
  434. Args:
  435. sentences (Union[str, List[str]]): The sentences to embed.
  436. prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
  437. which is either set in the constructor or loaded from the model configuration. For example if
  438. ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
  439. is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
  440. is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
  441. prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
  442. sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
  443. because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
  444. batch_size (int, optional): The batch size used for the computation. Defaults to 32.
  445. show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
  446. output_value (Optional[Literal["sentence_embedding", "token_embeddings"]], optional): The type of embeddings to return:
  447. "sentence_embedding" to get sentence embeddings, "token_embeddings" to get wordpiece token embeddings, and `None`,
  448. to get all output values. Defaults to "sentence_embedding".
  449. precision (Literal["float32", "int8", "uint8", "binary", "ubinary"], optional): The precision to use for the embeddings.
  450. Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions are quantized embeddings.
  451. Quantized embeddings are smaller in size and faster to compute, but may have a lower accuracy. They are useful for
  452. reducing the size of the embeddings of a corpus for semantic search, among other tasks. Defaults to "float32".
  453. convert_to_numpy (bool, optional): Whether the output should be a list of numpy vectors. If False, it is a list of PyTorch tensors.
  454. Defaults to True.
  455. convert_to_tensor (bool, optional): Whether the output should be one large tensor. Overwrites `convert_to_numpy`.
  456. Defaults to False.
  457. device (str, optional): Which :class:`torch.device` to use for the computation. Defaults to None.
  458. normalize_embeddings (bool, optional): Whether to normalize returned vectors to have length 1. In that case,
  459. the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
  460. Returns:
  461. Union[List[Tensor], ndarray, Tensor]: By default, a 2d numpy array with shape [num_inputs, output_dimension] is returned.
  462. If only one string input is provided, then the output is a 1d array with shape [output_dimension]. If ``convert_to_tensor``,
  463. a torch Tensor is returned instead. If ``self.truncate_dim <= output_dimension`` then output_dimension is ``self.truncate_dim``.
  464. Example:
  465. ::
  466. from sentence_transformers import SentenceTransformer
  467. # Load a pre-trained SentenceTransformer model
  468. model = SentenceTransformer('all-mpnet-base-v2')
  469. # Encode some texts
  470. sentences = [
  471. "The weather is lovely today.",
  472. "It's so sunny outside!",
  473. "He drove to the stadium.",
  474. ]
  475. embeddings = model.encode(sentences)
  476. print(embeddings.shape)
  477. # (3, 768)
  478. """
  479. if self.device.type == "hpu" and not self.is_hpu_graph_enabled:
  480. import habana_frameworks.torch as ht
  481. ht.hpu.wrap_in_hpu_graph(self, disable_tensor_cache=True)
  482. self.is_hpu_graph_enabled = True
  483. self.eval()
  484. if show_progress_bar is None:
  485. show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
  486. if convert_to_tensor:
  487. convert_to_numpy = False
  488. if output_value != "sentence_embedding":
  489. convert_to_tensor = False
  490. convert_to_numpy = False
  491. input_was_string = False
  492. if isinstance(sentences, str) or not hasattr(
  493. sentences, "__len__"
  494. ): # Cast an individual sentence to a list with length 1
  495. sentences = [sentences]
  496. input_was_string = True
  497. if prompt is None:
  498. if prompt_name is not None:
  499. try:
  500. prompt = self.prompts[prompt_name]
  501. except KeyError:
  502. raise ValueError(
  503. f"Prompt name '{prompt_name}' not found in the configured prompts dictionary with keys {list(self.prompts.keys())!r}."
  504. )
  505. elif self.default_prompt_name is not None:
  506. prompt = self.prompts.get(self.default_prompt_name, None)
  507. else:
  508. if prompt_name is not None:
  509. logger.warning(
  510. "Encode with either a `prompt`, a `prompt_name`, or neither, but not both. "
  511. "Ignoring the `prompt_name` in favor of `prompt`."
  512. )
  513. extra_features = {}
  514. if prompt is not None:
  515. sentences = [prompt + sentence for sentence in sentences]
  516. # Some models (e.g. INSTRUCTOR, GRIT) require removing the prompt before pooling
  517. # Tracking the prompt length allow us to remove the prompt during pooling
  518. tokenized_prompt = self.tokenize([prompt])
  519. if "input_ids" in tokenized_prompt:
  520. extra_features["prompt_length"] = tokenized_prompt["input_ids"].shape[-1] - 1
  521. if device is None:
  522. device = self.device
  523. self.to(device)
  524. all_embeddings = []
  525. length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences])
  526. sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
  527. for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar):
  528. sentences_batch = sentences_sorted[start_index : start_index + batch_size]
  529. features = self.tokenize(sentences_batch)
  530. if self.device.type == "hpu":
  531. if "input_ids" in features:
  532. curr_tokenize_len = features["input_ids"].shape
  533. additional_pad_len = 2 ** math.ceil(math.log2(curr_tokenize_len[1])) - curr_tokenize_len[1]
  534. features["input_ids"] = torch.cat(
  535. (
  536. features["input_ids"],
  537. torch.ones((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
  538. ),
  539. -1,
  540. )
  541. features["attention_mask"] = torch.cat(
  542. (
  543. features["attention_mask"],
  544. torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
  545. ),
  546. -1,
  547. )
  548. if "token_type_ids" in features:
  549. features["token_type_ids"] = torch.cat(
  550. (
  551. features["token_type_ids"],
  552. torch.zeros((curr_tokenize_len[0], additional_pad_len), dtype=torch.int8),
  553. ),
  554. -1,
  555. )
  556. features = batch_to_device(features, device)
  557. features.update(extra_features)
  558. with torch.no_grad():
  559. out_features = self.forward(features, **kwargs)
  560. if self.device.type == "hpu":
  561. out_features = copy.deepcopy(out_features)
  562. out_features["sentence_embedding"] = truncate_embeddings(
  563. out_features["sentence_embedding"], self.truncate_dim
  564. )
  565. if output_value == "token_embeddings":
  566. embeddings = []
  567. for token_emb, attention in zip(out_features[output_value], out_features["attention_mask"]):
  568. last_mask_id = len(attention) - 1
  569. while last_mask_id > 0 and attention[last_mask_id].item() == 0:
  570. last_mask_id -= 1
  571. embeddings.append(token_emb[0 : last_mask_id + 1])
  572. elif output_value is None: # Return all outputs
  573. embeddings = []
  574. for sent_idx in range(len(out_features["sentence_embedding"])):
  575. row = {name: out_features[name][sent_idx] for name in out_features}
  576. embeddings.append(row)
  577. else: # Sentence embeddings
  578. embeddings = out_features[output_value]
  579. embeddings = embeddings.detach()
  580. if normalize_embeddings:
  581. embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
  582. # fixes for #522 and #487 to avoid oom problems on gpu with large datasets
  583. if convert_to_numpy:
  584. embeddings = embeddings.cpu()
  585. all_embeddings.extend(embeddings)
  586. all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)]
  587. if precision and precision != "float32":
  588. all_embeddings = quantize_embeddings(all_embeddings, precision=precision)
  589. if convert_to_tensor:
  590. if len(all_embeddings):
  591. if isinstance(all_embeddings, np.ndarray):
  592. all_embeddings = torch.from_numpy(all_embeddings)
  593. else:
  594. all_embeddings = torch.stack(all_embeddings)
  595. else:
  596. all_embeddings = torch.Tensor()
  597. elif convert_to_numpy:
  598. if not isinstance(all_embeddings, np.ndarray):
  599. if all_embeddings and all_embeddings[0].dtype == torch.bfloat16:
  600. all_embeddings = np.asarray([emb.float().numpy() for emb in all_embeddings])
  601. else:
  602. all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings])
  603. elif isinstance(all_embeddings, np.ndarray):
  604. all_embeddings = [torch.from_numpy(embedding) for embedding in all_embeddings]
  605. if input_was_string:
  606. all_embeddings = all_embeddings[0]
  607. return all_embeddings
  608. def forward(self, input: dict[str, Tensor], **kwargs) -> dict[str, Tensor]:
  609. if self.module_kwargs is None:
  610. return super().forward(input)
  611. for module_name, module in self.named_children():
  612. module_kwarg_keys = self.module_kwargs.get(module_name, [])
  613. module_kwargs = {key: value for key, value in kwargs.items() if key in module_kwarg_keys}
  614. input = module(input, **module_kwargs)
  615. return input
  616. @property
  617. def similarity_fn_name(self) -> str | None:
  618. """Return the name of the similarity function used by :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise`.
  619. Returns:
  620. Optional[str]: The name of the similarity function. Can be None if not set, in which case any uses of
  621. :meth:`SentenceTransformer.similarity` and :meth:`SentenceTransformer.similarity_pairwise` default to "cosine".
  622. Example:
  623. >>> model = SentenceTransformer("multi-qa-mpnet-base-dot-v1")
  624. >>> model.similarity_fn_name
  625. 'dot'
  626. """
  627. return self._similarity_fn_name
  628. @similarity_fn_name.setter
  629. def similarity_fn_name(self, value: str | SimilarityFunction) -> None:
  630. if isinstance(value, SimilarityFunction):
  631. value = value.value
  632. self._similarity_fn_name = value
  633. if value is not None:
  634. self._similarity = SimilarityFunction.to_similarity_fn(value)
  635. self._similarity_pairwise = SimilarityFunction.to_similarity_pairwise_fn(value)
  636. @overload
  637. def similarity(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...
  638. @overload
  639. def similarity(self, embeddings1: ndarray, embeddings2: ndarray) -> Tensor: ...
  640. @property
  641. def similarity(self) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
  642. """
  643. Compute the similarity between two collections of embeddings. The output will be a matrix with the similarity
  644. scores between all embeddings from the first parameter and all embeddings from the second parameter. This
  645. differs from `similarity_pairwise` which computes the similarity between each pair of embeddings.
  646. Args:
  647. embeddings1 (Union[Tensor, ndarray]): [num_embeddings_1, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
  648. embeddings2 (Union[Tensor, ndarray]): [num_embeddings_2, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
  649. Returns:
  650. Tensor: A [num_embeddings_1, num_embeddings_2]-shaped torch tensor with similarity scores.
  651. Example:
  652. ::
  653. >>> model = SentenceTransformer("all-mpnet-base-v2")
  654. >>> sentences = [
  655. ... "The weather is so nice!",
  656. ... "It's so sunny outside.",
  657. ... "He's driving to the movie theater.",
  658. ... "She's going to the cinema.",
  659. ... ]
  660. >>> embeddings = model.encode(sentences, normalize_embeddings=True)
  661. >>> model.similarity(embeddings, embeddings)
  662. tensor([[1.0000, 0.7235, 0.0290, 0.1309],
  663. [0.7235, 1.0000, 0.0613, 0.1129],
  664. [0.0290, 0.0613, 1.0000, 0.5027],
  665. [0.1309, 0.1129, 0.5027, 1.0000]])
  666. >>> model.similarity_fn_name
  667. "cosine"
  668. >>> model.similarity_fn_name = "euclidean"
  669. >>> model.similarity(embeddings, embeddings)
  670. tensor([[-0.0000, -0.7437, -1.3935, -1.3184],
  671. [-0.7437, -0.0000, -1.3702, -1.3320],
  672. [-1.3935, -1.3702, -0.0000, -0.9973],
  673. [-1.3184, -1.3320, -0.9973, -0.0000]])
  674. """
  675. if self.similarity_fn_name is None:
  676. self.similarity_fn_name = SimilarityFunction.COSINE
  677. return self._similarity
  678. @overload
  679. def similarity_pairwise(self, embeddings1: Tensor, embeddings2: Tensor) -> Tensor: ...
  680. @overload
  681. def similarity_pairwise(self, embeddings1: ndarray, embeddings2: ndarray) -> Tensor: ...
  682. @property
  683. def similarity_pairwise(self) -> Callable[[Tensor | ndarray, Tensor | ndarray], Tensor]:
  684. """
  685. Compute the similarity between two collections of embeddings. The output will be a vector with the similarity
  686. scores between each pair of embeddings.
  687. Args:
  688. embeddings1 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
  689. embeddings2 (Union[Tensor, ndarray]): [num_embeddings, embedding_dim] or [embedding_dim]-shaped numpy array or torch tensor.
  690. Returns:
  691. Tensor: A [num_embeddings]-shaped torch tensor with pairwise similarity scores.
  692. Example:
  693. ::
  694. >>> model = SentenceTransformer("all-mpnet-base-v2")
  695. >>> sentences = [
  696. ... "The weather is so nice!",
  697. ... "It's so sunny outside.",
  698. ... "He's driving to the movie theater.",
  699. ... "She's going to the cinema.",
  700. ... ]
  701. >>> embeddings = model.encode(sentences, normalize_embeddings=True)
  702. >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
  703. tensor([0.7235, 0.5027])
  704. >>> model.similarity_fn_name
  705. "cosine"
  706. >>> model.similarity_fn_name = "euclidean"
  707. >>> model.similarity_pairwise(embeddings[::2], embeddings[1::2])
  708. tensor([-0.7437, -0.9973])
  709. """
  710. if self.similarity_fn_name is None:
  711. self.similarity_fn_name = SimilarityFunction.COSINE
  712. return self._similarity_pairwise
  713. def start_multi_process_pool(
  714. self, target_devices: list[str] = None
  715. ) -> dict[Literal["input", "output", "processes"], Any]:
  716. """
  717. Starts a multi-process pool to process the encoding with several independent processes
  718. via :meth:`SentenceTransformer.encode_multi_process <sentence_transformers.SentenceTransformer.encode_multi_process>`.
  719. This method is recommended if you want to encode on multiple GPUs or CPUs. It is advised
  720. to start only one process per GPU. This method works together with encode_multi_process
  721. and stop_multi_process_pool.
  722. Args:
  723. target_devices (List[str], optional): PyTorch target devices, e.g. ["cuda:0", "cuda:1", ...],
  724. ["npu:0", "npu:1", ...], or ["cpu", "cpu", "cpu", "cpu"]. If target_devices is None and CUDA/NPU
  725. is available, then all available CUDA/NPU devices will be used. If target_devices is None and
  726. CUDA/NPU is not available, then 4 CPU devices will be used.
  727. Returns:
  728. Dict[str, Any]: A dictionary with the target processes, an input queue, and an output queue.
  729. """
  730. if target_devices is None:
  731. if torch.cuda.is_available():
  732. target_devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
  733. elif is_torch_npu_available():
  734. target_devices = [f"npu:{i}" for i in range(torch.npu.device_count())]
  735. else:
  736. logger.info("CUDA/NPU is not available. Starting 4 CPU workers")
  737. target_devices = ["cpu"] * 4
  738. logger.info("Start multi-process pool on devices: {}".format(", ".join(map(str, target_devices))))
  739. self.to("cpu")
  740. self.share_memory()
  741. ctx = mp.get_context("spawn")
  742. input_queue = ctx.Queue()
  743. output_queue = ctx.Queue()
  744. processes = []
  745. for device_id in target_devices:
  746. p = ctx.Process(
  747. target=SentenceTransformer._encode_multi_process_worker,
  748. args=(device_id, self, input_queue, output_queue),
  749. daemon=True,
  750. )
  751. p.start()
  752. processes.append(p)
  753. return {"input": input_queue, "output": output_queue, "processes": processes}
  754. @staticmethod
  755. def stop_multi_process_pool(pool: dict[Literal["input", "output", "processes"], Any]) -> None:
  756. """
  757. Stops all processes started with start_multi_process_pool.
  758. Args:
  759. pool (Dict[str, object]): A dictionary containing the input queue, output queue, and process list.
  760. Returns:
  761. None
  762. """
  763. for p in pool["processes"]:
  764. p.terminate()
  765. for p in pool["processes"]:
  766. p.join()
  767. p.close()
  768. pool["input"].close()
  769. pool["output"].close()
  770. def encode_multi_process(
  771. self,
  772. sentences: list[str],
  773. pool: dict[Literal["input", "output", "processes"], Any],
  774. prompt_name: str | None = None,
  775. prompt: str | None = None,
  776. batch_size: int = 32,
  777. chunk_size: int = None,
  778. show_progress_bar: bool | None = None,
  779. precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32",
  780. normalize_embeddings: bool = False,
  781. ) -> np.ndarray:
  782. """
  783. Encodes a list of sentences using multiple processes and GPUs via
  784. :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
  785. The sentences are chunked into smaller packages and sent to individual processes, which encode them on different
  786. GPUs or CPUs. This method is only suitable for encoding large sets of sentences.
  787. Args:
  788. sentences (List[str]): List of sentences to encode.
  789. pool (Dict[Literal["input", "output", "processes"], Any]): A pool of workers started with
  790. :meth:`SentenceTransformer.start_multi_process_pool <sentence_transformers.SentenceTransformer.start_multi_process_pool>`.
  791. prompt_name (Optional[str], optional): The name of the prompt to use for encoding. Must be a key in the `prompts` dictionary,
  792. which is either set in the constructor or loaded from the model configuration. For example if
  793. ``prompt_name`` is "query" and the ``prompts`` is {"query": "query: ", ...}, then the sentence "What
  794. is the capital of France?" will be encoded as "query: What is the capital of France?" because the sentence
  795. is appended to the prompt. If ``prompt`` is also set, this argument is ignored. Defaults to None.
  796. prompt (Optional[str], optional): The prompt to use for encoding. For example, if the prompt is "query: ", then the
  797. sentence "What is the capital of France?" will be encoded as "query: What is the capital of France?"
  798. because the sentence is appended to the prompt. If ``prompt`` is set, ``prompt_name`` is ignored. Defaults to None.
  799. batch_size (int): Encode sentences with batch size. (default: 32)
  800. chunk_size (int): Sentences are chunked and sent to the individual processes. If None, it determines a
  801. sensible size. Defaults to None.
  802. show_progress_bar (bool, optional): Whether to output a progress bar when encode sentences. Defaults to None.
  803. precision (Literal["float32", "int8", "uint8", "binary", "ubinary"]): The precision to use for the
  804. embeddings. Can be "float32", "int8", "uint8", "binary", or "ubinary". All non-float32 precisions
  805. are quantized embeddings. Quantized embeddings are smaller in size and faster to compute, but may
  806. have lower accuracy. They are useful for reducing the size of the embeddings of a corpus for
  807. semantic search, among other tasks. Defaults to "float32".
  808. normalize_embeddings (bool): Whether to normalize returned vectors to have length 1. In that case,
  809. the faster dot-product (util.dot_score) instead of cosine similarity can be used. Defaults to False.
  810. Returns:
  811. np.ndarray: A 2D numpy array with shape [num_inputs, output_dimension].
  812. Example:
  813. ::
  814. from sentence_transformers import SentenceTransformer
  815. def main():
  816. model = SentenceTransformer("all-mpnet-base-v2")
  817. sentences = ["The weather is so nice!", "It's so sunny outside.", "He's driving to the movie theater.", "She's going to the cinema."] * 1000
  818. pool = model.start_multi_process_pool()
  819. embeddings = model.encode_multi_process(sentences, pool)
  820. model.stop_multi_process_pool(pool)
  821. print(embeddings.shape)
  822. # => (4000, 768)
  823. if __name__ == "__main__":
  824. main()
  825. """
  826. if chunk_size is None:
  827. chunk_size = min(math.ceil(len(sentences) / len(pool["processes"]) / 10), 5000)
  828. if show_progress_bar is None:
  829. show_progress_bar = logger.getEffectiveLevel() in (logging.INFO, logging.DEBUG)
  830. logger.debug(f"Chunk data into {math.ceil(len(sentences) / chunk_size)} packages of size {chunk_size}")
  831. input_queue = pool["input"]
  832. last_chunk_id = 0
  833. chunk = []
  834. for sentence in sentences:
  835. chunk.append(sentence)
  836. if len(chunk) >= chunk_size:
  837. input_queue.put(
  838. [last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings]
  839. )
  840. last_chunk_id += 1
  841. chunk = []
  842. if len(chunk) > 0:
  843. input_queue.put([last_chunk_id, batch_size, chunk, prompt_name, prompt, precision, normalize_embeddings])
  844. last_chunk_id += 1
  845. output_queue = pool["output"]
  846. results_list = sorted(
  847. [output_queue.get() for _ in trange(last_chunk_id, desc="Chunks", disable=not show_progress_bar)],
  848. key=lambda x: x[0],
  849. )
  850. embeddings = np.concatenate([result[1] for result in results_list])
  851. return embeddings
  852. @staticmethod
  853. def _encode_multi_process_worker(
  854. target_device: str, model: SentenceTransformer, input_queue: Queue, results_queue: Queue
  855. ) -> None:
  856. """
  857. Internal working process to encode sentences in multi-process setup
  858. """
  859. while True:
  860. try:
  861. chunk_id, batch_size, sentences, prompt_name, prompt, precision, normalize_embeddings = (
  862. input_queue.get()
  863. )
  864. embeddings = model.encode(
  865. sentences,
  866. prompt_name=prompt_name,
  867. prompt=prompt,
  868. device=target_device,
  869. show_progress_bar=False,
  870. precision=precision,
  871. convert_to_numpy=True,
  872. batch_size=batch_size,
  873. normalize_embeddings=normalize_embeddings,
  874. )
  875. results_queue.put([chunk_id, embeddings])
  876. except queue.Empty:
  877. break
  878. def set_pooling_include_prompt(self, include_prompt: bool) -> None:
  879. """
  880. Sets the `include_prompt` attribute in the pooling layer in the model, if there is one.
  881. This is useful for INSTRUCTOR models, as the prompt should be excluded from the pooling strategy
  882. for these models.
  883. Args:
  884. include_prompt (bool): Whether to include the prompt in the pooling layer.
  885. Returns:
  886. None
  887. """
  888. for module in self:
  889. if isinstance(module, Pooling):
  890. module.include_prompt = include_prompt
  891. break
  892. def get_max_seq_length(self) -> int | None:
  893. """
  894. Returns the maximal sequence length that the model accepts. Longer inputs will be truncated.
  895. Returns:
  896. Optional[int]: The maximal sequence length that the model accepts, or None if it is not defined.
  897. """
  898. if hasattr(self._first_module(), "max_seq_length"):
  899. return self._first_module().max_seq_length
  900. return None
  901. def tokenize(self, texts: list[str] | list[dict] | list[tuple[str, str]]) -> dict[str, Tensor]:
  902. """
  903. Tokenizes the texts.
  904. Args:
  905. texts (Union[List[str], List[Dict], List[Tuple[str, str]]]): A list of texts to be tokenized.
  906. Returns:
  907. Dict[str, Tensor]: A dictionary of tensors with the tokenized texts. Common keys are "input_ids",
  908. "attention_mask", and "token_type_ids".
  909. """
  910. return self._first_module().tokenize(texts)
  911. def get_sentence_features(self, *features) -> dict[Literal["sentence_embedding"], Tensor]:
  912. return self._first_module().get_sentence_features(*features)
  913. def get_sentence_embedding_dimension(self) -> int | None:
  914. """
  915. Returns the number of dimensions in the output of :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>`.
  916. Returns:
  917. Optional[int]: The number of dimensions in the output of `encode`. If it's not known, it's `None`.
  918. """
  919. output_dim = None
  920. for mod in reversed(self._modules.values()):
  921. sent_embedding_dim_method = getattr(mod, "get_sentence_embedding_dimension", None)
  922. if callable(sent_embedding_dim_method):
  923. output_dim = sent_embedding_dim_method()
  924. break
  925. if self.truncate_dim is not None:
  926. # The user requested truncation. If they set it to a dim greater than output_dim,
  927. # no truncation will actually happen. So return output_dim instead of self.truncate_dim
  928. return min(output_dim or np.inf, self.truncate_dim)
  929. return output_dim
  930. @contextmanager
  931. def truncate_sentence_embeddings(self, truncate_dim: int | None) -> Iterator[None]:
  932. """
  933. In this context, :meth:`SentenceTransformer.encode <sentence_transformers.SentenceTransformer.encode>` outputs
  934. sentence embeddings truncated at dimension ``truncate_dim``.
  935. This may be useful when you are using the same model for different applications where different dimensions
  936. are needed.
  937. Args:
  938. truncate_dim (int, optional): The dimension to truncate sentence embeddings to. ``None`` does no truncation.
  939. Example:
  940. ::
  941. from sentence_transformers import SentenceTransformer
  942. model = SentenceTransformer("all-mpnet-base-v2")
  943. with model.truncate_sentence_embeddings(truncate_dim=16):
  944. embeddings_truncated = model.encode(["hello there", "hiya"])
  945. assert embeddings_truncated.shape[-1] == 16
  946. """
  947. original_output_dim = self.truncate_dim
  948. try:
  949. self.truncate_dim = truncate_dim
  950. yield
  951. finally:
  952. self.truncate_dim = original_output_dim
  953. def _first_module(self) -> torch.nn.Module:
  954. """Returns the first module of this sequential embedder"""
  955. return self._modules[next(iter(self._modules))]
  956. def _last_module(self) -> torch.nn.Module:
  957. """Returns the last module of this sequential embedder"""
  958. return self._modules[next(reversed(self._modules))]
  959. def save(
  960. self,
  961. path: str,
  962. model_name: str | None = None,
  963. create_model_card: bool = True,
  964. train_datasets: list[str] | None = None,
  965. safe_serialization: bool = True,
  966. ) -> None:
  967. """
  968. Saves a model and its configuration files to a directory, so that it can be loaded
  969. with ``SentenceTransformer(path)`` again.
  970. Args:
  971. path (str): Path on disc where the model will be saved.
  972. model_name (str, optional): Optional model name.
  973. create_model_card (bool, optional): If True, create a README.md with basic information about this model.
  974. train_datasets (List[str], optional): Optional list with the names of the datasets used to train the model.
  975. safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model
  976. the traditional (but unsafe) PyTorch way.
  977. """
  978. if path is None:
  979. return
  980. os.makedirs(path, exist_ok=True)
  981. logger.info(f"Save model to {path}")
  982. modules_config = []
  983. # Save some model info
  984. self._model_config["__version__"] = {
  985. "sentence_transformers": __version__,
  986. "transformers": transformers.__version__,
  987. "pytorch": torch.__version__,
  988. }
  989. with open(os.path.join(path, "config_sentence_transformers.json"), "w") as fOut:
  990. config = self._model_config.copy()
  991. config["prompts"] = self.prompts
  992. config["default_prompt_name"] = self.default_prompt_name
  993. config["similarity_fn_name"] = self.similarity_fn_name
  994. json.dump(config, fOut, indent=2)
  995. # Save modules
  996. for idx, name in enumerate(self._modules):
  997. module = self._modules[name]
  998. if idx == 0 and hasattr(module, "save_in_root"): # Save first module in the main folder
  999. model_path = path + "/"
  1000. else:
  1001. model_path = os.path.join(path, str(idx) + "_" + type(module).__name__)
  1002. os.makedirs(model_path, exist_ok=True)
  1003. # Try to save with safetensors, but fall back to the traditional PyTorch way if the module doesn't support it
  1004. try:
  1005. module.save(model_path, safe_serialization=safe_serialization)
  1006. except TypeError:
  1007. module.save(model_path)
  1008. # "module" only works for Sentence Transformers as the modules have the same names as the classes
  1009. class_ref = type(module).__module__
  1010. # For remote modules, we want to remove "transformers_modules.{repo_name}":
  1011. if class_ref.startswith("transformers_modules."):
  1012. class_file = sys.modules[class_ref].__file__
  1013. # Save the custom module file
  1014. dest_file = Path(model_path) / (Path(class_file).name)
  1015. shutil.copy(class_file, dest_file)
  1016. # Save all files importeed in the custom module file
  1017. for needed_file in get_relative_import_files(class_file):
  1018. dest_file = Path(model_path) / (Path(needed_file).name)
  1019. shutil.copy(needed_file, dest_file)
  1020. # For remote modules, we want to ignore the "transformers_modules.{repo_id}" part,
  1021. # i.e. we only want the filename
  1022. class_ref = f"{class_ref.split('.')[-1]}.{type(module).__name__}"
  1023. # For other cases, we want to add the class name:
  1024. elif not class_ref.startswith("sentence_transformers."):
  1025. class_ref = f"{class_ref}.{type(module).__name__}"
  1026. modules_config.append({"idx": idx, "name": name, "path": os.path.basename(model_path), "type": class_ref})
  1027. with open(os.path.join(path, "modules.json"), "w") as fOut:
  1028. json.dump(modules_config, fOut, indent=2)
  1029. # Create model card
  1030. if create_model_card:
  1031. self._create_model_card(path, model_name, train_datasets)
  1032. def save_pretrained(
  1033. self,
  1034. path: str,
  1035. model_name: str | None = None,
  1036. create_model_card: bool = True,
  1037. train_datasets: list[str] | None = None,
  1038. safe_serialization: bool = True,
  1039. ) -> None:
  1040. """
  1041. Saves a model and its configuration files to a directory, so that it can be loaded
  1042. with ``SentenceTransformer(path)`` again.
  1043. Args:
  1044. path (str): Path on disc where the model will be saved.
  1045. model_name (str, optional): Optional model name.
  1046. create_model_card (bool, optional): If True, create a README.md with basic information about this model.
  1047. train_datasets (List[str], optional): Optional list with the names of the datasets used to train the model.
  1048. safe_serialization (bool, optional): If True, save the model using safetensors. If False, save the model
  1049. the traditional (but unsafe) PyTorch way.
  1050. """
  1051. self.save(
  1052. path,
  1053. model_name=model_name,
  1054. create_model_card=create_model_card,
  1055. train_datasets=train_datasets,
  1056. safe_serialization=safe_serialization,
  1057. )
  1058. def _create_model_card(
  1059. self, path: str, model_name: str | None = None, train_datasets: list[str] | None = "deprecated"
  1060. ) -> None:
  1061. """
  1062. Create an automatic model and stores it in the specified path. If no training was done and the loaded model
  1063. was a Sentence Transformer model already, then its model card is reused.
  1064. Args:
  1065. path (str): The path where the model card will be stored.
  1066. model_name (Optional[str], optional): The name of the model. Defaults to None.
  1067. train_datasets (Optional[List[str]], optional): Deprecated argument. Defaults to "deprecated".
  1068. Returns:
  1069. None
  1070. """
  1071. if model_name:
  1072. model_path = Path(model_name)
  1073. if not model_path.exists() and not self.model_card_data.model_id:
  1074. self.model_card_data.model_id = model_name
  1075. # If we loaded a Sentence Transformer model from the Hub, and no training was done, then
  1076. # we don't generate a new model card, but reuse the old one instead.
  1077. if self._model_card_text and self.model_card_data.trainer is None:
  1078. model_card = self._model_card_text
  1079. if self.model_card_data.model_id:
  1080. # If the original model card was saved without a model_id, we replace the model_id with the new model_id
  1081. model_card = model_card.replace(
  1082. 'model = SentenceTransformer("sentence_transformers_model_id"',
  1083. f'model = SentenceTransformer("{self.model_card_data.model_id}"',
  1084. )
  1085. else:
  1086. try:
  1087. model_card = generate_model_card(self)
  1088. except Exception:
  1089. logger.error(
  1090. f"Error while generating model card:\n{traceback.format_exc()}"
  1091. "Consider opening an issue on https://github.com/UKPLab/sentence-transformers/issues with this traceback.\n"
  1092. "Skipping model card creation."
  1093. )
  1094. return
  1095. with open(os.path.join(path, "README.md"), "w", encoding="utf8") as fOut:
  1096. fOut.write(model_card)
  1097. @save_to_hub_args_decorator
  1098. def save_to_hub(
  1099. self,
  1100. repo_id: str,
  1101. organization: str | None = None,
  1102. token: str | None = None,
  1103. private: bool | None = None,
  1104. safe_serialization: bool = True,
  1105. commit_message: str = "Add new SentenceTransformer model.",
  1106. local_model_path: str | None = None,
  1107. exist_ok: bool = False,
  1108. replace_model_card: bool = False,
  1109. train_datasets: list[str] | None = None,
  1110. ) -> str:
  1111. """
  1112. DEPRECATED, use `push_to_hub` instead.
  1113. Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
  1114. Args:
  1115. repo_id (str): Repository name for your model in the Hub, including the user or organization.
  1116. token (str, optional): An authentication token (See https://huggingface.co/settings/token)
  1117. private (bool, optional): Set to true, for hosting a private model
  1118. safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way
  1119. commit_message (str, optional): Message to commit while pushing.
  1120. local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
  1121. exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
  1122. replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card
  1123. train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
  1124. Returns:
  1125. str: The url of the commit of your model in the repository on the Hugging Face Hub.
  1126. """
  1127. logger.warning(
  1128. "The `save_to_hub` method is deprecated and will be removed in a future version of SentenceTransformers."
  1129. " Please use `push_to_hub` instead for future model uploads."
  1130. )
  1131. if organization:
  1132. if "/" not in repo_id:
  1133. logger.warning(
  1134. f'Providing an `organization` to `save_to_hub` is deprecated, please use `repo_id="{organization}/{repo_id}"` instead.'
  1135. )
  1136. repo_id = f"{organization}/{repo_id}"
  1137. elif repo_id.split("/")[0] != organization:
  1138. raise ValueError(
  1139. "Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id`."
  1140. )
  1141. else:
  1142. logger.warning(
  1143. f'Providing an `organization` to `save_to_hub` is deprecated, please only use `repo_id="{repo_id}"` instead.'
  1144. )
  1145. return self.push_to_hub(
  1146. repo_id=repo_id,
  1147. token=token,
  1148. private=private,
  1149. safe_serialization=safe_serialization,
  1150. commit_message=commit_message,
  1151. local_model_path=local_model_path,
  1152. exist_ok=exist_ok,
  1153. replace_model_card=replace_model_card,
  1154. train_datasets=train_datasets,
  1155. )
  1156. def push_to_hub(
  1157. self,
  1158. repo_id: str,
  1159. token: str | None = None,
  1160. private: bool | None = None,
  1161. safe_serialization: bool = True,
  1162. commit_message: str | None = None,
  1163. local_model_path: str | None = None,
  1164. exist_ok: bool = False,
  1165. replace_model_card: bool = False,
  1166. train_datasets: list[str] | None = None,
  1167. revision: str | None = None,
  1168. create_pr: bool = False,
  1169. ) -> str:
  1170. """
  1171. Uploads all elements of this Sentence Transformer to a new HuggingFace Hub repository.
  1172. Args:
  1173. repo_id (str): Repository name for your model in the Hub, including the user or organization.
  1174. token (str, optional): An authentication token (See https://huggingface.co/settings/token)
  1175. private (bool, optional): Set to true, for hosting a private model
  1176. safe_serialization (bool, optional): If true, save the model using safetensors. If false, save the model the traditional PyTorch way
  1177. commit_message (str, optional): Message to commit while pushing.
  1178. local_model_path (str, optional): Path of the model locally. If set, this file path will be uploaded. Otherwise, the current model will be uploaded
  1179. exist_ok (bool, optional): If true, saving to an existing repository is OK. If false, saving only to a new repository is possible
  1180. replace_model_card (bool, optional): If true, replace an existing model card in the hub with the automatically created model card
  1181. train_datasets (List[str], optional): Datasets used to train the model. If set, the datasets will be added to the model card in the Hub.
  1182. revision (str, optional): Branch to push the uploaded files to
  1183. create_pr (bool, optional): If True, create a pull request instead of pushing directly to the main branch
  1184. Returns:
  1185. str: The url of the commit of your model in the repository on the Hugging Face Hub.
  1186. """
  1187. api = HfApi(token=token)
  1188. repo_url = api.create_repo(
  1189. repo_id=repo_id,
  1190. private=private,
  1191. repo_type=None,
  1192. exist_ok=exist_ok or create_pr,
  1193. )
  1194. repo_id = repo_url.repo_id # Update the repo_id in case the old repo_id didn't contain a user or organization
  1195. self.model_card_data.set_model_id(repo_id)
  1196. if revision is not None:
  1197. api.create_branch(repo_id=repo_id, branch=revision, exist_ok=True)
  1198. if commit_message is None:
  1199. backend = self.get_backend()
  1200. if backend == "torch":
  1201. commit_message = "Add new SentenceTransformer model"
  1202. else:
  1203. commit_message = f"Add new SentenceTransformer model with an {backend} backend"
  1204. commit_description = ""
  1205. if create_pr:
  1206. commit_description = f"""\
  1207. Hello!
  1208. *This pull request has been automatically generated from the [`push_to_hub`](https://sbert.net/docs/package_reference/sentence_transformer/SentenceTransformer.html#sentence_transformers.SentenceTransformer.push_to_hub) method from the Sentence Transformers library.*
  1209. ## Full Model Architecture:
  1210. ```
  1211. {self}
  1212. ```
  1213. ## Tip:
  1214. Consider testing this pull request before merging by loading the model from this PR with the `revision` argument:
  1215. ```python
  1216. from sentence_transformers import SentenceTransformer
  1217. # TODO: Fill in the PR number
  1218. pr_number = 2
  1219. model = SentenceTransformer(
  1220. "{repo_id}",
  1221. revision=f"refs/pr/{{pr_number}}",
  1222. backend="{self.get_backend()}",
  1223. )
  1224. # Verify that everything works as expected
  1225. embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."])
  1226. print(embeddings.shape)
  1227. similarities = model.similarity(embeddings, embeddings)
  1228. print(similarities)
  1229. ```
  1230. """
  1231. if local_model_path:
  1232. folder_url = api.upload_folder(
  1233. repo_id=repo_id,
  1234. folder_path=local_model_path,
  1235. commit_message=commit_message,
  1236. commit_description=commit_description,
  1237. revision=revision,
  1238. create_pr=create_pr,
  1239. )
  1240. else:
  1241. with tempfile.TemporaryDirectory() as tmp_dir:
  1242. create_model_card = replace_model_card or not os.path.exists(os.path.join(tmp_dir, "README.md"))
  1243. self.save_pretrained(
  1244. tmp_dir,
  1245. model_name=repo_url.repo_id,
  1246. create_model_card=create_model_card,
  1247. train_datasets=train_datasets,
  1248. safe_serialization=safe_serialization,
  1249. )
  1250. folder_url = api.upload_folder(
  1251. repo_id=repo_id,
  1252. folder_path=tmp_dir,
  1253. commit_message=commit_message,
  1254. commit_description=commit_description,
  1255. revision=revision,
  1256. create_pr=create_pr,
  1257. )
  1258. if create_pr:
  1259. return folder_url.pr_url
  1260. return folder_url.commit_url
  1261. def _text_length(self, text: list[int] | list[list[int]]) -> int:
  1262. """
  1263. Help function to get the length for the input text. Text can be either
  1264. a list of ints (which means a single text as input), or a tuple of list of ints
  1265. (representing several text inputs to the model).
  1266. """
  1267. if isinstance(text, dict): # {key: value} case
  1268. return len(next(iter(text.values())))
  1269. elif not hasattr(text, "__len__"): # Object has no len() method
  1270. return 1
  1271. elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
  1272. return len(text)
  1273. else:
  1274. return sum([len(t) for t in text]) # Sum of length of individual strings
  1275. def evaluate(self, evaluator: SentenceEvaluator, output_path: str = None) -> dict[str, float] | float:
  1276. """
  1277. Evaluate the model based on an evaluator
  1278. Args:
  1279. evaluator (SentenceEvaluator): The evaluator used to evaluate the model.
  1280. output_path (str, optional): The path where the evaluator can write the results. Defaults to None.
  1281. Returns:
  1282. The evaluation results.
  1283. """
  1284. if output_path is not None:
  1285. os.makedirs(output_path, exist_ok=True)
  1286. return evaluator(self, output_path)
  1287. def _load_auto_model(
  1288. self,
  1289. model_name_or_path: str,
  1290. token: bool | str | None,
  1291. cache_folder: str | None,
  1292. revision: str | None = None,
  1293. trust_remote_code: bool = False,
  1294. local_files_only: bool = False,
  1295. model_kwargs: dict[str, Any] | None = None,
  1296. tokenizer_kwargs: dict[str, Any] | None = None,
  1297. config_kwargs: dict[str, Any] | None = None,
  1298. ) -> list[nn.Module]:
  1299. """
  1300. Creates a simple Transformer + Mean Pooling model and returns the modules
  1301. Args:
  1302. model_name_or_path (str): The name or path of the pre-trained model.
  1303. token (Optional[Union[bool, str]]): The token to use for the model.
  1304. cache_folder (Optional[str]): The folder to cache the model.
  1305. revision (Optional[str], optional): The revision of the model. Defaults to None.
  1306. trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
  1307. local_files_only (bool, optional): Whether to use only local files. Defaults to False.
  1308. model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the model. Defaults to None.
  1309. tokenizer_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the tokenizer. Defaults to None.
  1310. config_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the config. Defaults to None.
  1311. Returns:
  1312. List[nn.Module]: A list containing the transformer model and the pooling model.
  1313. """
  1314. logger.warning(
  1315. f"No sentence-transformers model found with name {model_name_or_path}. Creating a new one with mean pooling."
  1316. )
  1317. shared_kwargs = {
  1318. "token": token,
  1319. "trust_remote_code": trust_remote_code,
  1320. "revision": revision,
  1321. "local_files_only": local_files_only,
  1322. }
  1323. model_kwargs = shared_kwargs if model_kwargs is None else {**shared_kwargs, **model_kwargs}
  1324. tokenizer_kwargs = shared_kwargs if tokenizer_kwargs is None else {**shared_kwargs, **tokenizer_kwargs}
  1325. config_kwargs = shared_kwargs if config_kwargs is None else {**shared_kwargs, **config_kwargs}
  1326. transformer_model = Transformer(
  1327. model_name_or_path,
  1328. cache_dir=cache_folder,
  1329. model_args=model_kwargs,
  1330. tokenizer_args=tokenizer_kwargs,
  1331. config_args=config_kwargs,
  1332. backend=self.backend,
  1333. )
  1334. pooling_model = Pooling(transformer_model.get_word_embedding_dimension(), "mean")
  1335. self.model_card_data.set_base_model(model_name_or_path, revision=revision)
  1336. return [transformer_model, pooling_model]
  1337. def _load_module_class_from_ref(
  1338. self,
  1339. class_ref: str,
  1340. model_name_or_path: str,
  1341. trust_remote_code: bool,
  1342. revision: str | None,
  1343. model_kwargs: dict[str, Any] | None,
  1344. ) -> nn.Module:
  1345. # If the class is from sentence_transformers, we can directly import it,
  1346. # otherwise, we try to import it dynamically, and if that fails, we fall back to the default import
  1347. if class_ref.startswith("sentence_transformers."):
  1348. return import_from_string(class_ref)
  1349. if trust_remote_code:
  1350. code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None
  1351. try:
  1352. return get_class_from_dynamic_module(
  1353. class_ref,
  1354. model_name_or_path,
  1355. revision=revision,
  1356. code_revision=code_revision,
  1357. )
  1358. except OSError:
  1359. # Ignore the error if the file does not exist, and fall back to the default import
  1360. pass
  1361. return import_from_string(class_ref)
  1362. def _load_sbert_model(
  1363. self,
  1364. model_name_or_path: str,
  1365. token: bool | str | None,
  1366. cache_folder: str | None,
  1367. revision: str | None = None,
  1368. trust_remote_code: bool = False,
  1369. local_files_only: bool = False,
  1370. model_kwargs: dict[str, Any] | None = None,
  1371. tokenizer_kwargs: dict[str, Any] | None = None,
  1372. config_kwargs: dict[str, Any] | None = None,
  1373. ) -> dict[str, nn.Module]:
  1374. """
  1375. Loads a full SentenceTransformer model using the modules.json file.
  1376. Args:
  1377. model_name_or_path (str): The name or path of the pre-trained model.
  1378. token (Optional[Union[bool, str]]): The token to use for the model.
  1379. cache_folder (Optional[str]): The folder to cache the model.
  1380. revision (Optional[str], optional): The revision of the model. Defaults to None.
  1381. trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
  1382. local_files_only (bool, optional): Whether to use only local files. Defaults to False.
  1383. model_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the model. Defaults to None.
  1384. tokenizer_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the tokenizer. Defaults to None.
  1385. config_kwargs (Optional[Dict[str, Any]], optional): Additional keyword arguments for the config. Defaults to None.
  1386. Returns:
  1387. OrderedDict[str, nn.Module]: An ordered dictionary containing the modules of the model.
  1388. """
  1389. # Check if the config_sentence_transformers.json file exists (exists since v2 of the framework)
  1390. config_sentence_transformers_json_path = load_file_path(
  1391. model_name_or_path,
  1392. "config_sentence_transformers.json",
  1393. token=token,
  1394. cache_folder=cache_folder,
  1395. revision=revision,
  1396. local_files_only=local_files_only,
  1397. )
  1398. if config_sentence_transformers_json_path is not None:
  1399. with open(config_sentence_transformers_json_path) as fIn:
  1400. self._model_config = json.load(fIn)
  1401. if (
  1402. "__version__" in self._model_config
  1403. and "sentence_transformers" in self._model_config["__version__"]
  1404. and self._model_config["__version__"]["sentence_transformers"] > __version__
  1405. ):
  1406. logger.warning(
  1407. "You try to use a model that was created with version {}, however, your version is {}. This might cause unexpected behavior or errors. In that case, try to update to the latest version.\n\n\n".format(
  1408. self._model_config["__version__"]["sentence_transformers"], __version__
  1409. )
  1410. )
  1411. # Set score functions & prompts if not already overridden by the __init__ calls
  1412. if self.similarity_fn_name is None:
  1413. self.similarity_fn_name = self._model_config.get("similarity_fn_name", None)
  1414. if not self.prompts:
  1415. self.prompts = self._model_config.get("prompts", {})
  1416. if not self.default_prompt_name:
  1417. self.default_prompt_name = self._model_config.get("default_prompt_name", None)
  1418. # Check if a readme exists
  1419. model_card_path = load_file_path(
  1420. model_name_or_path,
  1421. "README.md",
  1422. token=token,
  1423. cache_folder=cache_folder,
  1424. revision=revision,
  1425. local_files_only=local_files_only,
  1426. )
  1427. if model_card_path is not None:
  1428. try:
  1429. with open(model_card_path, encoding="utf8") as fIn:
  1430. self._model_card_text = fIn.read()
  1431. except Exception:
  1432. pass
  1433. # Load the modules of sentence transformer
  1434. modules_json_path = load_file_path(
  1435. model_name_or_path,
  1436. "modules.json",
  1437. token=token,
  1438. cache_folder=cache_folder,
  1439. revision=revision,
  1440. local_files_only=local_files_only,
  1441. )
  1442. with open(modules_json_path) as fIn:
  1443. modules_config = json.load(fIn)
  1444. modules = OrderedDict()
  1445. module_kwargs = OrderedDict()
  1446. for module_config in modules_config:
  1447. class_ref = module_config["type"]
  1448. module_class = self._load_module_class_from_ref(
  1449. class_ref, model_name_or_path, trust_remote_code, revision, model_kwargs
  1450. )
  1451. # For Transformer, don't load the full directory, rely on `transformers` instead
  1452. # But, do load the config file first.
  1453. if module_config["path"] == "":
  1454. kwargs = {}
  1455. for config_name in [
  1456. "sentence_bert_config.json",
  1457. "sentence_roberta_config.json",
  1458. "sentence_distilbert_config.json",
  1459. "sentence_camembert_config.json",
  1460. "sentence_albert_config.json",
  1461. "sentence_xlm-roberta_config.json",
  1462. "sentence_xlnet_config.json",
  1463. ]:
  1464. config_path = load_file_path(
  1465. model_name_or_path,
  1466. config_name,
  1467. token=token,
  1468. cache_folder=cache_folder,
  1469. revision=revision,
  1470. local_files_only=local_files_only,
  1471. )
  1472. if config_path is not None:
  1473. with open(config_path) as fIn:
  1474. kwargs = json.load(fIn)
  1475. # Don't allow configs to set trust_remote_code
  1476. if "model_args" in kwargs and "trust_remote_code" in kwargs["model_args"]:
  1477. kwargs["model_args"].pop("trust_remote_code")
  1478. if "tokenizer_args" in kwargs and "trust_remote_code" in kwargs["tokenizer_args"]:
  1479. kwargs["tokenizer_args"].pop("trust_remote_code")
  1480. if "config_args" in kwargs and "trust_remote_code" in kwargs["config_args"]:
  1481. kwargs["config_args"].pop("trust_remote_code")
  1482. break
  1483. hub_kwargs = {
  1484. "token": token,
  1485. "trust_remote_code": trust_remote_code,
  1486. "revision": revision,
  1487. "local_files_only": local_files_only,
  1488. }
  1489. # 3rd priority: config file
  1490. if "model_args" not in kwargs:
  1491. kwargs["model_args"] = {}
  1492. if "tokenizer_args" not in kwargs:
  1493. kwargs["tokenizer_args"] = {}
  1494. if "config_args" not in kwargs:
  1495. kwargs["config_args"] = {}
  1496. # 2nd priority: hub_kwargs
  1497. kwargs["model_args"].update(hub_kwargs)
  1498. kwargs["tokenizer_args"].update(hub_kwargs)
  1499. kwargs["config_args"].update(hub_kwargs)
  1500. # 1st priority: kwargs passed to SentenceTransformer
  1501. if model_kwargs:
  1502. kwargs["model_args"].update(model_kwargs)
  1503. if tokenizer_kwargs:
  1504. kwargs["tokenizer_args"].update(tokenizer_kwargs)
  1505. if config_kwargs:
  1506. kwargs["config_args"].update(config_kwargs)
  1507. # Try to initialize the module with a lot of kwargs, but only if the module supports them
  1508. # Otherwise we fall back to the load method
  1509. try:
  1510. module = module_class(model_name_or_path, cache_dir=cache_folder, backend=self.backend, **kwargs)
  1511. except TypeError:
  1512. module = module_class.load(model_name_or_path)
  1513. else:
  1514. # Normalize does not require any files to be loaded
  1515. if module_class == Normalize:
  1516. module_path = None
  1517. else:
  1518. module_path = load_dir_path(
  1519. model_name_or_path,
  1520. module_config["path"],
  1521. token=token,
  1522. cache_folder=cache_folder,
  1523. revision=revision,
  1524. local_files_only=local_files_only,
  1525. )
  1526. module = module_class.load(module_path)
  1527. modules[module_config["name"]] = module
  1528. module_kwargs[module_config["name"]] = module_config.get("kwargs", [])
  1529. if revision is None:
  1530. path_parts = Path(modules_json_path)
  1531. if len(path_parts.parts) >= 2:
  1532. revision_path_part = Path(modules_json_path).parts[-2]
  1533. if len(revision_path_part) == 40:
  1534. revision = revision_path_part
  1535. self.model_card_data.set_base_model(model_name_or_path, revision=revision)
  1536. return modules, module_kwargs
  1537. @staticmethod
  1538. def load(input_path) -> SentenceTransformer:
  1539. return SentenceTransformer(input_path)
  1540. @property
  1541. def device(self) -> device:
  1542. """
  1543. Get torch.device from module, assuming that the whole module has one device.
  1544. In case there are no PyTorch parameters, fall back to CPU.
  1545. """
  1546. if isinstance(self[0], Transformer):
  1547. return self[0].auto_model.device
  1548. try:
  1549. return next(self.parameters()).device
  1550. except StopIteration:
  1551. # For nn.DataParallel compatibility in PyTorch 1.5
  1552. def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
  1553. tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
  1554. return tuples
  1555. gen = self._named_members(get_members_fn=find_tensor_attributes)
  1556. try:
  1557. first_tuple = next(gen)
  1558. return first_tuple[1].device
  1559. except StopIteration:
  1560. return torch.device("cpu")
  1561. @property
  1562. def tokenizer(self) -> Any:
  1563. """
  1564. Property to get the tokenizer that is used by this model
  1565. """
  1566. return self._first_module().tokenizer
  1567. @tokenizer.setter
  1568. def tokenizer(self, value) -> None:
  1569. """
  1570. Property to set the tokenizer that should be used by this model
  1571. """
  1572. self._first_module().tokenizer = value
  1573. @property
  1574. def max_seq_length(self) -> int:
  1575. """
  1576. Returns the maximal input sequence length for the model. Longer inputs will be truncated.
  1577. Returns:
  1578. int: The maximal input sequence length.
  1579. Example:
  1580. ::
  1581. from sentence_transformers import SentenceTransformer
  1582. model = SentenceTransformer("all-mpnet-base-v2")
  1583. print(model.max_seq_length)
  1584. # => 384
  1585. """
  1586. return self._first_module().max_seq_length
  1587. @max_seq_length.setter
  1588. def max_seq_length(self, value) -> None:
  1589. """
  1590. Property to set the maximal input sequence length for the model. Longer inputs will be truncated.
  1591. """
  1592. self._first_module().max_seq_length = value
  1593. @property
  1594. def _target_device(self) -> torch.device:
  1595. logger.warning(
  1596. "`SentenceTransformer._target_device` has been deprecated, please use `SentenceTransformer.device` instead.",
  1597. )
  1598. return self.device
  1599. @_target_device.setter
  1600. def _target_device(self, device: int | str | torch.device | None = None) -> None:
  1601. self.to(device)
  1602. @property
  1603. def _no_split_modules(self) -> list[str]:
  1604. try:
  1605. return self._first_module()._no_split_modules
  1606. except AttributeError:
  1607. return []
  1608. @property
  1609. def _keys_to_ignore_on_save(self) -> list[str]:
  1610. try:
  1611. return self._first_module()._keys_to_ignore_on_save
  1612. except AttributeError:
  1613. return []
  1614. def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None) -> None:
  1615. # Propagate the gradient checkpointing to the transformer model
  1616. for module in self:
  1617. if isinstance(module, Transformer):
  1618. return module.auto_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs)