WhitespaceTokenizer.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. from __future__ import annotations
  2. import collections
  3. import json
  4. import os
  5. import string
  6. from typing import Iterable
  7. from .WordTokenizer import ENGLISH_STOP_WORDS, WordTokenizer
  8. class WhitespaceTokenizer(WordTokenizer):
  9. """
  10. Simple and fast white-space tokenizer. Splits sentence based on white spaces.
  11. Punctuation are stripped from tokens.
  12. """
  13. def __init__(
  14. self, vocab: Iterable[str] = [], stop_words: Iterable[str] = ENGLISH_STOP_WORDS, do_lower_case: bool = False
  15. ):
  16. self.stop_words = set(stop_words)
  17. self.do_lower_case = do_lower_case
  18. self.set_vocab(vocab)
  19. def get_vocab(self):
  20. return self.vocab
  21. def set_vocab(self, vocab: Iterable[str]):
  22. self.vocab = vocab
  23. self.word2idx = collections.OrderedDict([(word, idx) for idx, word in enumerate(vocab)])
  24. def tokenize(self, text: str, **kwargs) -> list[int]:
  25. if self.do_lower_case:
  26. text = text.lower()
  27. tokens = text.split()
  28. tokens_filtered = []
  29. for token in tokens:
  30. if token in self.stop_words:
  31. continue
  32. elif token in self.word2idx:
  33. tokens_filtered.append(self.word2idx[token])
  34. continue
  35. token = token.strip(string.punctuation)
  36. if token in self.stop_words:
  37. continue
  38. elif len(token) > 0 and token in self.word2idx:
  39. tokens_filtered.append(self.word2idx[token])
  40. continue
  41. token = token.lower()
  42. if token in self.stop_words:
  43. continue
  44. elif token in self.word2idx:
  45. tokens_filtered.append(self.word2idx[token])
  46. continue
  47. return tokens_filtered
  48. def save(self, output_path: str):
  49. with open(os.path.join(output_path, "whitespacetokenizer_config.json"), "w") as fOut:
  50. json.dump(
  51. {
  52. "vocab": list(self.word2idx.keys()),
  53. "stop_words": list(self.stop_words),
  54. "do_lower_case": self.do_lower_case,
  55. },
  56. fOut,
  57. )
  58. @staticmethod
  59. def load(input_path: str):
  60. with open(os.path.join(input_path, "whitespacetokenizer_config.json")) as fIn:
  61. config = json.load(fIn)
  62. return WhitespaceTokenizer(**config)