| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- from __future__ import annotations
- import json
- import logging
- import os
- from typing import Literal
- import torch
- from torch import Tensor, nn
- from .tokenizer import WhitespaceTokenizer
- logger = logging.getLogger(__name__)
- class BoW(nn.Module):
- """Implements a Bag-of-Words (BoW) model to derive sentence embeddings.
- A weighting can be added to allow the generation of tf-idf vectors. The output vector has the size of the vocab.
- """
- def __init__(
- self,
- vocab: list[str],
- word_weights: dict[str, float] = {},
- unknown_word_weight: float = 1,
- cumulative_term_frequency: bool = True,
- ):
- super().__init__()
- vocab = list(set(vocab)) # Ensure vocab is unique
- self.config_keys = ["vocab", "word_weights", "unknown_word_weight", "cumulative_term_frequency"]
- self.vocab = vocab
- self.word_weights = word_weights
- self.unknown_word_weight = unknown_word_weight
- self.cumulative_term_frequency = cumulative_term_frequency
- # Maps wordIdx -> word weight
- self.weights = []
- num_unknown_words = 0
- for word in vocab:
- weight = unknown_word_weight
- if word in word_weights:
- weight = word_weights[word]
- elif word.lower() in word_weights:
- weight = word_weights[word.lower()]
- else:
- num_unknown_words += 1
- self.weights.append(weight)
- logger.info(
- f"{num_unknown_words} out of {len(vocab)} words without a weighting value. Set weight to {unknown_word_weight}"
- )
- self.tokenizer = WhitespaceTokenizer(vocab, stop_words=set(), do_lower_case=False)
- self.sentence_embedding_dimension = len(vocab)
- def forward(self, features: dict[str, Tensor]):
- # Nothing to do, everything is done in get_sentence_features
- return features
- def tokenize(self, texts: list[str], **kwargs) -> list[int]:
- tokenized = [self.tokenizer.tokenize(text, **kwargs) for text in texts]
- return self.get_sentence_features(tokenized)
- def get_sentence_embedding_dimension(self):
- return self.sentence_embedding_dimension
- def get_sentence_features(
- self, tokenized_texts: list[list[int]], pad_seq_length: int = 0
- ) -> dict[Literal["sentence_embedding"], torch.Tensor]:
- vectors = []
- for tokens in tokenized_texts:
- vector = torch.zeros(self.get_sentence_embedding_dimension(), dtype=torch.float32)
- for token in tokens:
- if self.cumulative_term_frequency:
- vector[token] += self.weights[token]
- else:
- vector[token] = self.weights[token]
- vectors.append(vector)
- return {"sentence_embedding": torch.stack(vectors)}
- def get_config_dict(self):
- return {key: self.__dict__[key] for key in self.config_keys}
- def save(self, output_path):
- with open(os.path.join(output_path, "config.json"), "w") as fOut:
- json.dump(self.get_config_dict(), fOut, indent=2)
- @staticmethod
- def load(input_path):
- with open(os.path.join(input_path, "config.json")) as fIn:
- config = json.load(fIn)
- return BoW(**config)
|