backend.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. from __future__ import annotations
  2. import logging
  3. import os
  4. import shutil
  5. import tempfile
  6. from pathlib import Path
  7. from typing import TYPE_CHECKING, Callable, Literal
  8. import huggingface_hub
  9. logger = logging.getLogger(__name__)
  10. if TYPE_CHECKING:
  11. from sentence_transformers.SentenceTransformer import SentenceTransformer
  12. try:
  13. from optimum.onnxruntime.configuration import OptimizationConfig, QuantizationConfig
  14. except ImportError:
  15. pass
  16. def export_optimized_onnx_model(
  17. model: SentenceTransformer,
  18. optimization_config: OptimizationConfig | Literal["O1", "O2", "O3", "O4"],
  19. model_name_or_path: str,
  20. push_to_hub: bool = False,
  21. create_pr: bool = False,
  22. file_suffix: str | None = None,
  23. ) -> None:
  24. """
  25. Export an optimized ONNX model from a SentenceTransformer model.
  26. The O1-O4 optimization levels are defined by Optimum and are documented here:
  27. https://huggingface.co/docs/optimum/main/en/onnxruntime/usage_guides/optimization
  28. The optimization levels are:
  29. - O1: basic general optimizations.
  30. - O2: basic and extended general optimizations, transformers-specific fusions.
  31. - O3: same as O2 with GELU approximation.
  32. - O4: same as O3 with mixed precision (fp16, GPU-only)
  33. See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for more information & benchmarks.
  34. Args:
  35. model (SentenceTransformer): The SentenceTransformer model to be optimized. Must be loaded with `backend="onnx"`.
  36. optimization_config (OptimizationConfig | Literal["O1", "O2", "O3", "O4"]): The optimization configuration or level.
  37. model_name_or_path (str): The path or Hugging Face Hub repository name where the optimized model will be saved.
  38. push_to_hub (bool, optional): Whether to push the optimized model to the Hugging Face Hub. Defaults to False.
  39. create_pr (bool, optional): Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False.
  40. file_suffix (str | None, optional): The suffix to add to the optimized model file name. Defaults to None.
  41. Raises:
  42. ImportError: If the required packages `optimum` and `onnxruntime` are not installed.
  43. ValueError: If the provided model is not a valid SentenceTransformer model loaded with `backend="onnx"`.
  44. ValueError: If the provided optimization_config is not valid.
  45. Returns:
  46. None
  47. """
  48. from sentence_transformers import SentenceTransformer
  49. from sentence_transformers.models.Transformer import Transformer
  50. try:
  51. from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTOptimizer
  52. from optimum.onnxruntime.configuration import AutoOptimizationConfig
  53. except ImportError:
  54. raise ImportError(
  55. "Please install Optimum and ONNX Runtime to use this function. "
  56. "You can install them with pip: `pip install optimum[onnxruntime]` "
  57. "or `pip install optimum[onnxruntime-gpu]`"
  58. )
  59. if (
  60. not isinstance(model, SentenceTransformer)
  61. or not len(model)
  62. or not isinstance(model[0], Transformer)
  63. or not isinstance(model[0].auto_model, ORTModelForFeatureExtraction)
  64. ):
  65. raise ValueError(
  66. 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="onnx"`.'
  67. )
  68. ort_model: ORTModelForFeatureExtraction = model[0].auto_model
  69. optimizer = ORTOptimizer.from_pretrained(ort_model)
  70. if isinstance(optimization_config, str):
  71. if optimization_config not in AutoOptimizationConfig._LEVELS:
  72. raise ValueError(
  73. "optimization_config must be an OptimizationConfig instance or one of 'O1', 'O2', 'O3', 'O4'."
  74. )
  75. file_suffix = file_suffix or optimization_config
  76. optimization_config = getattr(AutoOptimizationConfig, optimization_config)()
  77. if file_suffix is None:
  78. file_suffix = "optimized"
  79. save_or_push_to_hub_onnx_model(
  80. export_function=lambda save_dir: optimizer.optimize(optimization_config, save_dir, file_suffix=file_suffix),
  81. export_function_name="export_optimized_onnx_model",
  82. config=optimization_config,
  83. model_name_or_path=model_name_or_path,
  84. push_to_hub=push_to_hub,
  85. create_pr=create_pr,
  86. file_suffix=file_suffix,
  87. )
  88. def export_dynamic_quantized_onnx_model(
  89. model: SentenceTransformer,
  90. quantization_config: QuantizationConfig | Literal["arm64", "avx2", "avx512", "avx512_vnni"],
  91. model_name_or_path: str,
  92. push_to_hub: bool = False,
  93. create_pr: bool = False,
  94. file_suffix: str | None = None,
  95. ) -> None:
  96. """
  97. Export a quantized ONNX model from a SentenceTransformer model.
  98. This function applies dynamic quantization, i.e. without a calibration dataset.
  99. Each of the default quantization configurations quantize the model to int8, allowing
  100. for faster inference on CPUs, but are likely slower on GPUs.
  101. See https://sbert.net/docs/sentence_transformer/usage/efficiency.html for more information & benchmarks.
  102. Args:
  103. model (SentenceTransformer): The SentenceTransformer model to be quantized. Must be loaded with `backend="onnx"`.
  104. quantization_config (QuantizationConfig): The quantization configuration.
  105. model_name_or_path (str): The path or Hugging Face Hub repository name where the quantized model will be saved.
  106. push_to_hub (bool, optional): Whether to push the quantized model to the Hugging Face Hub. Defaults to False.
  107. create_pr (bool, optional): Whether to create a pull request when pushing to the Hugging Face Hub. Defaults to False.
  108. file_suffix (str | None, optional): The suffix to add to the quantized model file name. Defaults to None.
  109. Raises:
  110. ImportError: If the required packages `optimum` and `onnxruntime` are not installed.
  111. ValueError: If the provided model is not a valid SentenceTransformer model loaded with `backend="onnx"`.
  112. ValueError: If the provided quantization_config is not valid.
  113. Returns:
  114. None
  115. """
  116. from sentence_transformers import SentenceTransformer
  117. from sentence_transformers.models.Transformer import Transformer
  118. try:
  119. from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTQuantizer
  120. from optimum.onnxruntime.configuration import AutoQuantizationConfig
  121. except ImportError:
  122. raise ImportError(
  123. "Please install Optimum and ONNX Runtime to use this function. "
  124. "You can install them with pip: `pip install optimum[onnxruntime]` "
  125. "or `pip install optimum[onnxruntime-gpu]`"
  126. )
  127. if (
  128. not isinstance(model, SentenceTransformer)
  129. or not len(model)
  130. or not isinstance(model[0], Transformer)
  131. or not isinstance(model[0].auto_model, ORTModelForFeatureExtraction)
  132. ):
  133. raise ValueError(
  134. 'The model must be a Transformer-based SentenceTransformer model loaded with `backend="onnx"`.'
  135. )
  136. ort_model: ORTModelForFeatureExtraction = model[0].auto_model
  137. quantizer = ORTQuantizer.from_pretrained(ort_model)
  138. if isinstance(quantization_config, str):
  139. if quantization_config not in ["arm64", "avx2", "avx512", "avx512_vnni"]:
  140. raise ValueError(
  141. "quantization_config must be an QuantizationConfig instance or one of 'arm64', 'avx2', 'avx512', or 'avx512_vnni'."
  142. )
  143. quantization_config_name = quantization_config[:]
  144. quantization_config = getattr(AutoQuantizationConfig, quantization_config)(is_static=False)
  145. file_suffix = file_suffix or f"{quantization_config.weights_dtype.name.lower()}_{quantization_config_name}"
  146. if file_suffix is None:
  147. file_suffix = f"{quantization_config.weights_dtype.name.lower()}_quantized"
  148. save_or_push_to_hub_onnx_model(
  149. export_function=lambda save_dir: quantizer.quantize(quantization_config, save_dir, file_suffix=file_suffix),
  150. export_function_name="export_dynamic_quantized_onnx_model",
  151. config=quantization_config,
  152. model_name_or_path=model_name_or_path,
  153. push_to_hub=push_to_hub,
  154. create_pr=create_pr,
  155. file_suffix=file_suffix,
  156. )
  157. def save_or_push_to_hub_onnx_model(
  158. export_function: Callable,
  159. export_function_name: str,
  160. config,
  161. model_name_or_path: str,
  162. push_to_hub: bool = False,
  163. create_pr: bool = False,
  164. file_suffix: str | None = None,
  165. ):
  166. if push_to_hub:
  167. with tempfile.TemporaryDirectory() as save_dir:
  168. export_function(save_dir)
  169. file_name = f"model_{file_suffix}.onnx"
  170. source = (Path(save_dir) / file_name).as_posix()
  171. destination = (Path("onnx") / file_name).as_posix()
  172. commit_description = ""
  173. if create_pr:
  174. opt_config_string = repr(config).replace("(", "(\n\t").replace(", ", ",\n\t").replace(")", "\n)")
  175. commit_description = f"""\
  176. Hello!
  177. *This pull request has been automatically generated from the [`{export_function_name}`](https://sbert.net/docs/package_reference/util.html#sentence_transformers.backend.{export_function_name}) function from the Sentence Transformers library.*
  178. ## Config
  179. ```python
  180. {opt_config_string}
  181. ```
  182. ## Tip:
  183. Consider testing this pull request before merging by loading the model from this PR with the `revision` argument:
  184. ```python
  185. from sentence_transformers import SentenceTransformer
  186. # TODO: Fill in the PR number
  187. pr_number = 2
  188. model = SentenceTransformer(
  189. "{model_name_or_path}",
  190. revision=f"refs/pr/{{pr_number}}",
  191. backend="onnx",
  192. model_kwargs={{"file_name": "{destination}"}},
  193. )
  194. # Verify that everything works as expected
  195. embeddings = model.encode(["The weather is lovely today.", "It's so sunny outside!", "He drove to the stadium."])
  196. print(embeddings.shape)
  197. similarities = model.similarity(embeddings, embeddings)
  198. print(similarities)
  199. ```
  200. """
  201. huggingface_hub.upload_file(
  202. path_or_fileobj=source,
  203. path_in_repo=destination,
  204. repo_id=model_name_or_path,
  205. repo_type="model",
  206. commit_message=f"Add exported ONNX model {file_name!r}",
  207. commit_description=commit_description,
  208. create_pr=create_pr,
  209. )
  210. else:
  211. with tempfile.TemporaryDirectory() as save_dir:
  212. export_function(save_dir)
  213. file_name = f"model_{file_suffix}.onnx"
  214. source = os.path.join(save_dir, file_name)
  215. destination = os.path.join(model_name_or_path, "onnx", file_name)
  216. # Create destination if it does not exist
  217. os.makedirs(os.path.dirname(destination), exist_ok=True)
  218. shutil.copy(source, destination)