| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227 |
- # coding=utf-8
- # Copyright 2023 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from queue import Queue
- from typing import TYPE_CHECKING, Optional
- if TYPE_CHECKING:
- from ..models.auto import AutoTokenizer
- class BaseStreamer:
- """
- Base class from which `.generate()` streamers should inherit.
- """
- def put(self, value):
- """Function that is called by `.generate()` to push new tokens"""
- raise NotImplementedError()
- def end(self):
- """Function that is called by `.generate()` to signal the end of generation"""
- raise NotImplementedError()
- class TextStreamer(BaseStreamer):
- """
- Simple text streamer that prints the token(s) to stdout as soon as entire words are formed.
- <Tip warning={true}>
- The API for the streamer classes is still under development and may change in the future.
- </Tip>
- Parameters:
- tokenizer (`AutoTokenizer`):
- The tokenized used to decode the tokens.
- skip_prompt (`bool`, *optional*, defaults to `False`):
- Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
- decode_kwargs (`dict`, *optional*):
- Additional keyword arguments to pass to the tokenizer's `decode` method.
- Examples:
- ```python
- >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
- >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
- >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
- >>> streamer = TextStreamer(tok)
- >>> # Despite returning the usual output, the streamer will also print the generated text to stdout.
- >>> _ = model.generate(**inputs, streamer=streamer, max_new_tokens=20)
- An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,
- ```
- """
- def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
- self.tokenizer = tokenizer
- self.skip_prompt = skip_prompt
- self.decode_kwargs = decode_kwargs
- # variables used in the streaming process
- self.token_cache = []
- self.print_len = 0
- self.next_tokens_are_prompt = True
- def put(self, value):
- """
- Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
- """
- if len(value.shape) > 1 and value.shape[0] > 1:
- raise ValueError("TextStreamer only supports batch size 1")
- elif len(value.shape) > 1:
- value = value[0]
- if self.skip_prompt and self.next_tokens_are_prompt:
- self.next_tokens_are_prompt = False
- return
- # Add the new token to the cache and decodes the entire thing.
- self.token_cache.extend(value.tolist())
- text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
- # After the symbol for a new line, we flush the cache.
- if text.endswith("\n"):
- printable_text = text[self.print_len :]
- self.token_cache = []
- self.print_len = 0
- # If the last token is a CJK character, we print the characters.
- elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
- printable_text = text[self.print_len :]
- self.print_len += len(printable_text)
- # Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
- # which may change with the subsequent token -- there are probably smarter ways to do this!)
- else:
- printable_text = text[self.print_len : text.rfind(" ") + 1]
- self.print_len += len(printable_text)
- self.on_finalized_text(printable_text)
- def end(self):
- """Flushes any remaining cache and prints a newline to stdout."""
- # Flush the cache, if it exists
- if len(self.token_cache) > 0:
- text = self.tokenizer.decode(self.token_cache, **self.decode_kwargs)
- printable_text = text[self.print_len :]
- self.token_cache = []
- self.print_len = 0
- else:
- printable_text = ""
- self.next_tokens_are_prompt = True
- self.on_finalized_text(printable_text, stream_end=True)
- def on_finalized_text(self, text: str, stream_end: bool = False):
- """Prints the new text to stdout. If the stream is ending, also prints a newline."""
- print(text, flush=True, end="" if not stream_end else None)
- def _is_chinese_char(self, cp):
- """Checks whether CP is the codepoint of a CJK character."""
- # This defines a "chinese character" as anything in the CJK Unicode block:
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
- #
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
- # despite its name. The modern Korean Hangul alphabet is a different block,
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
- # space-separated words, so they are not treated specially and handled
- # like the all of the other languages.
- if (
- (cp >= 0x4E00 and cp <= 0x9FFF)
- or (cp >= 0x3400 and cp <= 0x4DBF) #
- or (cp >= 0x20000 and cp <= 0x2A6DF) #
- or (cp >= 0x2A700 and cp <= 0x2B73F) #
- or (cp >= 0x2B740 and cp <= 0x2B81F) #
- or (cp >= 0x2B820 and cp <= 0x2CEAF) #
- or (cp >= 0xF900 and cp <= 0xFAFF)
- or (cp >= 0x2F800 and cp <= 0x2FA1F) #
- ): #
- return True
- return False
- class TextIteratorStreamer(TextStreamer):
- """
- Streamer that stores print-ready text in a queue, to be used by a downstream application as an iterator. This is
- useful for applications that benefit from acessing the generated text in a non-blocking way (e.g. in an interactive
- Gradio demo).
- <Tip warning={true}>
- The API for the streamer classes is still under development and may change in the future.
- </Tip>
- Parameters:
- tokenizer (`AutoTokenizer`):
- The tokenized used to decode the tokens.
- skip_prompt (`bool`, *optional*, defaults to `False`):
- Whether to skip the prompt to `.generate()` or not. Useful e.g. for chatbots.
- timeout (`float`, *optional*):
- The timeout for the text queue. If `None`, the queue will block indefinitely. Useful to handle exceptions
- in `.generate()`, when it is called in a separate thread.
- decode_kwargs (`dict`, *optional*):
- Additional keyword arguments to pass to the tokenizer's `decode` method.
- Examples:
- ```python
- >>> from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
- >>> from threading import Thread
- >>> tok = AutoTokenizer.from_pretrained("openai-community/gpt2")
- >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
- >>> inputs = tok(["An increasing sequence: one,"], return_tensors="pt")
- >>> streamer = TextIteratorStreamer(tok)
- >>> # Run the generation in a separate thread, so that we can fetch the generated text in a non-blocking way.
- >>> generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
- >>> thread = Thread(target=model.generate, kwargs=generation_kwargs)
- >>> thread.start()
- >>> generated_text = ""
- >>> for new_text in streamer:
- ... generated_text += new_text
- >>> generated_text
- 'An increasing sequence: one, two, three, four, five, six, seven, eight, nine, ten, eleven,'
- ```
- """
- def __init__(
- self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, timeout: Optional[float] = None, **decode_kwargs
- ):
- super().__init__(tokenizer, skip_prompt, **decode_kwargs)
- self.text_queue = Queue()
- self.stop_signal = None
- self.timeout = timeout
- def on_finalized_text(self, text: str, stream_end: bool = False):
- """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue."""
- self.text_queue.put(text, timeout=self.timeout)
- if stream_end:
- self.text_queue.put(self.stop_signal, timeout=self.timeout)
- def __iter__(self):
- return self
- def __next__(self):
- value = self.text_queue.get(timeout=self.timeout)
- if value == self.stop_signal:
- raise StopIteration()
- else:
- return value
|