visualizer.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403
  1. import itertools
  2. import os
  3. import re
  4. from string import Template
  5. from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
  6. from tokenizers import Encoding, Tokenizer
  7. dirname = os.path.dirname(__file__)
  8. css_filename = os.path.join(dirname, "visualizer-styles.css")
  9. with open(css_filename) as f:
  10. css = f.read()
  11. class Annotation:
  12. start: int
  13. end: int
  14. label: int
  15. def __init__(self, start: int, end: int, label: str):
  16. self.start = start
  17. self.end = end
  18. self.label = label
  19. AnnotationList = List[Annotation]
  20. PartialIntList = List[Optional[int]]
  21. class CharStateKey(NamedTuple):
  22. token_ix: Optional[int]
  23. anno_ix: Optional[int]
  24. class CharState:
  25. char_ix: Optional[int]
  26. def __init__(self, char_ix):
  27. self.char_ix = char_ix
  28. self.anno_ix: Optional[int] = None
  29. self.tokens: List[int] = []
  30. @property
  31. def token_ix(self):
  32. return self.tokens[0] if len(self.tokens) > 0 else None
  33. @property
  34. def is_multitoken(self):
  35. """
  36. BPE tokenizers can output more than one token for a char
  37. """
  38. return len(self.tokens) > 1
  39. def partition_key(self) -> CharStateKey:
  40. return CharStateKey(
  41. token_ix=self.token_ix,
  42. anno_ix=self.anno_ix,
  43. )
  44. class Aligned:
  45. pass
  46. class EncodingVisualizer:
  47. """
  48. Build an EncodingVisualizer
  49. Args:
  50. tokenizer (:class:`~tokenizers.Tokenizer`):
  51. A tokenizer instance
  52. default_to_notebook (:obj:`bool`):
  53. Whether to render html output in a notebook by default
  54. annotation_converter (:obj:`Callable`, `optional`):
  55. An optional (lambda) function that takes an annotation in any format and returns
  56. an Annotation object
  57. """
  58. unk_token_regex = re.compile("(.{1}\b)?(unk|oov)(\b.{1})?", flags=re.IGNORECASE)
  59. def __init__(
  60. self,
  61. tokenizer: Tokenizer,
  62. default_to_notebook: bool = True,
  63. annotation_converter: Optional[Callable[[Any], Annotation]] = None,
  64. ):
  65. if default_to_notebook:
  66. try:
  67. from IPython.core.display import HTML, display
  68. except ImportError:
  69. raise Exception(
  70. """We couldn't import IPython utils for html display.
  71. Are you running in a notebook?
  72. You can also pass `default_to_notebook=False` to get back raw HTML
  73. """
  74. )
  75. self.tokenizer = tokenizer
  76. self.default_to_notebook = default_to_notebook
  77. self.annotation_coverter = annotation_converter
  78. pass
  79. def __call__(
  80. self,
  81. text: str,
  82. annotations: AnnotationList = [],
  83. default_to_notebook: Optional[bool] = None,
  84. ) -> Optional[str]:
  85. """
  86. Build a visualization of the given text
  87. Args:
  88. text (:obj:`str`):
  89. The text to tokenize
  90. annotations (:obj:`List[Annotation]`, `optional`):
  91. An optional list of annotations of the text. The can either be an annotation class
  92. or anything else if you instantiated the visualizer with a converter function
  93. default_to_notebook (:obj:`bool`, `optional`, defaults to `False`):
  94. If True, will render the html in a notebook. Otherwise returns an html string.
  95. Returns:
  96. The HTML string if default_to_notebook is False, otherwise (default) returns None and
  97. renders the HTML in the notebook
  98. """
  99. final_default_to_notebook = self.default_to_notebook
  100. if default_to_notebook is not None:
  101. final_default_to_notebook = default_to_notebook
  102. if final_default_to_notebook:
  103. try:
  104. from IPython.core.display import HTML, display
  105. except ImportError:
  106. raise Exception(
  107. """We couldn't import IPython utils for html display.
  108. Are you running in a notebook?"""
  109. )
  110. if self.annotation_coverter is not None:
  111. annotations = list(map(self.annotation_coverter, annotations))
  112. encoding = self.tokenizer.encode(text)
  113. html = EncodingVisualizer.__make_html(text, encoding, annotations)
  114. if final_default_to_notebook:
  115. display(HTML(html))
  116. else:
  117. return html
  118. @staticmethod
  119. def calculate_label_colors(annotations: AnnotationList) -> Dict[str, str]:
  120. """
  121. Generates a color palette for all the labels in a given set of annotations
  122. Args:
  123. annotations (:obj:`Annotation`):
  124. A list of annotations
  125. Returns:
  126. :obj:`dict`: A dictionary mapping labels to colors in HSL format
  127. """
  128. if len(annotations) == 0:
  129. return {}
  130. labels = set(map(lambda x: x.label, annotations))
  131. num_labels = len(labels)
  132. h_step = int(255 / num_labels)
  133. if h_step < 20:
  134. h_step = 20
  135. s = 32
  136. l = 64 # noqa: E741
  137. h = 10
  138. colors = {}
  139. for label in sorted(labels): # sort so we always get the same colors for a given set of labels
  140. colors[label] = f"hsl({h},{s}%,{l}%"
  141. h += h_step
  142. return colors
  143. @staticmethod
  144. def consecutive_chars_to_html(
  145. consecutive_chars_list: List[CharState],
  146. text: str,
  147. encoding: Encoding,
  148. ):
  149. """
  150. Converts a list of "consecutive chars" into a single HTML element.
  151. Chars are consecutive if they fall under the same word, token and annotation.
  152. The CharState class is a named tuple with a "partition_key" method that makes it easy to
  153. compare if two chars are consecutive.
  154. Args:
  155. consecutive_chars_list (:obj:`List[CharState]`):
  156. A list of CharStates that have been grouped together
  157. text (:obj:`str`):
  158. The original text being processed
  159. encoding (:class:`~tokenizers.Encoding`):
  160. The encoding returned from the tokenizer
  161. Returns:
  162. :obj:`str`: The HTML span for a set of consecutive chars
  163. """
  164. first = consecutive_chars_list[0]
  165. if first.char_ix is None:
  166. # its a special token
  167. stoken = encoding.tokens[first.token_ix]
  168. # special tokens are represented as empty spans. We use the data attribute and css
  169. # magic to display it
  170. return f'<span class="special-token" data-stoken={stoken}></span>'
  171. # We're not in a special token so this group has a start and end.
  172. last = consecutive_chars_list[-1]
  173. start = first.char_ix
  174. end = last.char_ix + 1
  175. span_text = text[start:end]
  176. css_classes = [] # What css classes will we apply on the resulting span
  177. data_items = {} # What data attributes will we apply on the result span
  178. if first.token_ix is not None:
  179. # We can either be in a token or not (e.g. in white space)
  180. css_classes.append("token")
  181. if first.is_multitoken:
  182. css_classes.append("multi-token")
  183. if first.token_ix % 2:
  184. # We use this to color alternating tokens.
  185. # A token might be split by an annotation that ends in the middle of it, so this
  186. # lets us visually indicate a consecutive token despite its possible splitting in
  187. # the html markup
  188. css_classes.append("odd-token")
  189. else:
  190. # Like above, but a different color so we can see the tokens alternate
  191. css_classes.append("even-token")
  192. if EncodingVisualizer.unk_token_regex.search(encoding.tokens[first.token_ix]) is not None:
  193. # This is a special token that is in the text. probably UNK
  194. css_classes.append("special-token")
  195. # TODO is this the right name for the data attribute ?
  196. data_items["stok"] = encoding.tokens[first.token_ix]
  197. else:
  198. # In this case we are looking at a group/single char that is not tokenized.
  199. # e.g. white space
  200. css_classes.append("non-token")
  201. css = f'''class="{' '.join(css_classes)}"'''
  202. data = ""
  203. for key, val in data_items.items():
  204. data += f' data-{key}="{val}"'
  205. return f"<span {css} {data} >{span_text}</span>"
  206. @staticmethod
  207. def __make_html(text: str, encoding: Encoding, annotations: AnnotationList) -> str:
  208. char_states = EncodingVisualizer.__make_char_states(text, encoding, annotations)
  209. current_consecutive_chars = [char_states[0]]
  210. prev_anno_ix = char_states[0].anno_ix
  211. spans = []
  212. label_colors_dict = EncodingVisualizer.calculate_label_colors(annotations)
  213. cur_anno_ix = char_states[0].anno_ix
  214. if cur_anno_ix is not None:
  215. # If we started in an annotation make a span for it
  216. anno = annotations[cur_anno_ix]
  217. label = anno.label
  218. color = label_colors_dict[label]
  219. spans.append(f'<span class="annotation" style="color:{color}" data-label="{label}">')
  220. for cs in char_states[1:]:
  221. cur_anno_ix = cs.anno_ix
  222. if cur_anno_ix != prev_anno_ix:
  223. # If we've transitioned in or out of an annotation
  224. spans.append(
  225. # Create a span from the current consecutive characters
  226. EncodingVisualizer.consecutive_chars_to_html(
  227. current_consecutive_chars,
  228. text=text,
  229. encoding=encoding,
  230. )
  231. )
  232. current_consecutive_chars = [cs]
  233. if prev_anno_ix is not None:
  234. # if we transitioned out of an annotation close it's span
  235. spans.append("</span>")
  236. if cur_anno_ix is not None:
  237. # If we entered a new annotation make a span for it
  238. anno = annotations[cur_anno_ix]
  239. label = anno.label
  240. color = label_colors_dict[label]
  241. spans.append(f'<span class="annotation" style="color:{color}" data-label="{label}">')
  242. prev_anno_ix = cur_anno_ix
  243. if cs.partition_key() == current_consecutive_chars[0].partition_key():
  244. # If the current charchter is in the same "group" as the previous one
  245. current_consecutive_chars.append(cs)
  246. else:
  247. # Otherwise we make a span for the previous group
  248. spans.append(
  249. EncodingVisualizer.consecutive_chars_to_html(
  250. current_consecutive_chars,
  251. text=text,
  252. encoding=encoding,
  253. )
  254. )
  255. # An reset the consecutive_char_list to form a new group
  256. current_consecutive_chars = [cs]
  257. # All that's left is to fill out the final span
  258. # TODO I think there is an edge case here where an annotation's span might not close
  259. spans.append(
  260. EncodingVisualizer.consecutive_chars_to_html(
  261. current_consecutive_chars,
  262. text=text,
  263. encoding=encoding,
  264. )
  265. )
  266. res = HTMLBody(spans) # Send the list of spans to the body of our html
  267. return res
  268. @staticmethod
  269. def __make_anno_map(text: str, annotations: AnnotationList) -> PartialIntList:
  270. """
  271. Args:
  272. text (:obj:`str`):
  273. The raw text we want to align to
  274. annotations (:obj:`AnnotationList`):
  275. A (possibly empty) list of annotations
  276. Returns:
  277. A list of length len(text) whose entry at index i is None if there is no annotation on
  278. charachter i or k, the index of the annotation that covers index i where k is with
  279. respect to the list of annotations
  280. """
  281. annotation_map = [None] * len(text)
  282. for anno_ix, a in enumerate(annotations):
  283. for i in range(a.start, a.end):
  284. annotation_map[i] = anno_ix
  285. return annotation_map
  286. @staticmethod
  287. def __make_char_states(text: str, encoding: Encoding, annotations: AnnotationList) -> List[CharState]:
  288. """
  289. For each character in the original text, we emit a tuple representing it's "state":
  290. * which token_ix it corresponds to
  291. * which word_ix it corresponds to
  292. * which annotation_ix it corresponds to
  293. Args:
  294. text (:obj:`str`):
  295. The raw text we want to align to
  296. annotations (:obj:`List[Annotation]`):
  297. A (possibly empty) list of annotations
  298. encoding: (:class:`~tokenizers.Encoding`):
  299. The encoding returned from the tokenizer
  300. Returns:
  301. :obj:`List[CharState]`: A list of CharStates, indicating for each char in the text what
  302. it's state is
  303. """
  304. annotation_map = EncodingVisualizer.__make_anno_map(text, annotations)
  305. # Todo make this a dataclass or named tuple
  306. char_states: List[CharState] = [CharState(char_ix) for char_ix in range(len(text))]
  307. for token_ix, token in enumerate(encoding.tokens):
  308. offsets = encoding.token_to_chars(token_ix)
  309. if offsets is not None:
  310. start, end = offsets
  311. for i in range(start, end):
  312. char_states[i].tokens.append(token_ix)
  313. for char_ix, anno_ix in enumerate(annotation_map):
  314. char_states[char_ix].anno_ix = anno_ix
  315. return char_states
  316. def HTMLBody(children: List[str], css_styles=css) -> str:
  317. """
  318. Generates the full html with css from a list of html spans
  319. Args:
  320. children (:obj:`List[str]`):
  321. A list of strings, assumed to be html elements
  322. css_styles (:obj:`str`, `optional`):
  323. Optional alternative implementation of the css
  324. Returns:
  325. :obj:`str`: An HTML string with style markup
  326. """
  327. children_text = "".join(children)
  328. return f"""
  329. <html>
  330. <head>
  331. <style>
  332. {css_styles}
  333. </style>
  334. </head>
  335. <body>
  336. <div class="tokenized-text" dir=auto>
  337. {children_text}
  338. </div>
  339. </body>
  340. </html>
  341. """