cache_utils.py 97 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122
  1. import copy
  2. import importlib.metadata
  3. import json
  4. import os
  5. from dataclasses import dataclass
  6. from typing import Any, Dict, List, Optional, Tuple, Union
  7. import torch
  8. from packaging import version
  9. from .configuration_utils import PretrainedConfig
  10. from .utils import (
  11. is_hqq_available,
  12. is_optimum_quanto_available,
  13. is_quanto_available,
  14. is_torchdynamo_compiling,
  15. logging,
  16. )
  17. from .utils.deprecation import deprecate_kwarg
  18. if is_hqq_available():
  19. from hqq.core.quantize import Quantizer as HQQQuantizer
  20. logger = logging.get_logger(__name__)
  21. class Cache(torch.nn.Module):
  22. """
  23. Base, abstract class for all caches. The actual data structure is specific to each subclass.
  24. """
  25. def __init__(self):
  26. super().__init__()
  27. def update(
  28. self,
  29. key_states: torch.Tensor,
  30. value_states: torch.Tensor,
  31. layer_idx: int,
  32. cache_kwargs: Optional[Dict[str, Any]] = None,
  33. ) -> Tuple[torch.Tensor, torch.Tensor]:
  34. """
  35. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  36. Parameters:
  37. key_states (`torch.Tensor`):
  38. The new key states to cache.
  39. value_states (`torch.Tensor`):
  40. The new value states to cache.
  41. layer_idx (`int`):
  42. The index of the layer to cache the states for.
  43. cache_kwargs (`Dict[str, Any]`, `optional`):
  44. Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
  45. cache to be created.
  46. Return:
  47. A tuple containing the updated key and value states.
  48. """
  49. raise NotImplementedError("Make sure to implement `update` in a subclass.")
  50. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  51. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  52. # TODO: deprecate this function in favor of `cache_position`
  53. raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
  54. # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
  55. # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
  56. # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
  57. # we change naming to be more explicit
  58. def get_max_length(self) -> Optional[int]:
  59. logger.warning_once(
  60. "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
  61. "Calling `get_max_cache()` will raise error from v4.48"
  62. )
  63. return self.get_max_cache_shape()
  64. def get_max_cache_shape(self) -> Optional[int]:
  65. """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
  66. raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
  67. def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
  68. """Given the sequence length of the new inputs, returns the usable length of the cache."""
  69. # Cache without size limit -> all cache is usable
  70. # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
  71. # length, we will need to evict part of the cache (and thus not all cache is usable)
  72. max_length = self.get_max_cache_shape()
  73. previous_seq_length = self.get_seq_length(layer_idx)
  74. if max_length is not None and previous_seq_length + new_seq_length > max_length:
  75. return max_length - new_seq_length
  76. return previous_seq_length
  77. def reorder_cache(self, beam_idx: torch.LongTensor):
  78. """Reorders the cache for beam search, given the selected beam indices."""
  79. for layer_idx in range(len(self.key_cache)):
  80. if self.key_cache[layer_idx] != []:
  81. device = self.key_cache[layer_idx].device
  82. self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
  83. if self.value_cache[layer_idx] != []:
  84. device = self.value_cache[layer_idx].device
  85. self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
  86. @property
  87. def seen_tokens(self):
  88. logger.warning_once(
  89. "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
  90. "model input instead."
  91. )
  92. if hasattr(self, "_seen_tokens"):
  93. return self._seen_tokens
  94. else:
  95. return None
  96. @dataclass
  97. class CacheConfig:
  98. """
  99. Base class for cache configs
  100. """
  101. cache_implementation: None
  102. @classmethod
  103. def from_dict(cls, config_dict, **kwargs):
  104. """
  105. Constructs a CacheConfig instance from a dictionary of parameters.
  106. Args:
  107. config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
  108. **kwargs: Additional keyword arguments to override dictionary values.
  109. Returns:
  110. CacheConfig: Instance of CacheConfig constructed from the dictionary.
  111. """
  112. config = cls(**config_dict)
  113. to_remove = []
  114. for key, value in kwargs.items():
  115. if hasattr(config, key):
  116. setattr(config, key, value)
  117. to_remove.append(key)
  118. for key in to_remove:
  119. kwargs.pop(key, None)
  120. return config
  121. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
  122. def to_json_file(self, json_file_path: Union[str, os.PathLike]):
  123. """
  124. Save this instance to a JSON file.
  125. Args:
  126. json_file_path (`str` or `os.PathLike`):
  127. Path to the JSON file in which this configuration instance's parameters will be saved.
  128. use_diff (`bool`, *optional*, defaults to `True`):
  129. If set to `True`, only the difference between the config instance and the default
  130. `QuantizationConfig()` is serialized to JSON file.
  131. """
  132. with open(json_file_path, "w", encoding="utf-8") as writer:
  133. config_dict = self.to_dict()
  134. json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
  135. writer.write(json_string)
  136. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
  137. def to_dict(self) -> Dict[str, Any]:
  138. """
  139. Serializes this instance to a Python dictionary. Returns:
  140. `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
  141. """
  142. return copy.deepcopy(self.__dict__)
  143. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
  144. def __iter__(self):
  145. """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
  146. for attr, value in copy.deepcopy(self.__dict__).items():
  147. yield attr, value
  148. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
  149. def __repr__(self):
  150. return f"{self.__class__.__name__} {self.to_json_string()}"
  151. def to_json_string(self):
  152. """
  153. Serializes this instance to a JSON formatted string.
  154. Returns:
  155. str: JSON formatted string representing the configuration instance.
  156. """
  157. return json.dumps(self.__dict__, indent=2) + "\n"
  158. # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
  159. def update(self, **kwargs):
  160. """
  161. Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
  162. returning all the unused kwargs.
  163. Args:
  164. kwargs (`Dict[str, Any]`):
  165. Dictionary of attributes to tentatively update this class.
  166. Returns:
  167. `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
  168. """
  169. to_remove = []
  170. for key, value in kwargs.items():
  171. if hasattr(self, key):
  172. setattr(self, key, value)
  173. to_remove.append(key)
  174. # Remove all the attributes that were updated, without modifying the input dict
  175. unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
  176. return unused_kwargs
  177. @dataclass
  178. class QuantizedCacheConfig(CacheConfig):
  179. """
  180. Configuration class for quantized cache settings.
  181. Attributes:
  182. backend (`str`, *optional*, defaults to `"quanto"`):
  183. Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
  184. nbits (`Optional[int]`, *optional*, defaults to 4):
  185. Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
  186. axis_key (`int`, *optional*, defaults to 0):
  187. Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
  188. axis_value (`int`, *optional*, defaults to 0):
  189. Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
  190. q_group_size (`Optional[int]`, *optional*, defaults to 64):
  191. Size of the quantization group, should be a divisor of the model's hidden dimension.
  192. Defaults to 64.
  193. residual_length (`Optional[int]`, *optional*, defaults to 128):
  194. Length of the residual cache which will always be stored in original presicion.
  195. Defaults to 128.
  196. compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
  197. The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
  198. device (`str`, *optional*, defaults to `"cpu"`):
  199. Device on which to perform computations, should be same as the model's device.
  200. """
  201. def __init__(
  202. self,
  203. backend: str = "quanto",
  204. nbits: Optional[int] = 4,
  205. axis_key: Optional[int] = 0,
  206. axis_value: Optional[int] = 0,
  207. q_group_size: Optional[int] = 64,
  208. residual_length: Optional[int] = 128,
  209. compute_dtype: Optional[torch.dtype] = torch.float16,
  210. device: Optional[str] = "cpu",
  211. ):
  212. self.backend = backend
  213. self.nbits = nbits
  214. self.axis_key = axis_key
  215. self.axis_value = axis_value
  216. self.q_group_size = q_group_size
  217. self.residual_length = residual_length
  218. self.compute_dtype = compute_dtype
  219. self.device = device
  220. def validate(self):
  221. """Validates if the arguments passed are correct"""
  222. incorrect_arg_msg = (
  223. "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
  224. "but found {found_value}"
  225. )
  226. # Check that the values are reasonable in general (nbits, axis)
  227. # Later in QuantizedCache init we check if they are supported for that particular backend
  228. if self.nbits not in [1, 2, 3, 4, 8]:
  229. raise ValueError(
  230. incorrect_arg_msg.format(
  231. key="nbits",
  232. correct_value="2 or 4 or 8",
  233. found_value=self.nbits,
  234. ),
  235. )
  236. if self.q_group_size <= 0:
  237. raise ValueError(
  238. incorrect_arg_msg.format(
  239. key="q_group_size",
  240. correct_value="a positive integer",
  241. found_value=self.q_group_size,
  242. ),
  243. )
  244. if self.residual_length < 0:
  245. raise ValueError(
  246. incorrect_arg_msg.format(
  247. key="residual_length",
  248. correct_value="a positive integer",
  249. found_value=self.residual_length,
  250. ),
  251. )
  252. if self.axis_key not in [0, 1, -1]:
  253. raise ValueError(
  254. incorrect_arg_msg.format(
  255. key="axis_key",
  256. correct_value="`1` or `0`, `-1`",
  257. found_value=self.axis_key,
  258. ),
  259. )
  260. if self.axis_value not in [0, 1, -1]:
  261. raise ValueError(
  262. incorrect_arg_msg.format(
  263. key="axis_value",
  264. correct_value="`1` or `0` or `-1`",
  265. found_value=self.axis_value,
  266. ),
  267. )
  268. @dataclass
  269. class StaticCacheConfig(CacheConfig):
  270. """
  271. Configuration class for static cache settings.
  272. """
  273. cache_implementation = "static"
  274. def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
  275. self.batch_size = batch_size
  276. self.max_cache_len = max_cache_len
  277. self.device = device
  278. def validate(self):
  279. """Validates if the arguments passed are correct"""
  280. incorrect_arg_msg = (
  281. "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
  282. "but found {found_value}"
  283. )
  284. if self.batch_size <= 0:
  285. raise ValueError(
  286. incorrect_arg_msg.format(
  287. key="batch_size",
  288. correct_value="> 0",
  289. found_value=self.batch_size,
  290. ),
  291. )
  292. if self.max_cache_len <= 0:
  293. raise ValueError(
  294. incorrect_arg_msg.format(
  295. key="max_cache_len",
  296. correct_value="> 0",
  297. found_value=self.max_cache_len,
  298. ),
  299. )
  300. class DynamicCache(Cache):
  301. """
  302. A cache that grows dynamically as more tokens are generated. This is the default for generative models.
  303. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
  304. `[batch_size, num_heads, seq_len, head_dim]`.
  305. Example:
  306. ```python
  307. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
  308. >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  309. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  310. >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
  311. >>> # Prepare a cache class and pass it to model's forward
  312. >>> past_key_values = DynamicCache()
  313. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  314. >>> outputs.past_key_values # access cache filled with key/values from generation
  315. DynamicCache()
  316. ```
  317. """
  318. @deprecate_kwarg("num_hidden_layers", version="4.47.0")
  319. def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
  320. super().__init__()
  321. self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
  322. self.key_cache: List[torch.Tensor] = []
  323. self.value_cache: List[torch.Tensor] = []
  324. def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
  325. """
  326. Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
  327. sequence length.
  328. """
  329. if layer_idx < len(self):
  330. return (self.key_cache[layer_idx], self.value_cache[layer_idx])
  331. else:
  332. raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
  333. def __iter__(self):
  334. """
  335. Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
  336. keys and values
  337. """
  338. for layer_idx in range(len(self)):
  339. yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
  340. def __len__(self):
  341. """
  342. Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
  343. to the number of layers in the model.
  344. """
  345. return len(self.key_cache)
  346. def update(
  347. self,
  348. key_states: torch.Tensor,
  349. value_states: torch.Tensor,
  350. layer_idx: int,
  351. cache_kwargs: Optional[Dict[str, Any]] = None,
  352. ) -> Tuple[torch.Tensor, torch.Tensor]:
  353. """
  354. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  355. Parameters:
  356. key_states (`torch.Tensor`):
  357. The new key states to cache.
  358. value_states (`torch.Tensor`):
  359. The new value states to cache.
  360. layer_idx (`int`):
  361. The index of the layer to cache the states for.
  362. cache_kwargs (`Dict[str, Any]`, `optional`):
  363. Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
  364. Return:
  365. A tuple containing the updated key and value states.
  366. """
  367. # Update the number of seen tokens
  368. if layer_idx == 0:
  369. self._seen_tokens += key_states.shape[-2]
  370. # Update the cache
  371. if len(self.key_cache) <= layer_idx:
  372. # There may be skipped layers, fill them with empty lists
  373. for _ in range(len(self.key_cache), layer_idx):
  374. self.key_cache.append([])
  375. self.value_cache.append([])
  376. self.key_cache.append(key_states)
  377. self.value_cache.append(value_states)
  378. elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
  379. self.key_cache[layer_idx] = key_states
  380. self.value_cache[layer_idx] = value_states
  381. else:
  382. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  383. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  384. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  385. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  386. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  387. # TODO: deprecate this function in favor of `cache_position`
  388. is_empty_layer = (
  389. len(self.key_cache) == 0 # no cache in any layer
  390. or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
  391. or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
  392. )
  393. layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
  394. return layer_seq_length
  395. def get_max_cache_shape(self) -> Optional[int]:
  396. """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
  397. return None
  398. def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
  399. """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
  400. backward compatibility."""
  401. legacy_cache = ()
  402. for layer_idx in range(len(self)):
  403. legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
  404. return legacy_cache
  405. @classmethod
  406. @deprecate_kwarg("num_hidden_layers", version="4.47.0")
  407. def from_legacy_cache(
  408. cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
  409. ) -> "DynamicCache":
  410. """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
  411. backward compatibility."""
  412. cache = cls()
  413. if past_key_values is not None:
  414. for layer_idx in range(len(past_key_values)):
  415. key_states, value_states = past_key_values[layer_idx]
  416. cache.update(key_states, value_states, layer_idx)
  417. return cache
  418. def crop(self, max_length: int):
  419. """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
  420. negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
  421. # In case it is negative
  422. if max_length < 0:
  423. max_length = self.get_seq_length() - abs(max_length)
  424. if self.get_seq_length() <= max_length:
  425. return
  426. self._seen_tokens = max_length
  427. for idx in range(len(self.key_cache)):
  428. if self.key_cache[idx] != []:
  429. self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
  430. self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
  431. @deprecate_kwarg("num_hidden_layers", version="4.47.0")
  432. def batch_split(
  433. self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
  434. ) -> List["DynamicCache"]:
  435. """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
  436. `_split_model_inputs()` in `generation.utils`"""
  437. out = []
  438. for i in range(0, full_batch_size, split_size):
  439. current_split = DynamicCache()
  440. current_split._seen_tokens = self._seen_tokens
  441. current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
  442. current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
  443. out.append(current_split)
  444. return out
  445. @classmethod
  446. @deprecate_kwarg("num_hidden_layers", version="4.47.0")
  447. def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
  448. """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
  449. `generation.utils`"""
  450. cache = cls()
  451. for idx in range(len(splits[0])):
  452. key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
  453. value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
  454. if key_cache != []:
  455. layer_keys = torch.cat(key_cache, dim=0)
  456. layer_values = torch.cat(value_cache, dim=0)
  457. cache.update(layer_keys, layer_values, idx)
  458. return cache
  459. def batch_repeat_interleave(self, repeats: int):
  460. """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
  461. for layer_idx in range(len(self)):
  462. self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
  463. self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
  464. def batch_select_indices(self, indices: torch.Tensor):
  465. """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
  466. for layer_idx in range(len(self)):
  467. self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
  468. self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
  469. class OffloadedCache(DynamicCache):
  470. """
  471. A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
  472. Useful for generating from models with very long context.
  473. In addition to the default CUDA stream, where all forward() computations happen,
  474. this class uses another stream, the prefetch stream, which it creates itself.
  475. Since scheduling of operations on separate streams happens independently, this class uses
  476. the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
  477. The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
  478. ensure the eviction is scheduled after all computations on that cache are finished.
  479. """
  480. def __init__(self) -> None:
  481. if not torch.cuda.is_available():
  482. raise RuntimeError("OffloadedCache can only be used with a GPU")
  483. super().__init__()
  484. self.original_device = []
  485. self.prefetch_stream = torch.cuda.Stream()
  486. self.beam_idx = None # used to delay beam search operations
  487. def prefetch_layer(self, layer_idx: int):
  488. "Starts prefetching the next layer cache"
  489. if layer_idx < len(self):
  490. with torch.cuda.stream(self.prefetch_stream):
  491. # Prefetch next layer tensors to GPU
  492. device = self.original_device[layer_idx]
  493. self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
  494. self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
  495. def evict_previous_layer(self, layer_idx: int):
  496. "Moves the previous layer cache to the CPU"
  497. if len(self) > 2:
  498. # We do it on the default stream so it occurs after all earlier computations on these tensors are done
  499. prev_layer_idx = (layer_idx - 1) % len(self)
  500. self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
  501. self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
  502. def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
  503. "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
  504. if layer_idx < len(self):
  505. # Evict the previous layer if necessary
  506. torch.cuda.current_stream().synchronize()
  507. self.evict_previous_layer(layer_idx)
  508. # Load current layer cache to its original device if not already there
  509. original_device = self.original_device[layer_idx]
  510. self.prefetch_stream.synchronize()
  511. key_tensor = self.key_cache[layer_idx]
  512. value_tensor = self.value_cache[layer_idx]
  513. # Now deal with beam search ops which were delayed
  514. if self.beam_idx is not None:
  515. self.beam_idx = self.beam_idx.to(original_device)
  516. key_tensor = key_tensor.index_select(0, self.beam_idx)
  517. value_tensor = value_tensor.index_select(0, self.beam_idx)
  518. # Prefetch the next layer
  519. self.prefetch_layer((layer_idx + 1) % len(self))
  520. return (key_tensor, value_tensor)
  521. else:
  522. raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
  523. def reorder_cache(self, beam_idx: torch.LongTensor):
  524. """Saves the beam indices and reorders the cache when the tensor is back to its device."""
  525. # We delay this operation until the tensors are back to their original
  526. # device because performing torch.index_select on the CPU is very slow
  527. del self.beam_idx
  528. self.beam_idx = beam_idx.clone()
  529. def update(
  530. self,
  531. key_states: torch.Tensor,
  532. value_states: torch.Tensor,
  533. layer_idx: int,
  534. cache_kwargs: Optional[Dict[str, Any]] = None,
  535. ) -> Tuple[torch.Tensor, torch.Tensor]:
  536. """
  537. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  538. Parameters:
  539. key_states (`torch.Tensor`):
  540. The new key states to cache.
  541. value_states (`torch.Tensor`):
  542. The new value states to cache.
  543. layer_idx (`int`):
  544. The index of the layer to cache the states for.
  545. cache_kwargs (`Dict[str, Any]`, `optional`):
  546. Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
  547. Return:
  548. A tuple containing the updated key and value states.
  549. """
  550. # Update the number of seen tokens
  551. if layer_idx == 0:
  552. self._seen_tokens += key_states.shape[-2]
  553. # Update the cache
  554. if len(self.key_cache) < layer_idx:
  555. raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
  556. elif len(self.key_cache) == layer_idx:
  557. self.key_cache.append(key_states)
  558. self.value_cache.append(value_states)
  559. self.original_device.append(key_states.device)
  560. self.evict_previous_layer(layer_idx)
  561. else:
  562. key_tensor, value_tensor = self[layer_idx]
  563. self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
  564. self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
  565. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  566. # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
  567. # if a method is not supposed to be supported in a subclass we should set it to None
  568. from_legacy_cache = None
  569. to_legacy_cache = None
  570. class QuantizedCache(DynamicCache):
  571. """
  572. A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
  573. It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
  574. The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
  575. original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
  576. quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
  577. It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
  578. Value in original precision states as a list of tensors, one for each layer. The size of each tensor
  579. is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
  580. """
  581. def __init__(self, cache_config: QuantizedCacheConfig) -> None:
  582. super().__init__()
  583. self._quantized_key_cache: List[torch.Tensor] = []
  584. self._quantized_value_cache: List[torch.Tensor] = []
  585. self.nbits = cache_config.nbits
  586. self.residual_length = cache_config.residual_length
  587. self.q_group_size = cache_config.q_group_size
  588. self.axis_key = cache_config.axis_key
  589. self.axis_value = cache_config.axis_value
  590. self.compute_dtype = cache_config.compute_dtype
  591. self.device = cache_config.device
  592. super().__init__()
  593. def update(
  594. self,
  595. key_states: torch.Tensor,
  596. value_states: torch.Tensor,
  597. layer_idx: int,
  598. cache_kwargs: Optional[Dict[str, Any]] = None,
  599. ) -> Tuple[torch.Tensor, torch.Tensor]:
  600. # Update the number of seen tokens
  601. if layer_idx == 0:
  602. self._seen_tokens += key_states.shape[-2]
  603. if len(self.key_cache) < layer_idx:
  604. raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
  605. elif len(self.key_cache) == layer_idx:
  606. self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
  607. self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
  608. self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
  609. self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
  610. keys_to_return, values_to_return = key_states, value_states
  611. else:
  612. dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
  613. dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
  614. keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
  615. values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
  616. keys_to_return = torch.cat(keys_to_return, dim=-2)
  617. values_to_return = torch.cat(values_to_return, dim=-2)
  618. if (
  619. self.key_cache[layer_idx].dim() == 4
  620. and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
  621. ):
  622. self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
  623. self._quantized_value_cache[layer_idx] = self._quantize(
  624. values_to_return.contiguous(), axis=self.axis_value
  625. )
  626. self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
  627. self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
  628. else:
  629. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  630. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  631. return keys_to_return, values_to_return
  632. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  633. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  634. if len(self.key_cache) <= layer_idx:
  635. return 0
  636. # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
  637. # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
  638. # this part of code otherwise fails when used to verify attn_weight shape in some models
  639. return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
  640. def _quantize(self, tensor, axis):
  641. """Quantizes a key/value using a defined quantization method."""
  642. raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
  643. def _dequantize(self, q_tensor):
  644. """Dequantizes back the tensor that was quantized by `self._quantize()`"""
  645. raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
  646. class QuantoQuantizedCache(QuantizedCache):
  647. """
  648. Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
  649. Parameters:
  650. cache_config (`QuantizedCacheConfig`):
  651. A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
  652. Example:
  653. ```python
  654. >>> # Run pip install quanto first if you don't have it yet
  655. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
  656. >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  657. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  658. >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
  659. >>> # Prepare a cache class and pass it to model's forward
  660. >>> cache_config = QuantizedCacheConfig(nbits=4)
  661. >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
  662. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  663. >>> outputs.past_key_values # access cache filled with key/values from generation
  664. QuantoQuantizedCache()
  665. ```
  666. """
  667. def __init__(self, cache_config: CacheConfig) -> None:
  668. super().__init__(cache_config)
  669. if is_optimum_quanto_available():
  670. from optimum.quanto import MaxOptimizer, qint2, qint4
  671. elif is_quanto_available():
  672. logger.warning_once(
  673. "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
  674. )
  675. quanto_version = version.parse(importlib.metadata.version("quanto"))
  676. if quanto_version < version.parse("0.2.0"):
  677. raise ImportError(
  678. f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
  679. f"Since quanto will be deprecated, please install optimum-quanto instead with `pip install -U optimum-quanto`"
  680. )
  681. from quanto import MaxOptimizer, qint2, qint4
  682. if self.nbits not in [2, 4]:
  683. raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
  684. if self.axis_key not in [0, -1]:
  685. raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
  686. if self.axis_value not in [0, -1]:
  687. raise ValueError(
  688. f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
  689. )
  690. self.qtype = qint4 if self.nbits == 4 else qint2
  691. self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
  692. def _quantize(self, tensor, axis):
  693. # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
  694. if is_optimum_quanto_available():
  695. from optimum.quanto import quantize_weight
  696. qtensor = quantize_weight(tensor, self.qtype, axis, self.q_group_size)
  697. return qtensor
  698. elif is_quanto_available():
  699. logger.warning_once(
  700. "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
  701. )
  702. from quanto import AffineQuantizer
  703. scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
  704. qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
  705. return qtensor
  706. def _dequantize(self, qtensor):
  707. return qtensor.dequantize()
  708. class HQQQuantizedCache(QuantizedCache):
  709. """
  710. Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
  711. Parameters:
  712. cache_config (`QuantizedCacheConfig`):
  713. A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
  714. Example:
  715. ```python
  716. >>> # Run pip install hqq first if you don't have it yet
  717. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
  718. >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  719. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  720. >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
  721. >>> # Prepare a cache class and pass it to model's forward
  722. >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
  723. >>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
  724. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  725. >>> outputs.past_key_values # access cache filled with key/values from generation
  726. HQQQuantizedCache()
  727. ```
  728. """
  729. def __init__(self, cache_config: CacheConfig) -> None:
  730. super().__init__(cache_config)
  731. if self.nbits not in [1, 2, 3, 4, 8]:
  732. raise ValueError(
  733. f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
  734. )
  735. if self.axis_key not in [0, 1]:
  736. raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
  737. if self.axis_value not in [0, 1]:
  738. raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
  739. self.quantizer = HQQQuantizer
  740. def _quantize(self, tensor, axis):
  741. qtensor, meta = self.quantizer.quantize(
  742. tensor,
  743. axis=axis,
  744. device=self.device,
  745. compute_dtype=self.compute_dtype,
  746. nbits=self.nbits,
  747. group_size=self.q_group_size,
  748. )
  749. meta["compute_dtype"] = self.compute_dtype
  750. self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
  751. return qtensor, meta
  752. def _dequantize(self, qtensor):
  753. quant_tensor, meta = qtensor
  754. tensor = self.quantizer.dequantize(quant_tensor, meta)
  755. return tensor
  756. class SinkCache(Cache):
  757. """
  758. A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
  759. generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
  760. tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
  761. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
  762. `[batch_size, num_heads, seq_len, head_dim]`.
  763. Parameters:
  764. window_length (`int`):
  765. The length of the context window.
  766. num_sink_tokens (`int`):
  767. The number of sink tokens. See the original paper for more information.
  768. Example:
  769. ```python
  770. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
  771. >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  772. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  773. >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
  774. >>> # Prepare a cache class and pass it to model's forward
  775. >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
  776. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  777. >>> outputs.past_key_values # access cache filled with key/values from generation
  778. SinkCache()
  779. ```
  780. """
  781. is_sliding = True
  782. def __init__(self, window_length: int, num_sink_tokens: int) -> None:
  783. super().__init__()
  784. self.key_cache: List[torch.Tensor] = []
  785. self.value_cache: List[torch.Tensor] = []
  786. self.window_length = window_length
  787. self.num_sink_tokens = num_sink_tokens
  788. self.cos_sin_rerotation_cache = {}
  789. self._cos_cache = None
  790. self._sin_cache = None
  791. self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
  792. @staticmethod
  793. def _rotate_half(x):
  794. x1 = x[..., : x.shape[-1] // 2]
  795. x2 = x[..., x.shape[-1] // 2 :]
  796. return torch.cat((-x2, x1), dim=-1)
  797. def _apply_key_rotary_pos_emb(
  798. self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
  799. ) -> torch.Tensor:
  800. rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
  801. return rotated_key_states
  802. def _get_rerotation_cos_sin(
  803. self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
  804. ) -> Tuple[torch.Tensor, torch.Tensor]:
  805. if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
  806. # Upcast to float32 temporarily for better accuracy
  807. cos = cos.to(torch.float32)
  808. sin = sin.to(torch.float32)
  809. # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
  810. original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
  811. shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
  812. original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
  813. shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
  814. rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
  815. rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
  816. self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
  817. rerotation_cos.to(key_states.dtype).unsqueeze(0),
  818. rerotation_sin.to(key_states.dtype).unsqueeze(0),
  819. )
  820. return self.cos_sin_rerotation_cache[key_states.shape[-2]]
  821. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  822. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  823. # TODO: deprecate this function in favor of `cache_position`
  824. # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
  825. if len(self.key_cache) <= layer_idx:
  826. return 0
  827. return self.key_cache[layer_idx].shape[-2]
  828. def get_max_cache_shape(self) -> Optional[int]:
  829. """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
  830. return self.window_length
  831. def update(
  832. self,
  833. key_states: torch.Tensor,
  834. value_states: torch.Tensor,
  835. layer_idx: int,
  836. cache_kwargs: Optional[Dict[str, Any]] = None,
  837. ) -> Tuple[torch.Tensor, torch.Tensor]:
  838. """
  839. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  840. Parameters:
  841. key_states (`torch.Tensor`):
  842. The new key states to cache.
  843. value_states (`torch.Tensor`):
  844. The new value states to cache.
  845. layer_idx (`int`):
  846. The index of the layer to cache the states for.
  847. cache_kwargs (`Dict[str, Any]`, `optional`):
  848. Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
  849. `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
  850. rotation as the tokens are shifted.
  851. Return:
  852. A tuple containing the updated key and value states.
  853. """
  854. # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
  855. # with partially rotated position embeddings, like Phi or Persimmon.
  856. sin = cache_kwargs.get("sin")
  857. cos = cache_kwargs.get("cos")
  858. partial_rotation_size = cache_kwargs.get("partial_rotation_size")
  859. using_rope = cos is not None and sin is not None
  860. # Update the number of seen tokens
  861. if layer_idx == 0:
  862. self._seen_tokens += key_states.shape[-2]
  863. # Update the sin/cos cache, which holds sin/cos values for all possible positions
  864. if using_rope and layer_idx == 0:
  865. # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
  866. # after all RoPE models have a llama-like cache utilization.
  867. if cos.dim() == 2:
  868. self._cos_cache = cos
  869. self._sin_cache = sin
  870. else:
  871. if self._cos_cache is None:
  872. self._cos_cache = cos[0, ...]
  873. self._sin_cache = sin[0, ...]
  874. elif self._cos_cache.shape[0] < self.window_length:
  875. self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
  876. self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
  877. # [bsz, num_heads, seq_len, head_dim]
  878. if len(self.key_cache) <= layer_idx:
  879. # Empty cache
  880. self.key_cache.append(key_states)
  881. self.value_cache.append(value_states)
  882. elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
  883. # Growing cache
  884. self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
  885. self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
  886. else:
  887. # Shifting cache
  888. keys_to_keep = self.key_cache[layer_idx][
  889. :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
  890. ]
  891. # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
  892. if using_rope:
  893. rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
  894. key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
  895. )
  896. if partial_rotation_size is not None:
  897. keys_to_keep, keys_pass = (
  898. keys_to_keep[..., :partial_rotation_size],
  899. keys_to_keep[..., partial_rotation_size:],
  900. )
  901. keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
  902. if partial_rotation_size is not None:
  903. keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
  904. # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
  905. sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
  906. self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
  907. sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
  908. values_to_keep = self.value_cache[layer_idx][
  909. :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
  910. ]
  911. self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
  912. return self.key_cache[layer_idx], self.value_cache[layer_idx]
  913. class StaticCache(Cache):
  914. """
  915. Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
  916. Parameters:
  917. config (`PretrainedConfig`):
  918. The configuration file defining the shape-related attributes required to initialize the static cache.
  919. batch_size (`int`):
  920. The batch size with which the model will be used. Note that a new instance must be instantiated if a
  921. smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
  922. max_cache_len (`int`):
  923. The maximum sequence length with which the model will be used.
  924. device (`torch.device` or `str`):
  925. The device on which the cache should be initialized. Should be the same as the layer.
  926. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
  927. The default `dtype` to use when initializing the layer.
  928. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
  929. Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
  930. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
  931. Example:
  932. ```python
  933. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
  934. >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
  935. >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
  936. >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
  937. >>> # Prepare a cache class and pass it to model's forward
  938. >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
  939. >>> max_generated_length = inputs.input_ids.shape[1] + 10
  940. >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
  941. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  942. >>> outputs.past_key_values # access cache filled with key/values from generation
  943. StaticCache()
  944. ```
  945. """
  946. # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
  947. def __init__(
  948. self,
  949. config: PretrainedConfig,
  950. batch_size: int = None,
  951. max_cache_len: int = None,
  952. device: torch.device = None,
  953. dtype: torch.dtype = torch.float32,
  954. max_batch_size: Optional[int] = None,
  955. layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
  956. ) -> None:
  957. super().__init__()
  958. if max_batch_size is not None:
  959. logger.warning_once(
  960. f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
  961. "v4.46. Use the more precisely named 'batch_size' argument instead."
  962. )
  963. self.batch_size = batch_size or max_batch_size
  964. self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
  965. # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
  966. self.head_dim = (
  967. config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
  968. )
  969. self.dtype = dtype
  970. self.num_key_value_heads = (
  971. config.num_attention_heads
  972. if getattr(config, "num_key_value_heads", None) is None
  973. else config.num_key_value_heads
  974. )
  975. self.key_cache: List[torch.Tensor] = []
  976. self.value_cache: List[torch.Tensor] = []
  977. # Note: There will be significant perf decrease if switching to use 5D tensors instead.
  978. cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
  979. for idx in range(config.num_hidden_layers):
  980. if layer_device_map is not None:
  981. layer_device = layer_device_map[idx]
  982. else:
  983. layer_device = device
  984. new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
  985. new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
  986. # Notes:
  987. # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
  988. # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
  989. # it is not needed anyway)
  990. # 2. `torch.export()` requires mutations to be registered as buffers.
  991. if not is_torchdynamo_compiling():
  992. self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
  993. self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
  994. new_layer_key_cache = getattr(self, f"key_cache_{idx}")
  995. new_layer_value_cache = getattr(self, f"value_cache_{idx}")
  996. torch._dynamo.mark_static_address(new_layer_key_cache)
  997. torch._dynamo.mark_static_address(new_layer_value_cache)
  998. self.key_cache.append(new_layer_key_cache)
  999. self.value_cache.append(new_layer_value_cache)
  1000. def update(
  1001. self,
  1002. key_states: torch.Tensor,
  1003. value_states: torch.Tensor,
  1004. layer_idx: int,
  1005. cache_kwargs: Optional[Dict[str, Any]] = None,
  1006. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1007. """
  1008. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  1009. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
  1010. Parameters:
  1011. key_states (`torch.Tensor`):
  1012. The new key states to cache.
  1013. value_states (`torch.Tensor`):
  1014. The new value states to cache.
  1015. layer_idx (`int`):
  1016. The index of the layer to cache the states for.
  1017. cache_kwargs (`Dict[str, Any]`, `optional`):
  1018. Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
  1019. to know how where to write in the cache.
  1020. Return:
  1021. A tuple containing the updated key and value states.
  1022. """
  1023. cache_position = cache_kwargs.get("cache_position")
  1024. k_out = self.key_cache[layer_idx]
  1025. v_out = self.value_cache[layer_idx]
  1026. if cache_position is None:
  1027. k_out.copy_(key_states)
  1028. v_out.copy_(value_states)
  1029. else:
  1030. # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
  1031. # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
  1032. # operation, that avoids copies and uses less memory.
  1033. try:
  1034. k_out.index_copy_(2, cache_position, key_states)
  1035. v_out.index_copy_(2, cache_position, value_states)
  1036. except NotImplementedError:
  1037. # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
  1038. k_out[:, :, cache_position] = key_states
  1039. v_out[:, :, cache_position] = value_states
  1040. return k_out, v_out
  1041. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  1042. """Returns the sequence length of the cached states that were seen by the model."""
  1043. # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
  1044. # limit the check to the first batch member and head dimension.
  1045. # TODO: deprecate this function in favor of `cache_position`
  1046. return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
  1047. def get_max_cache_shape(self) -> Optional[int]:
  1048. return self.max_cache_len
  1049. def reset(self):
  1050. """Resets the cache values while preserving the objects"""
  1051. for layer_idx in range(len(self.key_cache)):
  1052. # In-place ops prevent breaking the static address
  1053. self.key_cache[layer_idx].zero_()
  1054. self.value_cache[layer_idx].zero_()
  1055. class SlidingWindowCache(StaticCache):
  1056. """
  1057. Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
  1058. Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
  1059. if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
  1060. we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
  1061. The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
  1062. indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
  1063. tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
  1064. 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
  1065. 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
  1066. 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
  1067. We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
  1068. Parameters:
  1069. config (`PretrainedConfig`):
  1070. The configuration file defining the shape-related attributes required to initialize the static cache.
  1071. batch_size (`int`):
  1072. The batch size with which the model will be used. Note that a new instance must be instantiated if a
  1073. smaller batch size is used.
  1074. max_cache_len (`int`):
  1075. The maximum sequence length with which the model will be used.
  1076. device (`torch.device` or `str`):
  1077. The device on which the cache should be initialized. Should be the same as the layer.
  1078. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
  1079. The default `dtype` to use when initializing the layer.
  1080. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
  1081. Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
  1082. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
  1083. Example:
  1084. ```python
  1085. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache
  1086. >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
  1087. >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
  1088. >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt")
  1089. >>> # Prepare a cache class and pass it to model's forward
  1090. >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
  1091. >>> max_generated_length = inputs.input_ids.shape[1] + 10
  1092. >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
  1093. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  1094. >>> outputs.past_key_values # access cache filled with key/values from generation
  1095. SlidingWindowCache()
  1096. ```
  1097. """
  1098. is_sliding = True
  1099. # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
  1100. def __init__(
  1101. self,
  1102. config: PretrainedConfig,
  1103. batch_size: int = None,
  1104. max_cache_len: int = None,
  1105. device: torch.device = None,
  1106. dtype: torch.dtype = torch.float32,
  1107. max_batch_size: Optional[int] = None,
  1108. layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
  1109. ) -> None:
  1110. if not hasattr(config, "sliding_window") or config.sliding_window is None:
  1111. raise ValueError(
  1112. "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
  1113. "sliding window attention, please check if there is a `sliding_window` field in the model "
  1114. "config and it's not set to None."
  1115. )
  1116. max_cache_len = min(config.sliding_window, max_cache_len)
  1117. super().__init__(
  1118. config=config,
  1119. batch_size=batch_size,
  1120. max_cache_len=max_cache_len,
  1121. device=device,
  1122. dtype=dtype,
  1123. max_batch_size=max_batch_size,
  1124. layer_device_map=layer_device_map,
  1125. )
  1126. def update(
  1127. self,
  1128. key_states: torch.Tensor,
  1129. value_states: torch.Tensor,
  1130. layer_idx: int,
  1131. cache_kwargs: Optional[Dict[str, Any]] = None,
  1132. ) -> Tuple[torch.Tensor]:
  1133. cache_position = cache_kwargs.get("cache_position")
  1134. k_out = self.key_cache[layer_idx]
  1135. v_out = self.value_cache[layer_idx]
  1136. # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
  1137. if cache_position.shape[0] > self.max_cache_len:
  1138. k_out = key_states[:, :, -self.max_cache_len :, :]
  1139. v_out = value_states[:, :, -self.max_cache_len :, :]
  1140. # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
  1141. self.key_cache[layer_idx] += k_out
  1142. self.value_cache[layer_idx] += v_out
  1143. # we should return the whole states instead of k_out, v_out to take the whole prompt
  1144. # into consideration when building kv cache instead of just throwing away tokens outside of the window
  1145. return key_states, value_states
  1146. slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
  1147. cache_position = cache_position.clamp(0, self.max_cache_len - 1)
  1148. to_shift = cache_position >= self.max_cache_len - 1
  1149. indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
  1150. k_out = k_out[:, :, indices]
  1151. v_out = v_out[:, :, indices]
  1152. try:
  1153. k_out.index_copy_(2, cache_position, key_states)
  1154. v_out.index_copy_(2, cache_position, value_states)
  1155. except NotImplementedError:
  1156. # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
  1157. k_out[:, :, cache_position] = key_states
  1158. v_out[:, :, cache_position] = value_states
  1159. # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
  1160. self.key_cache[layer_idx].zero_()
  1161. self.value_cache[layer_idx].zero_()
  1162. self.key_cache[layer_idx] += k_out
  1163. self.value_cache[layer_idx] += v_out
  1164. return k_out, v_out
  1165. def get_max_cache_shape(self) -> Optional[int]:
  1166. return self.max_cache_len
  1167. def reset(self):
  1168. for layer_idx in range(len(self.key_cache)):
  1169. # In-place ops prevent breaking the static address
  1170. self.key_cache[layer_idx].zero_()
  1171. self.value_cache[layer_idx].zero_()
  1172. class EncoderDecoderCache(Cache):
  1173. """
  1174. Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
  1175. cross-attention caches.
  1176. Example:
  1177. ```python
  1178. >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
  1179. >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
  1180. >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
  1181. >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
  1182. >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
  1183. >>> self_attention_cache = DynamicCache()
  1184. >>> cross_attention_cache = DynamicCache()
  1185. >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
  1186. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  1187. >>> outputs.past_key_values # access cache filled with key/values from generation
  1188. EncoderDecoderCache()
  1189. ```
  1190. """
  1191. def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
  1192. super().__init__()
  1193. self.self_attention_cache = self_attention_cache
  1194. self.cross_attention_cache = cross_attention_cache
  1195. self.is_updated = {}
  1196. for layer_idx in range(len(cross_attention_cache.key_cache)):
  1197. self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
  1198. def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
  1199. """
  1200. Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
  1201. sequence length.
  1202. """
  1203. if layer_idx < len(self):
  1204. return (
  1205. self.self_attention_cache.key_cache[layer_idx],
  1206. self.self_attention_cache.value_cache[layer_idx],
  1207. self.cross_attention_cache.key_cache[layer_idx],
  1208. self.cross_attention_cache.value_cache[layer_idx],
  1209. )
  1210. else:
  1211. raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
  1212. def __len__(self):
  1213. """
  1214. Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
  1215. to the number of layers in the model.
  1216. """
  1217. return len(self.self_attention_cache)
  1218. def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
  1219. """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
  1220. legacy_cache = ()
  1221. if len(self.cross_attention_cache) > 0:
  1222. for self_attn, cross_attn in zip(
  1223. self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
  1224. ):
  1225. legacy_cache += (self_attn + cross_attn,)
  1226. else:
  1227. legacy_cache = self.self_attention_cache.to_legacy_cache()
  1228. return legacy_cache
  1229. @classmethod
  1230. def from_legacy_cache(
  1231. cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
  1232. ) -> "EncoderDecoderCache":
  1233. """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
  1234. cache = cls(
  1235. self_attention_cache=DynamicCache(),
  1236. cross_attention_cache=DynamicCache(),
  1237. )
  1238. if past_key_values is not None:
  1239. for layer_idx in range(len(past_key_values)):
  1240. key_states, value_states = past_key_values[layer_idx][:2]
  1241. cache.self_attention_cache.update(key_states, value_states, layer_idx)
  1242. if len(past_key_values[layer_idx]) > 2:
  1243. key_states, value_states = past_key_values[layer_idx][2:]
  1244. cache.cross_attention_cache.update(key_states, value_states, layer_idx)
  1245. cache.is_updated[layer_idx] = True
  1246. return cache
  1247. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  1248. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  1249. # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
  1250. return self.self_attention_cache.get_seq_length(layer_idx)
  1251. def reset(self):
  1252. if hasattr(self.self_attention_cache, "reset"):
  1253. self.self_attention_cache.reset()
  1254. if hasattr(self.cross_attention_cache, "reset"):
  1255. self.cross_attention_cache.reset()
  1256. elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
  1257. raise ValueError(
  1258. "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
  1259. "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
  1260. f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
  1261. f"{self.cross_attention_cache.__str__()} for the cross attention cache."
  1262. )
  1263. for layer_idx in self.is_updated:
  1264. self.is_updated[layer_idx] = False
  1265. def reorder_cache(self, beam_idx: torch.LongTensor):
  1266. """Reorders the cache for beam search, given the selected beam indices."""
  1267. self.self_attention_cache.reorder_cache(beam_idx)
  1268. self.cross_attention_cache.reorder_cache(beam_idx)
  1269. def check_dynamic_cache(self, method: str):
  1270. if not (
  1271. isinstance(self.self_attention_cache, DynamicCache)
  1272. and isinstance(self.cross_attention_cache, DynamicCache)
  1273. ):
  1274. raise ValueError(
  1275. f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
  1276. f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
  1277. )
  1278. # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
  1279. def crop(self, maximum_length: int):
  1280. """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
  1281. negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
  1282. self.check_dynamic_cache(self.crop.__name__)
  1283. self.self_attention_cache.crop(maximum_length)
  1284. def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
  1285. """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
  1286. `_split_model_inputs()` in `generation.utils`"""
  1287. self.check_dynamic_cache(self.batch_split.__name__)
  1288. self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
  1289. cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
  1290. out = []
  1291. for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
  1292. out.append(EncoderDecoderCache(self_attn, cross_attn))
  1293. return out
  1294. @classmethod
  1295. def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
  1296. """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
  1297. `generation.utils`"""
  1298. self_attention_cache = DynamicCache()
  1299. cross_attention_cache = DynamicCache()
  1300. for idx in range(len(splits[0])):
  1301. layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
  1302. layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
  1303. self_attention_cache.update(layer_keys, layer_values, idx)
  1304. layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
  1305. layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
  1306. cross_attention_cache.update(layer_keys, layer_values, idx)
  1307. return cls(self_attention_cache, cross_attention_cache)
  1308. def batch_repeat_interleave(self, repeats: int):
  1309. """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
  1310. self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
  1311. self.self_attention_cache.batch_repeat_interleave(repeats)
  1312. self.cross_attention_cache.batch_repeat_interleave(repeats)
  1313. def batch_select_indices(self, indices: torch.Tensor):
  1314. """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
  1315. self.check_dynamic_cache(self.batch_select_indices.__name__)
  1316. self.self_attention_cache.batch_select_indices(indices)
  1317. self.cross_attention_cache.batch_select_indices(indices)
  1318. class HybridCache(Cache):
  1319. """
  1320. Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
  1321. and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
  1322. and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
  1323. Parameters:
  1324. config (`PretrainedConfig):
  1325. The configuration file defining the shape-related attributes required to initialize the static cache.
  1326. batch_size (`int`):
  1327. The batch size with which the model will be used. Note that a new instance must be instantiated if a
  1328. smaller batch size is used.
  1329. max_cache_len (`int`):
  1330. The maximum sequence length with which the model will be used.
  1331. device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
  1332. The device on which the cache should be initialized. Should be the same as the layer.
  1333. dtype (torch.dtype, *optional*, defaults to `torch.float32`):
  1334. The default `dtype` to use when initializing the layer.
  1335. layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
  1336. Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
  1337. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
  1338. Example:
  1339. ```python
  1340. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
  1341. >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
  1342. >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
  1343. >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
  1344. >>> # Prepare a cache class and pass it to model's forward
  1345. >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
  1346. >>> max_generated_length = inputs.input_ids.shape[1] + 10
  1347. >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
  1348. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  1349. >>> outputs.past_key_values # access cache filled with key/values from generation
  1350. HybridCache()
  1351. ```
  1352. """
  1353. # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
  1354. def __init__(
  1355. self,
  1356. config: PretrainedConfig,
  1357. batch_size: int = None,
  1358. max_cache_len: int = None,
  1359. device: Union[torch.device, str] = "cpu",
  1360. dtype: torch.dtype = torch.float32,
  1361. max_batch_size: Optional[int] = None,
  1362. layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
  1363. ) -> None:
  1364. super().__init__()
  1365. if max_batch_size is not None:
  1366. logger.warning_once(
  1367. f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
  1368. "v4.46. Use the more precisely named 'batch_size' argument instead."
  1369. )
  1370. if not hasattr(config, "sliding_window") or config.sliding_window is None:
  1371. raise ValueError(
  1372. "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
  1373. "sliding window attention, please check if there is a `sliding_window` field in the model "
  1374. "config and it's not set to None."
  1375. )
  1376. self.max_cache_len = max_cache_len
  1377. self.batch_size = batch_size or max_batch_size
  1378. # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
  1379. self.head_dim = (
  1380. config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
  1381. )
  1382. self.dtype = dtype
  1383. self.num_key_value_heads = (
  1384. config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
  1385. )
  1386. self.is_sliding = torch.tensor(
  1387. [not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
  1388. )
  1389. self.key_cache: List[torch.Tensor] = []
  1390. self.value_cache: List[torch.Tensor] = []
  1391. global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
  1392. sliding_cache_shape = (
  1393. self.batch_size,
  1394. self.num_key_value_heads,
  1395. min(config.sliding_window, max_cache_len),
  1396. self.head_dim,
  1397. )
  1398. for i in range(config.num_hidden_layers):
  1399. if layer_device_map is not None:
  1400. layer_device = layer_device_map[i]
  1401. else:
  1402. layer_device = device
  1403. # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
  1404. # breaks when updating the cache.
  1405. cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
  1406. new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
  1407. new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
  1408. torch._dynamo.mark_static_address(new_layer_key_cache)
  1409. torch._dynamo.mark_static_address(new_layer_value_cache)
  1410. self.key_cache.append(new_layer_key_cache)
  1411. self.value_cache.append(new_layer_value_cache)
  1412. def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
  1413. if cache_position.shape[0] > max_cache_len:
  1414. k_out = key_states[:, :, -max_cache_len:, :]
  1415. v_out = value_states[:, :, -max_cache_len:, :]
  1416. # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
  1417. self.key_cache[layer_idx] += k_out
  1418. self.value_cache[layer_idx] += v_out
  1419. # we should return the whole states instead of k_out, v_out to take the whole prompt
  1420. # into consideration when building kv cache instead of just throwing away tokens outside of the window
  1421. return key_states, value_states
  1422. slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
  1423. cache_position = cache_position.clamp(0, max_cache_len - 1)
  1424. to_shift = cache_position >= max_cache_len - 1
  1425. indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
  1426. k_out = k_out[:, :, indices]
  1427. v_out = v_out[:, :, indices]
  1428. k_out[:, :, cache_position] = key_states
  1429. v_out[:, :, cache_position] = value_states
  1430. # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
  1431. self.key_cache[layer_idx].zero_()
  1432. self.value_cache[layer_idx].zero_()
  1433. self.key_cache[layer_idx] += k_out
  1434. self.value_cache[layer_idx] += v_out
  1435. return k_out, v_out
  1436. def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
  1437. k_out[:, :, cache_position] = key_states
  1438. v_out[:, :, cache_position] = value_states
  1439. self.key_cache[layer_idx] = k_out
  1440. self.value_cache[layer_idx] = v_out
  1441. return k_out, v_out
  1442. def update(
  1443. self,
  1444. key_states: torch.Tensor,
  1445. value_states: torch.Tensor,
  1446. layer_idx: int,
  1447. cache_kwargs: Optional[Dict[str, Any]] = None,
  1448. ) -> Tuple[torch.Tensor]:
  1449. cache_position = cache_kwargs.get("cache_position")
  1450. sliding_window = cache_kwargs.get("sliding_window")
  1451. k_out = self.key_cache[layer_idx]
  1452. v_out = self.value_cache[layer_idx]
  1453. if sliding_window:
  1454. update_fn = self._sliding_update
  1455. else:
  1456. update_fn = self._static_update
  1457. return update_fn(
  1458. cache_position,
  1459. layer_idx,
  1460. key_states,
  1461. value_states,
  1462. k_out,
  1463. v_out,
  1464. k_out.shape[2],
  1465. )
  1466. def get_max_cache_shape(self) -> Optional[int]:
  1467. return self.max_cache_len
  1468. def get_seq_length(self, layer_idx: Optional[int] = 0):
  1469. # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
  1470. # limit the check to the first batch member and head dimension.
  1471. # TODO: deprecate this function in favor of `cache_position`
  1472. if layer_idx != 0:
  1473. raise ValueError(
  1474. "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
  1475. "Using the `layer_idx` argument is not supported."
  1476. )
  1477. return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
  1478. def reset(self):
  1479. """Resets the cache values while preserving the objects"""
  1480. for layer_idx in range(len(self.key_cache)):
  1481. # In-place ops prevent breaking the static address
  1482. self.key_cache[layer_idx].zero_()
  1483. self.value_cache[layer_idx].zero_()
  1484. class MambaCache:
  1485. """
  1486. Cache for mamba model which does not have attention mechanism and key value states.
  1487. Arguments:
  1488. config (`PretrainedConfig):
  1489. The configuration file defining the shape-related attributes required to initialize the static cache.
  1490. batch_size (`int`):
  1491. The batch size with which the model will be used. Note that a new instance must be instantiated if a
  1492. smaller batch size is used.
  1493. dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
  1494. The default `dtype` to use when initializing the layer.
  1495. device (`torch.device` or `str`, *optional*):
  1496. The device on which the cache should be initialized. Should be the same as the layer.
  1497. Attributes:
  1498. dtype: (`torch.dtype`):
  1499. The default `dtype` used to initializing the cache.
  1500. intermediate_size: (`int`):
  1501. Model's intermediate_size taken from config.
  1502. ssm_state_size: (`int`):
  1503. Model's state_size taken from config.
  1504. conv_kernel_size: (`int`):
  1505. Model's convolution kernel size taken from config
  1506. conv_states: (`torch.Tensor`):
  1507. A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
  1508. ssm_states: (`torch.Tensor`):
  1509. A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
  1510. Example:
  1511. ```python
  1512. >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
  1513. >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
  1514. >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
  1515. >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
  1516. >>> # Prepare a cache class and pass it to model's forward
  1517. >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
  1518. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  1519. >>> outputs.past_key_values
  1520. MambaCache()
  1521. ```
  1522. """
  1523. # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
  1524. def __init__(
  1525. self,
  1526. config: PretrainedConfig,
  1527. batch_size: int = None,
  1528. dtype: torch.dtype = torch.float16,
  1529. device: Optional[Union[torch.device, str]] = None,
  1530. max_batch_size: Optional[int] = None,
  1531. ):
  1532. if max_batch_size is not None:
  1533. logger.warning_once(
  1534. f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
  1535. "v4.46. Use the more precisely named 'batch_size' argument instead."
  1536. )
  1537. self.dtype = dtype
  1538. self.batch_size = batch_size or max_batch_size
  1539. self.intermediate_size = config.intermediate_size
  1540. self.ssm_state_size = config.state_size
  1541. self.conv_kernel_size = config.conv_kernel
  1542. self.conv_states: torch.Tensor = torch.zeros(
  1543. config.num_hidden_layers,
  1544. self.batch_size,
  1545. self.intermediate_size,
  1546. self.conv_kernel_size,
  1547. device=device,
  1548. dtype=dtype,
  1549. )
  1550. self.ssm_states: torch.Tensor = torch.zeros(
  1551. config.num_hidden_layers,
  1552. self.batch_size,
  1553. self.intermediate_size,
  1554. self.ssm_state_size,
  1555. device=device,
  1556. dtype=dtype,
  1557. )
  1558. torch._dynamo.mark_static_address(self.conv_states)
  1559. torch._dynamo.mark_static_address(self.ssm_states)
  1560. def update_conv_state(
  1561. self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
  1562. ) -> torch.Tensor:
  1563. conv_state = self.conv_states[layer_idx]
  1564. cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
  1565. conv_state = conv_state.roll(shifts=-1, dims=-1)
  1566. conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
  1567. self.conv_states[layer_idx].zero_()
  1568. self.conv_states[layer_idx] += conv_state
  1569. return self.conv_states[layer_idx]
  1570. def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
  1571. self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
  1572. return self.ssm_states[layer_idx]
  1573. def reset(self):
  1574. self.conv_states.zero_()
  1575. self.ssm_states.zero_()
  1576. class OffloadedStaticCache(StaticCache):
  1577. """
  1578. Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
  1579. another device.
  1580. Args:
  1581. config (`PretrainedConfig):
  1582. The configuration file defining the shape-related attributes required to initialize
  1583. the static cache.
  1584. max_batch_size (`int`):
  1585. The maximum batch size with which the model will be used.
  1586. max_cache_len (`int`):
  1587. The maximum sequence length with which the model will be used.
  1588. device (`Union[str, torch.device]`):
  1589. The device on which the cache should be initialized. Should be the same as the
  1590. layer device.
  1591. dtype (`torch.dtype`, *optional*):
  1592. The default `dtype` to use when initializing the cache.
  1593. offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
  1594. The device to offload to. Defaults to CPU.
  1595. Attributes:
  1596. key_cache (`List[torch.Tensor]`):
  1597. Off-loaded key cache tensors. First one will be on device, where-as the others are
  1598. off-loaded.
  1599. value_cache (`List[torch.Tensor]`):
  1600. Off-loaded value cache tensors. First one will be on device, where-as the others are
  1601. off-loaded.
  1602. max_batch_size (`int`):
  1603. The maximum batch size with which this cache can be used.
  1604. max_cache_len (`int`):
  1605. The maximum sequence length with which this cache can be used.
  1606. device (`torch.device`):
  1607. The device on which the cache is used.
  1608. offload_device (`torch.device`):
  1609. The device used to offload to.
  1610. dtype (`torch.dtype`):
  1611. The `dtype` used to initializing the cache.
  1612. Example:
  1613. ```python
  1614. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
  1615. >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
  1616. >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
  1617. >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
  1618. >>> # Prepare a cache class and pass it to model's forward
  1619. >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
  1620. >>> max_generated_length = inputs.input_ids.shape[1] + 10
  1621. >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
  1622. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  1623. >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
  1624. ```
  1625. """
  1626. def __init__(
  1627. self,
  1628. config: PretrainedConfig,
  1629. max_batch_size: int,
  1630. max_cache_len: Optional[int],
  1631. device: Union[str, torch.device],
  1632. dtype: Optional[torch.dtype] = None,
  1633. offload_device: Union[str, torch.device] = torch.device("cpu"),
  1634. ) -> None:
  1635. self.max_batch_size = max_batch_size
  1636. self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
  1637. self.device = torch.device(device)
  1638. self.offload_device = torch.device(offload_device)
  1639. self.dtype = dtype if dtype is not None else torch.float32
  1640. # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
  1641. head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
  1642. num_key_value_heads = (
  1643. config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
  1644. )
  1645. cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
  1646. # Create offloaded CPU tensors.
  1647. self.key_cache: List[torch.Tensor] = []
  1648. self.value_cache: List[torch.Tensor] = []
  1649. for i in range(config.num_hidden_layers):
  1650. # First layer is always on-device.
  1651. device = self.device if i == 0 else self.offload_device
  1652. key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device)
  1653. self.key_cache.append(key_cache)
  1654. self.value_cache.append(value_cache)
  1655. # Create device tensors.
  1656. self._device_key_cache: List[torch.Tensor] = []
  1657. self._device_value_cache: List[torch.Tensor] = []
  1658. for i in range(2):
  1659. key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device)
  1660. self._device_key_cache.append(key_cache)
  1661. self._device_value_cache.append(value_cache)
  1662. # For backwards compatibility.
  1663. # TODO(gante): Remove this.
  1664. self._seen_tokens = 0
  1665. # Create new CUDA stream for parallel prefetching.
  1666. self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None
  1667. def update(
  1668. self,
  1669. key_states: torch.Tensor,
  1670. value_states: torch.Tensor,
  1671. layer_idx: int,
  1672. cache_kwargs: Optional[Dict[str, Any]] = None,
  1673. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1674. """
  1675. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  1676. It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
  1677. Parameters:
  1678. key_states (`torch.Tensor`):
  1679. The new key states to cache.
  1680. value_states (`torch.Tensor`):
  1681. The new value states to cache.
  1682. layer_idx (`int`):
  1683. The index of the layer to cache the states for.
  1684. cache_kwargs (`Dict[str, Any]`, *optional*):
  1685. Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the
  1686. `cache_position` input to know how where to write in the cache.
  1687. Return:
  1688. A tuple containing the updated key and value states.
  1689. """
  1690. if layer_idx == 0:
  1691. # Update seen tokens.
  1692. # TODO(gante): Remove this.
  1693. self._seen_tokens += key_states.shape[-2]
  1694. # Always there.
  1695. k_out = self.key_cache[0]
  1696. v_out = self.value_cache[0]
  1697. else:
  1698. # Wait for prefetch stream.
  1699. if self._prefetch_stream is not None:
  1700. torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream)
  1701. k_out = self._device_key_cache[layer_idx & 1]
  1702. v_out = self._device_value_cache[layer_idx & 1]
  1703. self._prefetch_layer(layer_idx + 1)
  1704. cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
  1705. if cache_position is None:
  1706. k_out.copy_(key_states)
  1707. v_out.copy_(value_states)
  1708. # Copy the values to the offloaded device as well.
  1709. if layer_idx == 0:
  1710. self.key_cache[layer_idx].copy_(key_states.to(self.offload_device))
  1711. self.value_cache[layer_idx].copy_(value_states.to(self.offload_device))
  1712. else:
  1713. # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
  1714. # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does
  1715. # explicitly an in-place operation, that avoids copies and uses less memory.
  1716. try:
  1717. k_out.index_copy_(2, cache_position, key_states)
  1718. v_out.index_copy_(2, cache_position, value_states)
  1719. except NotImplementedError:
  1720. # The operator 'aten::index_copy.out' is not currently implemented for the MPS
  1721. # device.
  1722. k_out[:, :, cache_position] = key_states
  1723. v_out[:, :, cache_position] = value_states
  1724. # Copy the values to the offloaded device as well.
  1725. if layer_idx != 0:
  1726. cache_position = cache_position.to(self.offload_device)
  1727. key_states = key_states.to(self.offload_device)
  1728. value_states = value_states.to(self.offload_device)
  1729. try:
  1730. self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
  1731. self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
  1732. except NotImplementedError:
  1733. # The operator 'aten::index_copy.out' is not currently implemented for the MPS
  1734. # device.
  1735. self.key_cache[layer_idx][:, :, cache_position] = key_states
  1736. self.value_cache[layer_idx][:, :, cache_position] = value_states
  1737. return k_out, v_out
  1738. def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
  1739. """Returns the sequence length of the cached states that were seen by the model."""
  1740. # TODO(gante): Remove this.
  1741. return self._seen_tokens
  1742. def get_max_cache_shape(self) -> Optional[int]:
  1743. """Returns the maximum sequence length of the cached states."""
  1744. return self.max_cache_len
  1745. def reset(self) -> None:
  1746. """Resets the cache values while preserving the objects."""
  1747. # For backwards compatibility.
  1748. # TODO(gante): Remove this.
  1749. self._seen_tokens = 0
  1750. # Zero out cache.
  1751. for layer_idx in range(len(self.key_cache)):
  1752. # In-place ops prevent breaking the static address.
  1753. self.key_cache[layer_idx].zero_()
  1754. self.value_cache[layer_idx].zero_()
  1755. @property
  1756. def seen_tokens(self) -> int:
  1757. # For backwards compatibility.
  1758. # TODO(gante): Remove this.
  1759. return self._seen_tokens
  1760. def _create_key_value_cache_tensors(
  1761. self, shape: Tuple[int, ...], device: torch.device
  1762. ) -> Tuple[torch.Tensor, torch.Tensor]:
  1763. """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static
  1764. addresses for non-CPU tensors.
  1765. Args:
  1766. shape (`Tuple[int, ...]`): Shape.
  1767. device (`torch.device`): Device.
  1768. Returns:
  1769. Key and value cache tensors as a tuple.
  1770. """
  1771. is_cpu_device = device == torch.device("cpu")
  1772. key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
  1773. value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
  1774. # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
  1775. # preventing compiled graph breaks when updating the cache.
  1776. torch._dynamo.mark_static_address(key_cache)
  1777. torch._dynamo.mark_static_address(value_cache)
  1778. return key_cache, value_cache
  1779. def _prefetch_layer(self, layer_idx: int) -> None:
  1780. """Prefetch a layer to the device. Needs to be called in order of layer indices."""
  1781. # Don't fetch layers that do not exist.
  1782. if layer_idx >= len(self.key_cache):
  1783. return
  1784. # Alternate between two on-device caches.
  1785. if self._prefetch_stream is not None:
  1786. with torch.cuda.stream(self._prefetch_stream):
  1787. self._prefetch_layer_in_context(layer_idx)
  1788. else:
  1789. self._prefetch_layer_in_context(layer_idx)
  1790. def _prefetch_layer_in_context(self, layer_idx: int) -> None:
  1791. """Performs the actual copy of the layer to device cache."""
  1792. self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
  1793. self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)