BoW.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from __future__ import annotations
  2. import json
  3. import logging
  4. import os
  5. from typing import Literal
  6. import torch
  7. from torch import Tensor, nn
  8. from .tokenizer import WhitespaceTokenizer
  9. logger = logging.getLogger(__name__)
  10. class BoW(nn.Module):
  11. """Implements a Bag-of-Words (BoW) model to derive sentence embeddings.
  12. A weighting can be added to allow the generation of tf-idf vectors. The output vector has the size of the vocab.
  13. """
  14. def __init__(
  15. self,
  16. vocab: list[str],
  17. word_weights: dict[str, float] = {},
  18. unknown_word_weight: float = 1,
  19. cumulative_term_frequency: bool = True,
  20. ):
  21. super().__init__()
  22. vocab = list(set(vocab)) # Ensure vocab is unique
  23. self.config_keys = ["vocab", "word_weights", "unknown_word_weight", "cumulative_term_frequency"]
  24. self.vocab = vocab
  25. self.word_weights = word_weights
  26. self.unknown_word_weight = unknown_word_weight
  27. self.cumulative_term_frequency = cumulative_term_frequency
  28. # Maps wordIdx -> word weight
  29. self.weights = []
  30. num_unknown_words = 0
  31. for word in vocab:
  32. weight = unknown_word_weight
  33. if word in word_weights:
  34. weight = word_weights[word]
  35. elif word.lower() in word_weights:
  36. weight = word_weights[word.lower()]
  37. else:
  38. num_unknown_words += 1
  39. self.weights.append(weight)
  40. logger.info(
  41. f"{num_unknown_words} out of {len(vocab)} words without a weighting value. Set weight to {unknown_word_weight}"
  42. )
  43. self.tokenizer = WhitespaceTokenizer(vocab, stop_words=set(), do_lower_case=False)
  44. self.sentence_embedding_dimension = len(vocab)
  45. def forward(self, features: dict[str, Tensor]):
  46. # Nothing to do, everything is done in get_sentence_features
  47. return features
  48. def tokenize(self, texts: list[str], **kwargs) -> list[int]:
  49. tokenized = [self.tokenizer.tokenize(text, **kwargs) for text in texts]
  50. return self.get_sentence_features(tokenized)
  51. def get_sentence_embedding_dimension(self):
  52. return self.sentence_embedding_dimension
  53. def get_sentence_features(
  54. self, tokenized_texts: list[list[int]], pad_seq_length: int = 0
  55. ) -> dict[Literal["sentence_embedding"], torch.Tensor]:
  56. vectors = []
  57. for tokens in tokenized_texts:
  58. vector = torch.zeros(self.get_sentence_embedding_dimension(), dtype=torch.float32)
  59. for token in tokens:
  60. if self.cumulative_term_frequency:
  61. vector[token] += self.weights[token]
  62. else:
  63. vector[token] = self.weights[token]
  64. vectors.append(vector)
  65. return {"sentence_embedding": torch.stack(vectors)}
  66. def get_config_dict(self):
  67. return {key: self.__dict__[key] for key in self.config_keys}
  68. def save(self, output_path):
  69. with open(os.path.join(output_path, "config.json"), "w") as fOut:
  70. json.dump(self.get_config_dict(), fOut, indent=2)
  71. @staticmethod
  72. def load(input_path):
  73. with open(os.path.join(input_path, "config.json")) as fIn:
  74. config = json.load(fIn)
  75. return BoW(**config)