Transformer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. import os
  5. from fnmatch import fnmatch
  6. from pathlib import Path
  7. from typing import Any, Callable
  8. import huggingface_hub
  9. import torch
  10. from torch import nn
  11. from transformers import AutoConfig, AutoModel, AutoTokenizer, MT5Config, T5Config
  12. logger = logging.getLogger(__name__)
  13. def _save_pretrained_wrapper(_save_pretrained_fn: Callable, subfolder: str) -> Callable[..., None]:
  14. def wrapper(save_directory: str | Path, **kwargs) -> None:
  15. os.makedirs(Path(save_directory) / subfolder, exist_ok=True)
  16. return _save_pretrained_fn(Path(save_directory) / subfolder, **kwargs)
  17. return wrapper
  18. class Transformer(nn.Module):
  19. """Hugging Face AutoModel to generate token embeddings.
  20. Loads the correct class, e.g. BERT / RoBERTa etc.
  21. Args:
  22. model_name_or_path: Hugging Face models name
  23. (https://huggingface.co/models)
  24. max_seq_length: Truncate any inputs longer than max_seq_length
  25. model_args: Keyword arguments passed to the Hugging Face
  26. Transformers model
  27. tokenizer_args: Keyword arguments passed to the Hugging Face
  28. Transformers tokenizer
  29. config_args: Keyword arguments passed to the Hugging Face
  30. Transformers config
  31. cache_dir: Cache dir for Hugging Face Transformers to store/load
  32. models
  33. do_lower_case: If true, lowercases the input (independent if the
  34. model is cased or not)
  35. tokenizer_name_or_path: Name or path of the tokenizer. When
  36. None, then model_name_or_path is used
  37. backend: Backend used for model inference. Can be `torch`, `onnx`,
  38. or `openvino`. Default is `torch`.
  39. """
  40. save_in_root: bool = True
  41. def __init__(
  42. self,
  43. model_name_or_path: str,
  44. max_seq_length: int | None = None,
  45. model_args: dict[str, Any] | None = None,
  46. tokenizer_args: dict[str, Any] | None = None,
  47. config_args: dict[str, Any] | None = None,
  48. cache_dir: str | None = None,
  49. do_lower_case: bool = False,
  50. tokenizer_name_or_path: str = None,
  51. backend: str = "torch",
  52. ) -> None:
  53. super().__init__()
  54. self.config_keys = ["max_seq_length", "do_lower_case"]
  55. self.do_lower_case = do_lower_case
  56. self.backend = backend
  57. if model_args is None:
  58. model_args = {}
  59. if tokenizer_args is None:
  60. tokenizer_args = {}
  61. if config_args is None:
  62. config_args = {}
  63. config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
  64. self._load_model(model_name_or_path, config, cache_dir, backend, **model_args)
  65. if max_seq_length is not None and "model_max_length" not in tokenizer_args:
  66. tokenizer_args["model_max_length"] = max_seq_length
  67. self.tokenizer = AutoTokenizer.from_pretrained(
  68. tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
  69. cache_dir=cache_dir,
  70. **tokenizer_args,
  71. )
  72. # No max_seq_length set. Try to infer from model
  73. if max_seq_length is None:
  74. if (
  75. hasattr(self.auto_model, "config")
  76. and hasattr(self.auto_model.config, "max_position_embeddings")
  77. and hasattr(self.tokenizer, "model_max_length")
  78. ):
  79. max_seq_length = min(self.auto_model.config.max_position_embeddings, self.tokenizer.model_max_length)
  80. self.max_seq_length = max_seq_length
  81. if tokenizer_name_or_path is not None:
  82. self.auto_model.config.tokenizer_class = self.tokenizer.__class__.__name__
  83. def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None:
  84. """Loads the transformer model"""
  85. if backend == "torch":
  86. if isinstance(config, T5Config):
  87. self._load_t5_model(model_name_or_path, config, cache_dir, **model_args)
  88. elif isinstance(config, MT5Config):
  89. self._load_mt5_model(model_name_or_path, config, cache_dir, **model_args)
  90. else:
  91. self.auto_model = AutoModel.from_pretrained(
  92. model_name_or_path, config=config, cache_dir=cache_dir, **model_args
  93. )
  94. elif backend == "onnx":
  95. self._load_onnx_model(model_name_or_path, config, cache_dir, **model_args)
  96. elif backend == "openvino":
  97. self._load_openvino_model(model_name_or_path, config, cache_dir, **model_args)
  98. else:
  99. raise ValueError(f"Unsupported backend '{backend}'. `backend` should be `torch`, `onnx`, or `openvino`.")
  100. def _load_openvino_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
  101. if isinstance(config, T5Config) or isinstance(config, MT5Config):
  102. raise ValueError("T5 models are not yet supported by the OpenVINO backend.")
  103. try:
  104. from optimum.intel import OVModelForFeatureExtraction
  105. from optimum.intel.openvino import OV_XML_FILE_NAME
  106. except ModuleNotFoundError:
  107. raise Exception(
  108. "Using the OpenVINO backend requires installing Optimum and OpenVINO. "
  109. "You can install them with pip: `pip install optimum[openvino]`."
  110. )
  111. load_path = Path(model_name_or_path)
  112. is_local = load_path.exists()
  113. backend_name = "OpenVINO"
  114. target_file_glob = "openvino*.xml"
  115. # Determine whether the model should be exported or whether we can load it directly
  116. export, model_args = self._backend_should_export(
  117. load_path, is_local, model_args, OV_XML_FILE_NAME, target_file_glob, backend_name
  118. )
  119. # If we're exporting, then there's no need for a file_name to load the model from
  120. if export:
  121. model_args.pop("file_name", None)
  122. # ov_config can be either a dictionary, or point to a json file with an OpenVINO config
  123. if "ov_config" in model_args:
  124. ov_config = model_args["ov_config"]
  125. if not isinstance(ov_config, dict):
  126. if not Path(ov_config).exists():
  127. raise ValueError(
  128. "ov_config should be a dictionary or a path to a .json file containing an OpenVINO config"
  129. )
  130. with open(ov_config, encoding="utf-8") as f:
  131. model_args["ov_config"] = json.load(f)
  132. else:
  133. model_args["ov_config"] = {}
  134. # Either load an exported model, or export the model to OpenVINO
  135. self.auto_model: OVModelForFeatureExtraction = OVModelForFeatureExtraction.from_pretrained(
  136. model_name_or_path,
  137. config=config,
  138. cache_dir=cache_dir,
  139. export=export,
  140. **model_args,
  141. )
  142. # Wrap the save_pretrained method to save the model in the correct subfolder
  143. self.auto_model._save_pretrained = _save_pretrained_wrapper(self.auto_model._save_pretrained, self.backend)
  144. # Warn the user to save the model if they haven't already
  145. if export:
  146. self._backend_warn_to_save(model_name_or_path, is_local, backend_name)
  147. def _load_onnx_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
  148. try:
  149. import onnxruntime as ort
  150. from optimum.onnxruntime import ONNX_WEIGHTS_NAME, ORTModelForFeatureExtraction
  151. except ModuleNotFoundError:
  152. raise Exception(
  153. "Using the ONNX backend requires installing Optimum and ONNX Runtime. "
  154. "You can install them with pip: `pip install optimum[onnxruntime]` "
  155. "or `pip install optimum[onnxruntime-gpu]`"
  156. )
  157. # Default to the highest priority available provider if not specified
  158. # E.g. Tensorrt > CUDA > CPU
  159. model_args["provider"] = model_args.pop("provider", ort.get_available_providers()[0])
  160. load_path = Path(model_name_or_path)
  161. is_local = load_path.exists()
  162. backend_name = "ONNX"
  163. target_file_glob = "*.onnx"
  164. # Determine whether the model should be exported or whether we can load it directly
  165. export, model_args = self._backend_should_export(
  166. load_path, is_local, model_args, ONNX_WEIGHTS_NAME, target_file_glob, backend_name
  167. )
  168. # If we're exporting, then there's no need for a file_name to load the model from
  169. if export:
  170. model_args.pop("file_name", None)
  171. # Either load an exported model, or export the model to ONNX
  172. self.auto_model: ORTModelForFeatureExtraction = ORTModelForFeatureExtraction.from_pretrained(
  173. model_name_or_path,
  174. config=config,
  175. cache_dir=cache_dir,
  176. export=export,
  177. **model_args,
  178. )
  179. # Wrap the save_pretrained method to save the model in the correct subfolder
  180. self.auto_model._save_pretrained = _save_pretrained_wrapper(self.auto_model._save_pretrained, self.backend)
  181. # Warn the user to save the model if they haven't already
  182. if export:
  183. self._backend_warn_to_save(model_name_or_path, is_local, backend_name)
  184. def _backend_should_export(
  185. self,
  186. load_path: Path,
  187. is_local: bool,
  188. model_args: dict[str, Any],
  189. target_file_name: str,
  190. target_file_glob: str,
  191. backend_name: str,
  192. ) -> tuple[bool, dict[str, Any]]:
  193. """
  194. Determines whether the model should be exported to the backend, or if it can be loaded directly.
  195. Also update the `file_name` and `subfolder` model_args if necessary.
  196. These are the cases:
  197. 1. If export is set in model_args, just return export
  198. 2. If `<subfolder>/<file_name>` exists; set export to False
  199. 3. If `<backend>/<file_name>` exists; set export to False and set subfolder to the backend (e.g. "onnx")
  200. 4. If `<file_name>` contains a folder, add those folders to the subfolder and set the file_name to the last part
  201. We will warn if:
  202. 1. The expected file does not exist in the model directory given the optional file_name and subfolder.
  203. If there are valid files for this backend, but they're don't align with file_name, then we give a useful warning.
  204. 2. Multiple files are found in the model directory that match the target file name and the user did not
  205. specify the desired file name via `model_kwargs={"file_name": "<file_name>"}`
  206. Args:
  207. load_path: The model repository or directory, as a Path instance
  208. is_local: Whether the model is local or remote, i.e. whether load_path is a local directory
  209. model_args: The model_args dictionary. Notable keys are "export", "file_name", and "subfolder"
  210. target_file_name: The expected file name in the model directory, e.g. "model.onnx" or "openvino_model.xml"
  211. target_file_glob: The glob pattern to match the target file name, e.g. "*.onnx" or "openvino*.xml"
  212. backend_name: The human-readable name of the backend for use in warnings, e.g. "ONNX" or "OpenVINO"
  213. Returns:
  214. Tuple[bool, dict[str, Any]]: A tuple of the export boolean and the updated model_args dictionary.
  215. """
  216. export = model_args.pop("export", None)
  217. if export is not None:
  218. return export, model_args
  219. file_name = model_args.get("file_name", target_file_name)
  220. subfolder = model_args.get("subfolder", None)
  221. primary_full_path = Path(subfolder, file_name).as_posix() if subfolder else file_name
  222. secondary_full_path = (
  223. Path(subfolder, self.backend, file_name).as_posix()
  224. if subfolder
  225. else Path(self.backend, file_name).as_posix()
  226. )
  227. glob_pattern = f"{subfolder}/**/{target_file_glob}" if subfolder else f"**/{target_file_glob}"
  228. # Get the list of files in the model directory that match the target file name
  229. if is_local:
  230. model_file_names = [path.relative_to(load_path).as_posix() for path in load_path.glob(glob_pattern)]
  231. else:
  232. all_files = huggingface_hub.list_repo_files(
  233. load_path.as_posix(),
  234. repo_type="model",
  235. revision=model_args.get("revision", None),
  236. token=model_args.get("token", None),
  237. )
  238. model_file_names = [fname for fname in all_files if fnmatch(fname, glob_pattern)]
  239. # First check if the expected file exists in the root of the model directory
  240. # If it doesn't, check if it exists in the backend subfolder.
  241. # If it does, set the subfolder to include the backend
  242. export = primary_full_path not in model_file_names
  243. if export and "subfolder" not in model_args:
  244. export = secondary_full_path not in model_file_names
  245. if not export:
  246. if len(model_file_names) > 1 and "file_name" not in model_args:
  247. logger.warning(
  248. f"Multiple {backend_name} files found in {load_path.as_posix()!r}: {model_file_names}, defaulting to {secondary_full_path!r}. "
  249. f'Please specify the desired file name via `model_kwargs={{"file_name": "<file_name>"}}`.'
  250. )
  251. model_args["subfolder"] = self.backend
  252. model_args["file_name"] = file_name
  253. # If the file_name contains subfolders, set it as the subfolder instead
  254. file_name_parts = Path(file_name).parts
  255. if len(file_name_parts) > 1:
  256. model_args["file_name"] = file_name_parts[-1]
  257. model_args["subfolder"] = Path(model_args.get("subfolder", ""), *file_name_parts[:-1]).as_posix()
  258. if export:
  259. logger.warning(
  260. f"No {file_name!r} found in {load_path.as_posix()!r}. Exporting the model to {backend_name}."
  261. )
  262. if model_file_names:
  263. logger.warning(
  264. f"If you intended to load one of the {model_file_names} {backend_name} files, "
  265. f'please specify the desired file name via `model_kwargs={{"file_name": "{model_file_names[0]}"}}`.'
  266. )
  267. return export, model_args
  268. def _backend_warn_to_save(self, model_name_or_path: str, is_local: str, backend_name: str) -> None:
  269. to_log = f"Saving the exported {backend_name} model is heavily recommended to avoid having to export it again."
  270. if is_local:
  271. to_log += f" Do so with `model.save_pretrained({model_name_or_path!r})`."
  272. else:
  273. to_log += f" Do so with `model.push_to_hub({model_name_or_path!r}, create_pr=True)`."
  274. logger.warning(to_log)
  275. def _load_t5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
  276. """Loads the encoder model from T5"""
  277. from transformers import T5EncoderModel
  278. T5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
  279. self.auto_model = T5EncoderModel.from_pretrained(
  280. model_name_or_path, config=config, cache_dir=cache_dir, **model_args
  281. )
  282. def _load_mt5_model(self, model_name_or_path, config, cache_dir, **model_args) -> None:
  283. """Loads the encoder model from T5"""
  284. from transformers import MT5EncoderModel
  285. MT5EncoderModel._keys_to_ignore_on_load_unexpected = ["decoder.*"]
  286. self.auto_model = MT5EncoderModel.from_pretrained(
  287. model_name_or_path, config=config, cache_dir=cache_dir, **model_args
  288. )
  289. def __repr__(self) -> str:
  290. return f"Transformer({self.get_config_dict()}) with Transformer model: {self.auto_model.__class__.__name__} "
  291. def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torch.Tensor]:
  292. """Returns token_embeddings, cls_token"""
  293. trans_features = {"input_ids": features["input_ids"], "attention_mask": features["attention_mask"]}
  294. if "token_type_ids" in features:
  295. trans_features["token_type_ids"] = features["token_type_ids"]
  296. output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
  297. output_tokens = output_states[0]
  298. features.update({"token_embeddings": output_tokens, "attention_mask": features["attention_mask"]})
  299. if self.auto_model.config.output_hidden_states and len(output_states) > 2:
  300. all_layer_idx = 2 # I.e. after last_hidden_states and pooler_output
  301. if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
  302. all_layer_idx = 1
  303. hidden_states = output_states[all_layer_idx]
  304. features.update({"all_layer_embeddings": hidden_states})
  305. return features
  306. def get_word_embedding_dimension(self) -> int:
  307. return self.auto_model.config.hidden_size
  308. def tokenize(
  309. self, texts: list[str] | list[dict] | list[tuple[str, str]], padding: str | bool = True
  310. ) -> dict[str, torch.Tensor]:
  311. """Tokenizes a text and maps tokens to token-ids"""
  312. output = {}
  313. if isinstance(texts[0], str):
  314. to_tokenize = [texts]
  315. elif isinstance(texts[0], dict):
  316. to_tokenize = []
  317. output["text_keys"] = []
  318. for lookup in texts:
  319. text_key, text = next(iter(lookup.items()))
  320. to_tokenize.append(text)
  321. output["text_keys"].append(text_key)
  322. to_tokenize = [to_tokenize]
  323. else:
  324. batch1, batch2 = [], []
  325. for text_tuple in texts:
  326. batch1.append(text_tuple[0])
  327. batch2.append(text_tuple[1])
  328. to_tokenize = [batch1, batch2]
  329. # strip
  330. to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize]
  331. # Lowercase
  332. if self.do_lower_case:
  333. to_tokenize = [[s.lower() for s in col] for col in to_tokenize]
  334. output.update(
  335. self.tokenizer(
  336. *to_tokenize,
  337. padding=padding,
  338. truncation="longest_first",
  339. return_tensors="pt",
  340. max_length=self.max_seq_length,
  341. )
  342. )
  343. return output
  344. def get_config_dict(self) -> dict[str, Any]:
  345. return {key: self.__dict__[key] for key in self.config_keys}
  346. def save(self, output_path: str, safe_serialization: bool = True) -> None:
  347. self.auto_model.save_pretrained(output_path, safe_serialization=safe_serialization)
  348. self.tokenizer.save_pretrained(output_path)
  349. with open(os.path.join(output_path, "sentence_bert_config.json"), "w") as fOut:
  350. json.dump(self.get_config_dict(), fOut, indent=2)
  351. @classmethod
  352. def load(cls, input_path: str) -> Transformer:
  353. # Old classes used other config names than 'sentence_bert_config.json'
  354. for config_name in [
  355. "sentence_bert_config.json",
  356. "sentence_roberta_config.json",
  357. "sentence_distilbert_config.json",
  358. "sentence_camembert_config.json",
  359. "sentence_albert_config.json",
  360. "sentence_xlm-roberta_config.json",
  361. "sentence_xlnet_config.json",
  362. ]:
  363. sbert_config_path = os.path.join(input_path, config_name)
  364. if os.path.exists(sbert_config_path):
  365. break
  366. with open(sbert_config_path) as fIn:
  367. config = json.load(fIn)
  368. # Don't allow configs to set trust_remote_code
  369. if "model_args" in config and "trust_remote_code" in config["model_args"]:
  370. config["model_args"].pop("trust_remote_code")
  371. if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
  372. config["tokenizer_args"].pop("trust_remote_code")
  373. if "config_args" in config and "trust_remote_code" in config["config_args"]:
  374. config["config_args"].pop("trust_remote_code")
  375. return cls(model_name_or_path=input_path, **config)