| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764 |
- # coding=utf-8
- # Copyright 2023 The HuggingFace Inc. team.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
- import numpy as np
- import torch
- from ..cache_utils import DynamicCache
- from ..pytorch_utils import isin_mps_friendly
- from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor
- if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
- from ..tokenization_utils_base import PreTrainedTokenizerBase
- from .configuration_utils import GenerationConfig
- class CandidateGenerator:
- """Abstract base class for all candidate generators that can be applied during assisted generation."""
- def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
- """
- Fetches the candidates to be tried for the current input.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- Return:
- `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
- assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
- vocabulary_size)` containing the logits associated to each candidate.
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
- )
- def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
- """
- Updates the candidate generation strategy based on the outcomes.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
- Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
- beam search or log softmax for each vocabulary token when using beam search
- num_matches (`int`):
- The number of matches between the candidate sequences and the model predictions.
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
- "`update_candidate_strategy`."
- )
- class AssistedCandidateGenerator(CandidateGenerator):
- """
- `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
- candidates through the use of a smaller model. Read the following blog post for more information:
- https://huggingface.co/blog/assisted-generation
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- assistant_model (`PreTrainedModel`):
- The model to be used for generating candidates. This model should be smaller than the main model.
- generation_config (`~generation.GenerationConfig`, *optional*):
- The generation configuration to be used as base parametrization for the generation call.
- logits_processor (`LogitsProcessorList`):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- model_kwargs (`Dict`):
- The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
- model as well.
- inputs_tensor (`torch.Tensor`, *optional*):
- The model input tensor. In encoder-decoder models, this is the encoder input.
- """
- def __init__(
- self,
- input_ids: torch.LongTensor,
- assistant_model: "PreTrainedModel",
- generation_config: "GenerationConfig",
- model_kwargs: Dict,
- inputs_tensor: Optional[torch.Tensor] = None,
- logits_processor: "LogitsProcessorList" = None,
- ):
- # Make sure all data at the same device as assistant model
- device = assistant_model.device
- input_ids = input_ids.to(device)
- if inputs_tensor is not None:
- inputs_tensor = inputs_tensor.to(device)
- # Prepare the assistant and the starting number of candidate tokens
- self.assistant_model = assistant_model
- self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
- self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold
- # Set eos in assistant same as in target model
- self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
- # Prepare the kwargs for the assistant model
- assistant_kwargs = {}
- for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
- if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"):
- assistant_kwargs[key] = (
- value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
- )
- # Remove potential default "num_logits_to_keep" key
- if "num_logits_to_keep" in assistant_kwargs.keys() and not assistant_model._supports_num_logits_to_keep():
- del assistant_kwargs["num_logits_to_keep"]
- if "assistant_encoder_outputs" in model_kwargs:
- assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
- elif assistant_model.config.is_encoder_decoder:
- inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
- inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
- )
- assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
- inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config
- )
- elif "encoder_outputs" in model_kwargs:
- assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
- self.assistant_kwargs = assistant_kwargs
- # Prepare assistant model's keys of inputs
- if assistant_model.config.is_encoder_decoder:
- # both are encoder-decoder
- self.input_ids_key = "decoder_input_ids"
- elif "encoder_outputs" in assistant_kwargs:
- # special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
- self.input_ids_key = "input_ids"
- self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
- "decoder_attention_mask",
- torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
- )
- else:
- # both are decoder-only
- self.input_ids_key = "input_ids"
- # Prepare generation-related options.
- self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
- self.generation_config = copy.deepcopy(generation_config)
- self.generation_config.return_dict_in_generate = True
- self.generation_config.output_scores = True
- self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
- # this flag allow us set the confidence stopping criteria for assistant model generation.
- self.generation_config.is_assistant = True
- # avoid unnecessary warnings that min_length is larger than max_new_tokens
- # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
- self.main_model_min_length = self.generation_config.min_length
- self.generation_config.min_length = 0
- self.generation_config.min_new_tokens = None
- for processor in self.logits_processor:
- if isinstance(processor, MinLengthLogitsProcessor):
- raise ValueError(
- "Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
- "Please pass in `min_length` into `.generate()` instead"
- )
- # We need to roll back the cache in assisted generation, only DynamicCache is supported
- self.generation_config.cache_implementation = None
- def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
- """
- Fetches the candidates to be tried for the current input.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- Return:
- `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
- assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
- vocabulary_size)` containing the logits associated to each candidate.
- """
- input_ids = input_ids.to(self.assistant_model.device)
- # Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
- new_cur_len = input_ids.shape[-1]
- max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
- min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
- if max_new_tokens == 0:
- return input_ids, None
- # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
- # (which implicitly contains the number of accepted candidates from the previous round)
- has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
- if has_past_key_values:
- new_cache_size = new_cur_len - 1
- self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
- self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
- ) # the assistant does not have the token after the last match, hence the -1
- self.assistant_kwargs = _prepare_attention_mask(
- self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
- )
- self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
- # 2. Forecast next N tokens using the assistant model.
- assistant_generation_kwargs = {
- self.input_ids_key: input_ids,
- "min_new_tokens": min_new_tokens,
- "max_new_tokens": max_new_tokens,
- "generation_config": self.generation_config,
- "logits_processor": self.logits_processor,
- }
- assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
- # 3. Update variables for the next round of candidate generation
- self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
- # 4. Prepare variables for output
- candidate_logits = torch.stack(assistant_output.scores, dim=1)
- candidate_ids = assistant_output.sequences
- return candidate_ids, candidate_logits
- def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
- """
- Updates the candidate generation strategy based on the outcomes.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
- Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
- beam search or log softmax for each vocabulary token when using beam search
- num_matches (`int`):
- The number of matches between the candidate sequences and the model predictions.
- """
- # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
- # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
- # cost of forecasting incorrect assistant tokens.
- if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
- "heuristic",
- "heuristic_transient",
- }:
- if num_matches == int(self.num_assistant_tokens):
- self.num_assistant_tokens += 2.0
- else:
- self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
- class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
- """
- `CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers
- for the assistant and main models. This class generates candidates through the use of a smaller
- model.
- The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
- in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
- The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
- Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
- to ensure the new tokens include the correct prompt suffix.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- assistant_model (`PreTrainedModel`):
- The model to be used for generating candidates. This model should be smaller than the main model.
- target_tokenizer (`PreTrainedTokenizerBase`):
- The tokenizer used for the target model.
- assistant_tokenizer (`PreTrainedTokenizerBase`):
- The tokenizer used for the assistant model.
- generation_config (`~generation.GenerationConfig`, *optional*):
- The generation configuration to be used as base parametrization for the generation call.
- logits_processor (`LogitsProcessorList`):
- An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
- used to modify the prediction scores of the language modeling head applied at each generation step.
- model_kwargs (`Dict`):
- The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
- model as well.
- inputs_tensor (`torch.Tensor`, *optional*):
- The model input tensor. In encoder-decoder models, this is the encoder input.
- """
- def __init__(
- self,
- input_ids: torch.LongTensor,
- assistant_model: "PreTrainedModel",
- target_tokenizer: "PreTrainedTokenizerBase",
- assistant_tokenizer: "PreTrainedTokenizerBase",
- generation_config: "GenerationConfig",
- model_kwargs: Dict,
- inputs_tensor: Optional[torch.Tensor] = None,
- logits_processor: "LogitsProcessorList" = None,
- ):
- super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor)
- self.target_tokenizer = target_tokenizer
- self.assistant_tokenizer = assistant_tokenizer
- self.prev_tokens = None
- self.prev_assistant_ids = None
- self.target_lookbehind = 10
- self.assistant_lookbehind = 10
- @staticmethod
- def _get_longest_diag_dict(input_matrix, nonzero_idx):
- """
- Calculates the length of the longest diagonal sequence in a given matrix.
- Args:
- input_matrix (torch.Tensor): The input matrix.
- nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix.
- Returns:
- dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices.
- """
- visited = set()
- diags = {}
- for idx in nonzero_idx:
- start_idx = torch.clone(idx)
- tuple_start_idx = tuple(start_idx.tolist())
- if tuple_start_idx in visited:
- continue
- visited.add(tuple_start_idx)
- cur_diag_len = 1
- start_idx += 1
- while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]:
- tuple_start_idx = tuple(start_idx.tolist())
- visited.add(tuple_start_idx)
- if input_matrix[start_idx[0], start_idx[1]] == 1:
- cur_diag_len += 1
- start_idx += 1
- else:
- break
- diags[idx] = cur_diag_len
- return diags
- @staticmethod
- def _get_longest_diag_index(input_matrix):
- """
- Returns the start index and length of the longest diagonal in the given input.
- Args:
- input_matrix (numpy.ndarray): The input matrix.
- Returns:
- tuple: A tuple containing the start index and length of the longest diagonal.
- """
- diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict(
- input_matrix, input_matrix.nonzero()
- )
- diags_values = list(diags.values())
- diags_keys = list(diags.keys())
- best_diag = np.argmax(diags_values)
- diag_start_index = diags_keys[best_diag]
- diag_start_length = diags_values[best_diag]
- return diag_start_index, diag_start_length
- @staticmethod
- def _get_tokens_diag(prompt, prompt_plus_new_tokens):
- """
- Input:
- prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens
- prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens.
- Output:
- discrepancy_length: int, represents the number of tokens that need to be replaced from prompt
- new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt
- discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens
- """
- compare_mat = prompt_plus_new_tokens.T == prompt
- if not torch.is_tensor(compare_mat):
- compare_mat = torch.tensor(compare_mat)
- compare_mat_int = compare_mat.to(int)
- if not compare_mat_int.any().item():
- # empty intersection between prompt and prompt_plus_new_tokens
- return None, None, None
- longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index(
- compare_mat_int
- )
- new_token_start_index = longest_location[0] + longest_diag_length
- discrepancy_with_old = longest_location[1] + longest_diag_length
- discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item()
- new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :]
- discrepancy_only = prompt_plus_new_tokens[
- :, new_token_start_index : new_token_start_index + discrepancy_length
- ]
- return discrepancy_length, new_tokens_only, discrepancy_only
- def convert_source_tokens_to_target_tokens(
- self,
- input_ids,
- source_tokenizer,
- destination_tokenizer,
- ):
- """
- Convert token IDs from one tokenizer to another.
- Args:
- input_ids: The input token IDs.
- source_tokenizer: The source tokenizer.
- destination_tokenizer: The destination tokenizer.
- Returns:
- The converted token IDs.
- """
- text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
- dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
- return dest_ids.to(input_ids.device)
- def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
- """
- Fetches the candidates to be tried for the current input.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- Return:
- `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
- assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
- vocabulary_size)` containing the logits associated to each candidate.
- """
- max_new_tokens = int(self.num_assistant_tokens)
- if max_new_tokens == 0:
- return input_ids, None
- input_ids = input_ids.to(self.assistant_model.device)
- convert_kwargs = {
- "source_tokenizer": self.target_tokenizer,
- "destination_tokenizer": self.assistant_tokenizer,
- }
- remove_from_pkv = 0
- # Since re-encoding the tokens may result in tokenization discrepancies, we use 2 look behind values
- # (one for each conversion) which mark where to start looking for the overlap between the
- # source and target encodings, to ensure the new tokens include the correct prompt suffix.
- if self.prev_tokens is not None and self.prev_target_ids.shape[1] > self.target_lookbehind:
- # input_ids contains all target prompt input ids and some new target input ids
- start_index_in_target_window = self.prev_target_ids.shape[1] - self.target_lookbehind
- new_assistant_ids = self.convert_source_tokens_to_target_tokens(
- input_ids[:, start_index_in_target_window:], **convert_kwargs
- )
- prompt_use_length = new_assistant_ids.shape[1]
- prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
- discrepancy_length, new_tokens_only, discrepancy_only = (
- AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(prompt_use, new_assistant_ids)
- )
- assistant_input_ids = self.prev_assistant_ids
- if new_tokens_only is not None:
- if discrepancy_length > 0 and discrepancy_only.shape[1] > 0:
- if discrepancy_length == discrepancy_only.shape[1]:
- assistant_input_ids[:, -discrepancy_length:] = discrepancy_only
- elif discrepancy_length > discrepancy_only.shape[1]:
- discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1]
- assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff]
- assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only
- remove_from_pkv = discrepancy_length
- if new_tokens_only.shape[1] > 0:
- assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1)
- else:
- # edge case: in case of no intersection between prompt and new_assistant_ids
- assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
- else:
- assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
- self.prev_target_ids = input_ids
- self.prev_assistant_ids = assistant_input_ids
- new_cur_len = assistant_input_ids.shape[-1]
- min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
- # 1. If it is not the first round of candidate generation, prepare the inputs based on the input_ids length
- # (which implicitly contains the number of accepted candidates from the previous round)
- has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
- if has_past_key_values:
- new_cache_size = new_cur_len - 1 - remove_from_pkv
- self.assistant_kwargs["past_key_values"] = _crop_past_key_values(
- self.assistant_model, self.assistant_kwargs["past_key_values"], new_cache_size - 1
- ) # the assistant does not have the token after the last match, hence the -1
- self.assistant_kwargs = _prepare_attention_mask(
- self.assistant_kwargs, new_cur_len, self.assistant_model.config.is_encoder_decoder
- )
- self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, new_cur_len)
- # 2. Forecast next N tokens using the assistant model.
- assistant_generation_kwargs = {
- self.input_ids_key: assistant_input_ids,
- "min_new_tokens": min_new_tokens,
- "max_new_tokens": max_new_tokens,
- "generation_config": self.generation_config,
- "logits_processor": self.logits_processor,
- }
- self.assistant_kwargs.pop("attention_mask", None)
- assistant_output = self.assistant_model.generate(**assistant_generation_kwargs, **self.assistant_kwargs)
- num_prev_assistant = self.prev_assistant_ids.shape[1]
- start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
- new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
- assistant_output.sequences[:, start_assistant_look_index:],
- source_tokenizer=self.assistant_tokenizer,
- destination_tokenizer=self.target_tokenizer,
- )
- target_prompt_use_length = new_target_ids_from_window.shape[1]
- target_prompt_use = input_ids[:, -target_prompt_use_length:]
- _, target_new_tokens_only, _ = AssistedCandidateGeneratorDifferentTokenizers._get_tokens_diag(
- target_prompt_use, new_target_ids_from_window
- )
- new_target_ids = input_ids
- if target_new_tokens_only is not None:
- if target_new_tokens_only.shape[1] > 0:
- new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1)
- else:
- # edge case: in case of no intersection between prompt and new_target_ids
- new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
- self.prev_target_ids = input_ids
- if hasattr(self.generation_config, "max_length"):
- new_target_ids = new_target_ids[:, : self.generation_config.max_length]
- # 3. Update variables for the next round of candidate generation
- self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
- self.prev_tokens = assistant_output.sequences
- # 4. Prepare variables for output
- if input_ids.shape[1] >= new_target_ids.shape[1]:
- return input_ids, None
- return new_target_ids, None
- class PromptLookupCandidateGenerator(CandidateGenerator):
- """
- `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
- likely continuations in the provided prompt (input_ids) itself.
- Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
- Args:
- max_matching_ngram_size (`int`):
- The maximum ngram size to be considered for matching in the prompt
- num_output_tokens (`int`):
- The number of tokens to be output as candidate tokens.
- max_length (`int`):
- The number of total maximum tokens that can be generated. For decoder-only models that includes the prompt length.
- Defaults to 20, which is the max length used as default in generation config.
- """
- def __init__(
- self,
- eos_token_id: torch.Tensor = None,
- num_output_tokens: int = 10,
- max_matching_ngram_size: int = None,
- max_length: int = 20,
- ):
- self.num_output_tokens = num_output_tokens
- self.max_matching_ngram_size = max_matching_ngram_size if max_matching_ngram_size else 2
- self.max_length = max_length
- self.eos_token_id = eos_token_id
- if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
- raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
- def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
- """
- Fetches the candidates to be tried for the current input.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- Return:
- `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
- """
- input_length = input_ids.size(1)
- # Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
- if self.max_length == input_length + 1:
- return input_ids, None
- chosen_ids = None
- match_found = False
- for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
- # Create sliding windows of size ngram_size
- windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
- # Convert ngram to a tensor for comparison
- ngram_tensor = input_ids[0, -ngram_size:]
- # Find where the windows match the ngram
- matches = (windows == ngram_tensor).all(dim=2)
- # Get the indices of matches
- match_indices = matches.nonzero(as_tuple=True)[1]
- # Iterate through match indices to find a valid continuation
- for idx in match_indices:
- start_idx = idx + ngram_size
- end_idx = start_idx + self.num_output_tokens
- end_idx = min(end_idx, input_length, self.max_length)
- if start_idx < end_idx:
- chosen_ids = input_ids[0, start_idx:end_idx]
- match_found = True
- # remove remaining candidate ids if an "eos" token is found, otherwise the target model may
- # accept eos and the rest as valid, thus not stopping generation after "eos"
- # NOTE: below code is written based on the fact that assisted decoding supports only bs=1
- mask = isin_mps_friendly(chosen_ids, self.eos_token_id)
- match_indices_eos = torch.nonzero(mask)
- if match_indices_eos.numel() > 0:
- first_eos_index = match_indices_eos[0].item()
- chosen_ids = chosen_ids[:first_eos_index]
- break
- if match_found:
- break
- if chosen_ids is None or len(chosen_ids) == 0:
- # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
- return input_ids, None
- # Now need extend input_ids with chosen_ids
- chosen_ids = chosen_ids.unsqueeze(0)
- candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
- # assisted_generation expects logits as well, but we don't have those here, so returning None
- return candidate_input_ids, None
- def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
- """
- Updates the candidate generation strategy based on the outcomes.
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
- Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
- scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
- Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
- beam search or log softmax for each vocabulary token when using beam search
- num_matches (`int`):
- The number of matches between the candidate sequences and the model predictions.
- """
- # Currently does nothing
- return
- def _crop_past_key_values(model, past_key_values, max_length):
- """Crops the past key values up to a certain maximum length."""
- new_past = []
- if model.config.is_encoder_decoder:
- for idx in range(len(past_key_values)):
- new_past.append(
- (
- past_key_values[idx][0][:, :, :max_length, :],
- past_key_values[idx][1][:, :, :max_length, :],
- past_key_values[idx][2],
- past_key_values[idx][3],
- )
- )
- past_key_values = tuple(new_past)
- # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
- elif "gptbigcode" in model.__class__.__name__.lower() or (
- model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
- ):
- if model.config.multi_query:
- for idx in range(len(past_key_values)):
- past_key_values[idx] = past_key_values[idx][:, :max_length, :]
- else:
- for idx in range(len(past_key_values)):
- past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
- elif isinstance(past_key_values, DynamicCache):
- past_key_values.crop(max_length)
- elif past_key_values is not None:
- for idx in range(len(past_key_values)):
- if past_key_values[idx] != ([], []):
- new_past.append(
- (
- past_key_values[idx][0][:, :, :max_length, :],
- past_key_values[idx][1][:, :, :max_length, :],
- )
- )
- else:
- new_past.append((past_key_values[idx][0], past_key_values[idx][1]))
- past_key_values = tuple(new_past)
- return past_key_values
- def _prepare_attention_mask(model_kwargs: Dict[str, Any], new_length: int, is_encoder_decoder: bool) -> Dict[str, Any]:
- """Expands or crops the model's mask for decoding purposes, to the defined length"""
- mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
- if mask_key not in model_kwargs:
- return model_kwargs
- mask = model_kwargs[mask_key]
- mask_length_diff = new_length - mask.shape[1]
- if mask_length_diff < 0:
- model_kwargs[mask_key] = mask[:, :mask_length_diff]
- elif mask_length_diff > 0:
- model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
- # Handle cross attention models
- if "cross_attention_mask" in model_kwargs:
- # Mllama case
- cross_mask = model_kwargs["cross_attention_mask"]
- if mask_length_diff < 0:
- model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
- elif mask_length_diff > 0:
- new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
- model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
- elif "image_attention_mask" in model_kwargs:
- # IDEFICS case
- cross_mask = model_kwargs["image_attention_mask"]
- if mask_length_diff < 0:
- model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
- elif mask_length_diff > 0:
- new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
- model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
- return model_kwargs
- def _prepare_token_type_ids(model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]:
- """Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
- if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
- return model_kwargs
- token_type_ids = model_kwargs["token_type_ids"]
- final_token_type = token_type_ids[:, -1].unsqueeze(-1)
- type_length_diff = new_length - token_type_ids.shape[1]
- if type_length_diff < 0:
- token_type_ids = token_type_ids[:, :type_length_diff]
- elif type_length_diff > 0:
- token_type_copies = final_token_type.repeat(1, type_length_diff)
- model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
- return model_kwargs
|