WordEmbeddings.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from __future__ import annotations
  2. import gzip
  3. import json
  4. import logging
  5. import os
  6. import numpy as np
  7. import torch
  8. from safetensors.torch import load_file as load_safetensors_file
  9. from safetensors.torch import save_file as save_safetensors_file
  10. from torch import nn
  11. from tqdm import tqdm
  12. from sentence_transformers.util import fullname, http_get, import_from_string
  13. from .tokenizer import WhitespaceTokenizer, WordTokenizer
  14. logger = logging.getLogger(__name__)
  15. class WordEmbeddings(nn.Module):
  16. def __init__(
  17. self,
  18. tokenizer: WordTokenizer,
  19. embedding_weights,
  20. update_embeddings: bool = False,
  21. max_seq_length: int = 1000000,
  22. ):
  23. nn.Module.__init__(self)
  24. if isinstance(embedding_weights, list):
  25. embedding_weights = np.asarray(embedding_weights)
  26. if isinstance(embedding_weights, np.ndarray):
  27. embedding_weights = torch.from_numpy(embedding_weights)
  28. num_embeddings, embeddings_dimension = embedding_weights.size()
  29. self.embeddings_dimension = embeddings_dimension
  30. self.emb_layer = nn.Embedding(num_embeddings, embeddings_dimension)
  31. self.emb_layer.load_state_dict({"weight": embedding_weights})
  32. self.emb_layer.weight.requires_grad = update_embeddings
  33. self.tokenizer = tokenizer
  34. self.update_embeddings = update_embeddings
  35. self.max_seq_length = max_seq_length
  36. def forward(self, features):
  37. token_embeddings = self.emb_layer(features["input_ids"])
  38. cls_tokens = None
  39. features.update(
  40. {
  41. "token_embeddings": token_embeddings,
  42. "cls_token_embeddings": cls_tokens,
  43. "attention_mask": features["attention_mask"],
  44. }
  45. )
  46. return features
  47. def tokenize(self, texts: list[str], **kwargs):
  48. tokenized_texts = [self.tokenizer.tokenize(text, **kwargs) for text in texts]
  49. sentence_lengths = [len(tokens) for tokens in tokenized_texts]
  50. max_len = max(sentence_lengths)
  51. input_ids = []
  52. attention_masks = []
  53. for tokens in tokenized_texts:
  54. padding = [0] * (max_len - len(tokens))
  55. input_ids.append(tokens + padding)
  56. attention_masks.append([1] * len(tokens) + padding)
  57. output = {
  58. "input_ids": torch.tensor(input_ids, dtype=torch.long),
  59. "attention_mask": torch.tensor(attention_masks, dtype=torch.long),
  60. "sentence_lengths": torch.tensor(sentence_lengths, dtype=torch.long),
  61. }
  62. return output
  63. def get_word_embedding_dimension(self) -> int:
  64. return self.embeddings_dimension
  65. def save(self, output_path: str, safe_serialization: bool = True):
  66. with open(os.path.join(output_path, "wordembedding_config.json"), "w") as fOut:
  67. json.dump(self.get_config_dict(), fOut, indent=2)
  68. if safe_serialization:
  69. save_safetensors_file(self.state_dict(), os.path.join(output_path, "model.safetensors"))
  70. else:
  71. torch.save(self.state_dict(), os.path.join(output_path, "pytorch_model.bin"))
  72. self.tokenizer.save(output_path)
  73. def get_config_dict(self):
  74. return {
  75. "tokenizer_class": fullname(self.tokenizer),
  76. "update_embeddings": self.update_embeddings,
  77. "max_seq_length": self.max_seq_length,
  78. }
  79. @staticmethod
  80. def load(input_path: str):
  81. with open(os.path.join(input_path, "wordembedding_config.json")) as fIn:
  82. config = json.load(fIn)
  83. tokenizer_class = import_from_string(config["tokenizer_class"])
  84. tokenizer = tokenizer_class.load(input_path)
  85. if os.path.exists(os.path.join(input_path, "model.safetensors")):
  86. weights = load_safetensors_file(os.path.join(input_path, "model.safetensors"))
  87. else:
  88. weights = torch.load(
  89. os.path.join(input_path, "pytorch_model.bin"), map_location=torch.device("cpu"), weights_only=True
  90. )
  91. embedding_weights = weights["emb_layer.weight"]
  92. model = WordEmbeddings(
  93. tokenizer=tokenizer, embedding_weights=embedding_weights, update_embeddings=config["update_embeddings"]
  94. )
  95. return model
  96. @staticmethod
  97. def from_text_file(
  98. embeddings_file_path: str,
  99. update_embeddings: bool = False,
  100. item_separator: str = " ",
  101. tokenizer=WhitespaceTokenizer(),
  102. max_vocab_size: int = None,
  103. ):
  104. logger.info(f"Read in embeddings file {embeddings_file_path}")
  105. if not os.path.exists(embeddings_file_path):
  106. logger.info(f"{embeddings_file_path} does not exist, try to download from server")
  107. if "/" in embeddings_file_path or "\\" in embeddings_file_path:
  108. raise ValueError(f"Embeddings file not found: {embeddings_file_path}")
  109. url = "https://public.ukp.informatik.tu-darmstadt.de/reimers/embeddings/" + embeddings_file_path
  110. http_get(url, embeddings_file_path)
  111. embeddings_dimension = None
  112. vocab = []
  113. embeddings = []
  114. with gzip.open(embeddings_file_path, "rt", encoding="utf8") if embeddings_file_path.endswith(".gz") else open(
  115. embeddings_file_path, encoding="utf8"
  116. ) as fIn:
  117. iterator = tqdm(fIn, desc="Load Word Embeddings", unit="Embeddings")
  118. for line in iterator:
  119. split = line.rstrip().split(item_separator)
  120. if not vocab and len(split) == 2: # Handle Word2vec format
  121. continue
  122. word = split[0]
  123. if embeddings_dimension is None:
  124. embeddings_dimension = len(split) - 1
  125. vocab.append("PADDING_TOKEN")
  126. embeddings.append(np.zeros(embeddings_dimension))
  127. if (
  128. len(split) - 1
  129. ) != embeddings_dimension: # Assure that all lines in the embeddings file are of the same length
  130. logger.error(
  131. "ERROR: A line in the embeddings file had more or less dimensions than expected. Skip token."
  132. )
  133. continue
  134. vector = np.array([float(num) for num in split[1:]])
  135. embeddings.append(vector)
  136. vocab.append(word)
  137. if max_vocab_size is not None and max_vocab_size > 0 and len(vocab) > max_vocab_size:
  138. break
  139. embeddings = np.asarray(embeddings)
  140. tokenizer.set_vocab(vocab)
  141. return WordEmbeddings(
  142. tokenizer=tokenizer, embedding_weights=embeddings, update_embeddings=update_embeddings
  143. )