streamers.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from queue import Queue
  16. from typing import TYPE_CHECKING, Optional
  17. if TYPE_CHECKING:
  18. from ..models.auto import AutoTokenizer
  19. class BaseStreamer:
  20. """
  21. Base class from which `.generate()` streamers should inherit.
  22. """
  23. def put(self, value):
  24. """Function that is called by `.generate()` to push new tokens"""
  25. raise NotImplementedError()
  26. def end(self):
  27. """Function that is called by `.generate()` to signal the end of generation"""
  28. raise NotImplementedError()
  29. class TextStreamer(BaseStreamer):
  30. """
  31. Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
  32. <Tip warning={true}>
  33. The API for the streamer classes is still under development and may change in the future.
  34. </Tip>
  35. Parameters:
  36. tokenizer (`AutoTokenizer`):
  37. The tokenized used to decode the tokens.
  38. skip_prompt (`bool`, *optional*, defaults to `False`):
  39. Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
  40. decode_kwargs (`dict`, *optional*):
  41. Additional keyword arguments to pass to the tokenizer's `decode` method.
  42. Examples:
  43. ```python
  44. >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
  45. >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
  46. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  47. >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
  48. >>> streamer = TextStreamer(tok)
  49. >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
  50. >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
  51. An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
  52. ```
  53. """
  54. def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
  55. self.tokenizer = tokenizer
  56. self.skip_prompt = skip_prompt
  57. self.decode_kwargs = decode_kwargs
  58. # variables used in the streaming process
  59. self.token_cache = []
  60. self.print_len = 0
  61. self.next_tokens_are_prompt = True
  62. def put(self, value):
  63. """
  64. Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
  65. """
  66. if len(value.shape) > 1 and value.shape[0] > 1:
  67. raise ValueError("TextStreamer only supports batch size 1")
  68. elif len(value.shape) > 1:
  69. value = value[0]
  70. if self.skip_prompt and self.next_tokens_are_prompt:
  71. self.next_tokens_are_prompt = False
  72. return
  73. # Add the new token to the cache and decodes the entire thing.
  74. self.token_cache.extend(value.tolist())
  75. text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
  76. # After the symbol for a new line, we flush the cache.
  77. if text.endswith("\n"):
  78. printable_text = text[self.print_len :]
  79. self.token_cache = []
  80. self.print_len = 0
  81. # If the last token is a CJK character, we print the characters.
  82. elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
  83. printable_text = text[self.print_len :]
  84. self.print_len += len(printable_text)
  85. # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
  86. # which may change with the subsequent token -- there are probably smarter ways to do this!)
  87. else:
  88. printable_text = text[self.print_len : text.rfind(" ") + 1]
  89. self.print_len += len(printable_text)
  90. self.on_finalized_text(printable_text)
  91. def end(self):
  92. """Flushes any remaining cache and prints a newline to stdout."""
  93. # Flush the cache, if it exists
  94. if len(self.token_cache) > 0:
  95. text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
  96. printable_text = text[self.print_len :]
  97. self.token_cache = []
  98. self.print_len = 0
  99. else:
  100. printable_text = ""
  101. self.next_tokens_are_prompt = True
  102. self.on_finalized_text(printable_text, stream_end=True)
  103. def on_finalized_text(self, text: str, stream_end: bool = False):
  104. """Prints the new text to stdout. If the stream is ending, also prints a newline."""
  105. print(text, flush=True, end="" if not stream_end else None)
  106. def _is_chinese_char(self, cp):
  107. """Checks whether CP is the codepoint of a CJK character."""
  108. # This defines a "chinese character" as anything in the CJK Unicode block:
  109. # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
  110. #
  111. # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
  112. # despite its name. The modern Korean Hangul alphabet is a different block,
  113. # as is Japanese Hiragana and Katakana. Those alphabets are used to write
  114. # space-separated words, so they are not treated specially and handled
  115. # like the all of the other languages.
  116. if (
  117. (cp >= 0x4E00 and cp <= 0x9FFF)
  118. or (cp >= 0x3400 and cp <= 0x4DBF) #
  119. or (cp >= 0x20000 and cp <= 0x2A6DF) #
  120. or (cp >= 0x2A700 and cp <= 0x2B73F) #
  121. or (cp >= 0x2B740 and cp <= 0x2B81F) #
  122. or (cp >= 0x2B820 and cp <= 0x2CEAF) #
  123. or (cp >= 0xF900 and cp <= 0xFAFF)
  124. or (cp >= 0x2F800 and cp <= 0x2FA1F) #
  125. ): #
  126. return True
  127. return False
  128. class TextIteratorStreamer(TextStreamer):
  129. """
  130. Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
  131. useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
  132. Gradio demo).
  133. <Tip warning={true}>
  134. The API for the streamer classes is still under development and may change in the future.
  135. </Tip>
  136. Parameters:
  137. tokenizer (`AutoTokenizer`):
  138. The tokenized used to decode the tokens.
  139. skip_prompt (`bool`, *optional*, defaults to `False`):
  140. Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
  141. timeout (`float`, *optional*):
  142. The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
  143. in `.generate()`, when it is called in a separate thread.
  144. decode_kwargs (`dict`, *optional*):
  145. Additional keyword arguments to pass to the tokenizer's `decode` method.
  146. Examples:
  147. ```python
  148. >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
  149. >>> from threading import Thread
  150. >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
  151. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  152. >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
  153. >>> streamer = TextIteratorStreamer(tok)
  154. >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
  155. >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
  156. >>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
  157. >>> thread.start()
  158. >>> generated_text = ""
  159. >>> for new_text in streamer:
  160. ... generated_text += new_text
  161. >>> generated_text
  162. 'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
  163. ```
  164. """
  165. def __init__(
  166. self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
  167. ):
  168. super().__init__(tokenizer, skip_prompt, **decode_kwargs)
  169. self.text_queue = Queue()
  170. self.stop_signal = None
  171. self.timeout = timeout
  172. def on_finalized_text(self, text: str, stream_end: bool = False):
  173. """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
  174. self.text_queue.put(text, timeout=self.timeout)
  175. if stream_end:
  176. self.text_queue.put(self.stop_signal, timeout=self.timeout)
  177. def __iter__(self):
  178. return self
  179. def __next__(self):
  180. value = self.text_queue.get(timeout=self.timeout)
  181. if value == self.stop_signal:
  182. raise StopIteration()
  183. else:
  184. return value