flax_logits_process.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544
  1. # coding=utf-8
  2. # Copyright 2021 The HuggingFace Inc. team
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import inspect
  16. import jax
  17. import jax.lax as lax
  18. import jax.numpy as jnp
  19. from jax.experimental import sparse
  20. from ..utils import add_start_docstrings
  21. from ..utils.logging import get_logger
  22. logger = get_logger(__name__)
  23. LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
  24. Args:
  25. input_ids (`jnp.ndarray` of shape `(batch_size, sequence_length)`):
  26. Indices of input sequence tokens in the vocabulary.
  27. Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  28. [`PreTrainedTokenizer.__call__`] for details.
  29. [What are input IDs?](../glossary#input-ids)
  30. scores (`jnp.ndarray` of shape `(batch_size, config.vocab_size)`):
  31. Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
  32. search or log softmax for each vocabulary token when using beam search
  33. kwargs (`Dict[str, Any]`, *optional*):
  34. Additional logits processor specific kwargs.
  35. Return:
  36. `jnp.ndarray` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
  37. """
  38. class FlaxLogitsProcessor:
  39. """Abstract base class for all logit processors that can be applied during generation."""
  40. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  41. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
  42. """Flax method for processing logits."""
  43. raise NotImplementedError(
  44. f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
  45. )
  46. class FlaxLogitsWarper:
  47. """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
  48. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  49. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray) -> jnp.ndarray:
  50. """Flax method for warping logits."""
  51. raise NotImplementedError(
  52. f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
  53. )
  54. class FlaxLogitsProcessorList(list):
  55. """
  56. This class can be used to create a list of [`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to subsequently process
  57. a `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
  58. [`FlaxLogitsProcessor`] or [`FlaxLogitsWarper`] to the inputs.
  59. """
  60. @add_start_docstrings(LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  61. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int, **kwargs) -> jnp.ndarray:
  62. for processor in self:
  63. function_args = inspect.signature(processor.__call__).parameters
  64. if len(function_args) > 3:
  65. if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
  66. raise ValueError(
  67. f"Make sure that all the required parameters: {list(function_args.keys())} for "
  68. f"{processor.__class__} are passed to the logits processor."
  69. )
  70. scores = processor(input_ids, scores, cur_len, **kwargs)
  71. else:
  72. scores = processor(input_ids, scores, cur_len)
  73. return scores
  74. class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
  75. r"""
  76. [`FlaxLogitsWarper`] for temperature (exponential scaling output probability distribution).
  77. Args:
  78. temperature (`float`):
  79. The value used to module the logits distribution.
  80. """
  81. def __init__(self, temperature: float):
  82. if not isinstance(temperature, float) or not (temperature > 0):
  83. raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
  84. self.temperature = temperature
  85. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  86. scores = scores / self.temperature
  87. return scores
  88. class FlaxTopPLogitsWarper(FlaxLogitsWarper):
  89. """
  90. [`FlaxLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
  91. Args:
  92. top_p (`float`):
  93. If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
  94. higher are kept for generation.
  95. filter_value (`float`, *optional*, defaults to -inf):
  96. All filtered values will be set to this float value.
  97. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  98. Minimum number of tokens that cannot be filtered.
  99. """
  100. def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  101. if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
  102. raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
  103. if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
  104. raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
  105. self.top_p = top_p
  106. self.filter_value = filter_value
  107. self.min_tokens_to_keep = min_tokens_to_keep
  108. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  109. topk_scores, topk_indices = lax.top_k(scores, scores.shape[-1])
  110. mask_scores = jnp.full_like(scores, self.filter_value)
  111. cumulative_probs = jax.nn.softmax(topk_scores, axis=-1).cumsum(axis=-1)
  112. score_mask = cumulative_probs < self.top_p
  113. # include the token that is higher than top_p as well
  114. score_mask = jnp.roll(score_mask, 1)
  115. score_mask |= score_mask.at[:, 0].set(True)
  116. # min tokens to keep
  117. score_mask = score_mask.at[:, : self.min_tokens_to_keep].set(True)
  118. topk_next_scores = jnp.where(score_mask, topk_scores, mask_scores)
  119. next_scores = jax.lax.sort_key_val(topk_indices, topk_next_scores)[-1]
  120. return next_scores
  121. class FlaxTopKLogitsWarper(FlaxLogitsWarper):
  122. r"""
  123. [`FlaxLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
  124. Args:
  125. top_k (`int`):
  126. The number of highest probability vocabulary tokens to keep for top-k-filtering.
  127. filter_value (`float`, *optional*, defaults to -inf):
  128. All filtered values will be set to this float value.
  129. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  130. Minimum number of tokens that cannot be filtered.
  131. """
  132. def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  133. if not isinstance(top_k, int) or top_k <= 0:
  134. raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
  135. self.top_k = max(top_k, min_tokens_to_keep)
  136. self.filter_value = filter_value
  137. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  138. batch_size, vocab_size = scores.shape
  139. next_scores_flat = jnp.full(batch_size * vocab_size, self.filter_value)
  140. topk = min(self.top_k, scores.shape[-1]) # Safety check
  141. topk_scores, topk_indices = lax.top_k(scores, topk)
  142. shift = jnp.broadcast_to((jnp.arange(batch_size) * vocab_size)[:, None], (batch_size, topk)).flatten()
  143. topk_scores_flat = topk_scores.flatten()
  144. topk_indices_flat = topk_indices.flatten() + shift
  145. next_scores_flat = next_scores_flat.at[topk_indices_flat].set(topk_scores_flat)
  146. next_scores = next_scores_flat.reshape(batch_size, vocab_size)
  147. return next_scores
  148. class FlaxForcedBOSTokenLogitsProcessor(FlaxLogitsProcessor):
  149. r"""
  150. [`FlaxLogitsProcessor`] that enforces the specified token as the first generated token.
  151. Args:
  152. bos_token_id (`int`):
  153. The id of the token to force as the first generated token.
  154. """
  155. def __init__(self, bos_token_id: int):
  156. self.bos_token_id = bos_token_id
  157. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  158. new_scores = jnp.full(scores.shape, -float("inf"))
  159. apply_penalty = 1 - jnp.bool_(cur_len - 1)
  160. scores = jnp.where(apply_penalty, new_scores.at[:, self.bos_token_id].set(0), scores)
  161. return scores
  162. class FlaxForcedEOSTokenLogitsProcessor(FlaxLogitsProcessor):
  163. r"""
  164. [`FlaxLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
  165. Args:
  166. max_length (`int`):
  167. The maximum length of the sequence to be generated.
  168. eos_token_id (`int`):
  169. The id of the token to force as the last generated token when `max_length` is reached.
  170. """
  171. def __init__(self, max_length: int, eos_token_id: int):
  172. self.max_length = max_length
  173. self.eos_token_id = eos_token_id
  174. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  175. new_scores = jnp.full(scores.shape, -float("inf"))
  176. apply_penalty = 1 - jnp.bool_(cur_len - self.max_length + 1)
  177. scores = jnp.where(apply_penalty, new_scores.at[:, self.eos_token_id].set(0), scores)
  178. return scores
  179. class FlaxMinLengthLogitsProcessor(FlaxLogitsProcessor):
  180. r"""
  181. [`FlaxLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
  182. Args:
  183. min_length (`int`):
  184. The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
  185. eos_token_id (`int`):
  186. The id of the *end-of-sequence* token.
  187. """
  188. def __init__(self, min_length: int, eos_token_id: int):
  189. if not isinstance(min_length, int) or min_length < 0:
  190. raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
  191. if not isinstance(eos_token_id, int) or eos_token_id < 0:
  192. raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
  193. self.min_length = min_length
  194. self.eos_token_id = eos_token_id
  195. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  196. # create boolean flag to decide if min length penalty should be applied
  197. apply_penalty = 1 - jnp.clip(cur_len - self.min_length, 0, 1)
  198. scores = jnp.where(apply_penalty, scores.at[:, self.eos_token_id].set(-float("inf")), scores)
  199. return scores
  200. class FlaxSuppressTokensAtBeginLogitsProcessor(FlaxLogitsProcessor):
  201. r"""
  202. [`FlaxLogitsProcessor`] supressing a list of tokens as soon as the `generate` function starts generating using
  203. `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` are not sampled at the
  204. begining of the generation.
  205. Args:
  206. begin_suppress_tokens (`List[int]`):
  207. Tokens to not sample.
  208. begin_index (`int`):
  209. Index where the tokens are suppressed.
  210. """
  211. def __init__(self, begin_suppress_tokens, begin_index):
  212. self.begin_suppress_tokens = list(begin_suppress_tokens)
  213. self.begin_index = begin_index
  214. def __call__(self, input_ids, scores, cur_len: int):
  215. apply_penalty = 1 - jnp.bool_(cur_len - self.begin_index)
  216. scores = jnp.where(apply_penalty, scores.at[:, self.begin_suppress_tokens].set(-float("inf")), scores)
  217. return scores
  218. class FlaxSuppressTokensLogitsProcessor(FlaxLogitsProcessor):
  219. r"""
  220. [`FlaxLogitsProcessor`] suppressing a list of tokens at each decoding step. The processor will set their log probs
  221. to be `-inf` so they are not sampled.
  222. Args:
  223. suppress_tokens (`list`):
  224. Tokens to not sample.
  225. """
  226. def __init__(self, suppress_tokens: list):
  227. self.suppress_tokens = list(suppress_tokens)
  228. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  229. scores = scores.at[..., self.suppress_tokens].set(-float("inf"))
  230. return scores
  231. class FlaxForceTokensLogitsProcessor(FlaxLogitsProcessor):
  232. r"""
  233. [`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
  234. token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
  235. to `-inf` so that they are sampled at their corresponding index.
  236. Args:
  237. force_token_map (`list`):
  238. Map giving token ids and indices where they will be forced to be sampled.
  239. """
  240. def __init__(self, force_token_map):
  241. force_token_map = dict(force_token_map)
  242. # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
  243. # index of the array corresponds to the index of the token to be forced, for XLA compatibility.
  244. # Indexes without forced tokens will have a negative value.
  245. force_token_array = jnp.ones((max(force_token_map.keys()) + 1), dtype=jnp.int32) * -1
  246. for index, token in force_token_map.items():
  247. if token is not None:
  248. force_token_array = force_token_array.at[index].set(token)
  249. self.force_token_array = jnp.int32(force_token_array)
  250. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  251. def _force_token(generation_idx):
  252. batch_size = scores.shape[0]
  253. current_token = self.force_token_array[generation_idx]
  254. new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
  255. updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
  256. new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
  257. return new_scores
  258. scores = lax.cond(
  259. cur_len >= self.force_token_array.shape[0],
  260. # If the current length is geq than the length of force_token_array, the processor does nothing.
  261. lambda: scores,
  262. # Otherwise, it may force a certain token.
  263. lambda: lax.cond(
  264. self.force_token_array[cur_len] >= 0,
  265. # Only valid (positive) tokens are forced
  266. lambda: _force_token(cur_len),
  267. # Otherwise, the processor does nothing.
  268. lambda: scores,
  269. ),
  270. )
  271. return scores
  272. class FlaxWhisperTimeStampLogitsProcessor(FlaxLogitsProcessor):
  273. r"""
  274. Whisper specific Processor. This processor can be used to force a list of tokens. The processor will set their log
  275. probs to `inf` so that they are sampled at their corresponding index.
  276. Args:
  277. generate_config (`GenerateConfig`):
  278. The generate config used to generate the output. The following parameters are required:
  279. eos_token_id (`int`, *optional*, defaults to 50257):
  280. The id of the *end-of-sequence* token.
  281. no_timestamps_token_id (`int`, *optional*, defaults to 50363):
  282. The id of the `"<|notimestamps|>"` token.
  283. max_initial_timestamp_index (`int`, *optional*, defaults to 1):
  284. Used to set the maximum value of the initial timestamp. This is used to prevent the model from
  285. predicting timestamps that are too far in the future.
  286. """
  287. def __init__(self, generate_config, model_config, decoder_input_length):
  288. self.eos_token_id = generate_config.eos_token_id
  289. self.no_timestamps_token_id = generate_config.no_timestamps_token_id
  290. self.timestamp_begin = generate_config.no_timestamps_token_id + 1
  291. self.begin_index = decoder_input_length + 1
  292. if generate_config.is_multilingual:
  293. # room for language token and task token
  294. self.begin_index += 2
  295. if hasattr(generate_config, "max_initial_timestamp_index"):
  296. self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
  297. else:
  298. self.max_initial_timestamp_index = model_config.vocab_size
  299. if self.max_initial_timestamp_index is None:
  300. self.max_initial_timestamp_index = model_config.vocab_size
  301. def __call__(self, input_ids, scores, cur_len):
  302. # suppress <|notimestamps|> which is handled by without_timestamps
  303. scores = scores.at[:, self.no_timestamps_token_id].set(-float("inf"))
  304. def handle_pairs(input_ids_k, scores_k):
  305. last_was_timestamp = jnp.where((cur_len - self.begin_index) >= 1, True, False)
  306. last_was_timestamp = jnp.where(
  307. input_ids_k[cur_len - 1] >= self.timestamp_begin,
  308. True and last_was_timestamp,
  309. False,
  310. )
  311. penultimate_was_timestamp = jnp.where((cur_len - self.begin_index) < 2, True, False)
  312. penultimate_was_timestamp = jnp.where(
  313. input_ids_k[cur_len - 2] >= self.timestamp_begin,
  314. True,
  315. penultimate_was_timestamp,
  316. )
  317. return jnp.where(
  318. last_was_timestamp,
  319. jnp.where(
  320. penultimate_was_timestamp > 0,
  321. scores_k.at[self.timestamp_begin :].set(-float("inf")),
  322. scores_k.at[: self.eos_token_id].set(-float("inf")),
  323. ),
  324. scores_k,
  325. )
  326. scores = jax.vmap(handle_pairs)(input_ids, scores)
  327. apply_max_initial_timestamp = jnp.where(cur_len == self.begin_index, True, False)
  328. apply_max_initial_timestamp = jnp.where(
  329. self.max_initial_timestamp_index is not None,
  330. True and apply_max_initial_timestamp,
  331. False,
  332. )
  333. last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
  334. scores = jnp.where(
  335. apply_max_initial_timestamp,
  336. scores.at[:, last_allowed + 1 :].set(-float("inf")),
  337. scores,
  338. )
  339. # if sum of probability over timestamps is above any other token, sample timestamp
  340. logprobs = jax.nn.log_softmax(scores, axis=-1)
  341. def handle_cumulative_probs(logprobs_k, scores_k):
  342. timestamp_logprob = jax.nn.logsumexp(logprobs_k[self.timestamp_begin :], axis=-1)
  343. max_text_token_logprob = jnp.max(logprobs_k[: self.timestamp_begin])
  344. return jnp.where(
  345. timestamp_logprob > max_text_token_logprob,
  346. scores_k.at[: self.timestamp_begin].set(-float("inf")),
  347. scores_k,
  348. )
  349. scores = jax.vmap(handle_cumulative_probs)(logprobs, scores)
  350. return scores
  351. class FlaxNoRepeatNGramLogitsProcessor(FlaxLogitsProcessor):
  352. r"""
  353. [`FlaxLogitsProcessor`] that enforces no repetition of n-grams. See
  354. [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
  355. Args:
  356. ngram_size (`int`):
  357. All ngrams of size `ngram_size` can only occur once.
  358. """
  359. def __init__(self, ngram_size: int):
  360. if not isinstance(ngram_size, int) or ngram_size <= 0:
  361. raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
  362. self.ngram_size = ngram_size
  363. def get_previous_ngrams(self, input_ids: jnp.ndarray, vocab_size: int, cur_len: int):
  364. """
  365. get a matrix of size (batch_size,) + (vocab_size,)*n (for n-grams) that
  366. represent the n-grams that occurred previously.
  367. The BCOO representation allow to store only the few non-zero entries, instead of the full (huge) matrix
  368. """
  369. batch_size, seq_len = input_ids.shape
  370. # number of n-grams in the whole sequence
  371. seq_ngrams = seq_len - (self.ngram_size - 1)
  372. # number of n-grams in the currently generated sequence
  373. cur_ngrams = cur_len - (self.ngram_size - 1)
  374. def body_fun(i, val):
  375. b = i % batch_size
  376. pos = i // batch_size
  377. return val.at[i].set(
  378. jnp.array(
  379. [
  380. b,
  381. ]
  382. + [jnp.array(input_ids)[b, pos + j] for j in range(self.ngram_size)]
  383. )
  384. )
  385. shape = (batch_size * seq_ngrams, self.ngram_size + 1)
  386. all_update_indices = jax.lax.fori_loop(
  387. 0, batch_size * cur_ngrams, body_fun, jnp.zeros(shape, dtype=input_ids.dtype)
  388. )
  389. # ignore the n-grams not yet generated
  390. data = (jnp.arange(batch_size * seq_ngrams) < batch_size * cur_ngrams).astype("float32")
  391. return sparse.BCOO((data, all_update_indices), shape=(batch_size,) + (vocab_size,) * self.ngram_size)
  392. def get_banned_tokens_mask(self, latest_tokens: jnp.ndarray, previous_ngrams) -> jnp.ndarray:
  393. """
  394. Determines which tokens must be banned given latest tokens and the previously seen
  395. ngrams.
  396. """
  397. @sparse.sparsify
  398. @jax.vmap
  399. def inner_fn(latest_tokens, previous_ngrams):
  400. return previous_ngrams[tuple(latest_tokens)]
  401. return sparse.bcoo_todense(inner_fn(latest_tokens, previous_ngrams))
  402. def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
  403. def true_fn():
  404. _, vocab_size = scores.shape
  405. # store the previously seen n-grams
  406. previous_ngrams = self.get_previous_ngrams(input_ids, vocab_size, cur_len)
  407. # get the n-1 last tokens that prefix the n-gram being generated
  408. latest_tokens = jnp.zeros((input_ids.shape[0], self.ngram_size - 1), dtype=input_ids.dtype)
  409. latest_tokens = jax.lax.dynamic_update_slice(
  410. latest_tokens,
  411. jax.lax.dynamic_slice(
  412. input_ids, (0, cur_len - (self.ngram_size - 1)), (input_ids.shape[0], (self.ngram_size - 1))
  413. ),
  414. (0, 0),
  415. )
  416. # compute the banned tokens, ie all the tokens that when added to the latest tokens lead to a n-gram that was previously generated
  417. banned_tokens_indices_mask = self.get_banned_tokens_mask(latest_tokens, previous_ngrams).astype("bool")
  418. return jnp.where(banned_tokens_indices_mask, -float("inf"), scores)
  419. output = jax.lax.cond((cur_len >= self.ngram_size - 1), true_fn, lambda: scores)
  420. return output