| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122 |
- import copy
- import importlib.metadata
- import json
- import os
- from dataclasses import dataclass
- from typing import Any, Dict, List, Optional, Tuple, Union
- import torch
- from packaging import version
- from .configuration_utils import PretrainedConfig
- from .utils import (
- is_hqq_available,
- is_optimum_quanto_available,
- is_quanto_available,
- is_torchdynamo_compiling,
- logging,
- )
- from .utils.deprecation import deprecate_kwarg
- if is_hqq_available():
- from hqq.core.quantize import Quantizer as HQQQuantizer
- logger = logging.get_logger(__name__)
- class Cache(torch.nn.Module):
- """
- Base, abstract class for all caches. The actual data structure is specific to each subclass.
- """
- def __init__(self):
- super().__init__()
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
- cache to be created.
- Return:
- A tuple containing the updated key and value states.
- """
- raise NotImplementedError("Make sure to implement `update` in a subclass.")
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # TODO: deprecate this function in favor of `cache_position`
- raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")
- # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
- # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
- # infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
- # we change naming to be more explicit
- def get_max_length(self) -> Optional[int]:
- logger.warning_once(
- "`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. "
- "Calling `get_max_cache()` will raise error from v4.48"
- )
- return self.get_max_cache_shape()
- def get_max_cache_shape(self) -> Optional[int]:
- """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
- raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.")
- def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int:
- """Given the sequence length of the new inputs, returns the usable length of the cache."""
- # Cache without size limit -> all cache is usable
- # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache
- # length, we will need to evict part of the cache (and thus not all cache is usable)
- max_length = self.get_max_cache_shape()
- previous_seq_length = self.get_seq_length(layer_idx)
- if max_length is not None and previous_seq_length + new_seq_length > max_length:
- return max_length - new_seq_length
- return previous_seq_length
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Reorders the cache for beam search, given the selected beam indices."""
- for layer_idx in range(len(self.key_cache)):
- if self.key_cache[layer_idx] != []:
- device = self.key_cache[layer_idx].device
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
- if self.value_cache[layer_idx] != []:
- device = self.value_cache[layer_idx].device
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
- @property
- def seen_tokens(self):
- logger.warning_once(
- "The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` "
- "model input instead."
- )
- if hasattr(self, "_seen_tokens"):
- return self._seen_tokens
- else:
- return None
- @dataclass
- class CacheConfig:
- """
- Base class for cache configs
- """
- cache_implementation: None
- @classmethod
- def from_dict(cls, config_dict, **kwargs):
- """
- Constructs a CacheConfig instance from a dictionary of parameters.
- Args:
- config_dict (Dict[str, Any]): Dictionary containing configuration parameters.
- **kwargs: Additional keyword arguments to override dictionary values.
- Returns:
- CacheConfig: Instance of CacheConfig constructed from the dictionary.
- """
- config = cls(**config_dict)
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(config, key):
- setattr(config, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
- return config
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_json_file
- def to_json_file(self, json_file_path: Union[str, os.PathLike]):
- """
- Save this instance to a JSON file.
- Args:
- json_file_path (`str` or `os.PathLike`):
- Path to the JSON file in which this configuration instance's parameters will be saved.
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default
- `QuantizationConfig()` is serialized to JSON file.
- """
- with open(json_file_path, "w", encoding="utf-8") as writer:
- config_dict = self.to_dict()
- json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
- writer.write(json_string)
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.to_dict
- def to_dict(self) -> Dict[str, Any]:
- """
- Serializes this instance to a Python dictionary. Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
- """
- return copy.deepcopy(self.__dict__)
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__iter__
- def __iter__(self):
- """allows `dict(obj)` for situations where obj may be a dict or QuantizationConfigMixin"""
- for attr, value in copy.deepcopy(self.__dict__).items():
- yield attr, value
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.__repr__
- def __repr__(self):
- return f"{self.__class__.__name__} {self.to_json_string()}"
- def to_json_string(self):
- """
- Serializes this instance to a JSON formatted string.
- Returns:
- str: JSON formatted string representing the configuration instance.
- """
- return json.dumps(self.__dict__, indent=2) + "\n"
- # Copied from transformers.utils.quantization_config.QuantizationConfigMixin.update
- def update(self, **kwargs):
- """
- Updates attributes of this class instance with attributes from `kwargs` if they match existing attributes,
- returning all the unused kwargs.
- Args:
- kwargs (`Dict[str, Any]`):
- Dictionary of attributes to tentatively update this class.
- Returns:
- `Dict[str, Any]`: Dictionary containing all the key-value pairs that were not used to update the instance.
- """
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(self, key):
- setattr(self, key, value)
- to_remove.append(key)
- # Remove all the attributes that were updated, without modifying the input dict
- unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
- return unused_kwargs
- @dataclass
- class QuantizedCacheConfig(CacheConfig):
- """
- Configuration class for quantized cache settings.
- Attributes:
- backend (`str`, *optional*, defaults to `"quanto"`):
- Backend to use when performing quantization, Can be one of [`quanto`, `HQQ`]
- nbits (`Optional[int]`, *optional*, defaults to 4):
- Number of bits, can be 2 or 4 for the `quanto` backend and one of [1, 2, 3, 4, 8] for the `HQQ` backend. Defaults to 2.
- axis_key (`int`, *optional*, defaults to 0):
- Axis over which to perform grouping for the key tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
- axis_value (`int`, *optional*, defaults to 0):
- Axis over which to perform grouping for the value tensors. Can be [0, -1] for `quanto` backend and [0, 1] for `HQQ` backend.
- q_group_size (`Optional[int]`, *optional*, defaults to 64):
- Size of the quantization group, should be a divisor of the model's hidden dimension.
- Defaults to 64.
- residual_length (`Optional[int]`, *optional*, defaults to 128):
- Length of the residual cache which will always be stored in original presicion.
- Defaults to 128.
- compute_dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
- The defualt dtype used for computations in the model. Keys and Values will be cast to this dtype after dequantization.
- device (`str`, *optional*, defaults to `"cpu"`):
- Device on which to perform computations, should be same as the model's device.
- """
- def __init__(
- self,
- backend: str = "quanto",
- nbits: Optional[int] = 4,
- axis_key: Optional[int] = 0,
- axis_value: Optional[int] = 0,
- q_group_size: Optional[int] = 64,
- residual_length: Optional[int] = 128,
- compute_dtype: Optional[torch.dtype] = torch.float16,
- device: Optional[str] = "cpu",
- ):
- self.backend = backend
- self.nbits = nbits
- self.axis_key = axis_key
- self.axis_value = axis_value
- self.q_group_size = q_group_size
- self.residual_length = residual_length
- self.compute_dtype = compute_dtype
- self.device = device
- def validate(self):
- """Validates if the arguments passed are correct"""
- incorrect_arg_msg = (
- "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
- "but found {found_value}"
- )
- # Check that the values are reasonable in general (nbits, axis)
- # Later in QuantizedCache init we check if they are supported for that particular backend
- if self.nbits not in [1, 2, 3, 4, 8]:
- raise ValueError(
- incorrect_arg_msg.format(
- key="nbits",
- correct_value="2 or 4 or 8",
- found_value=self.nbits,
- ),
- )
- if self.q_group_size <= 0:
- raise ValueError(
- incorrect_arg_msg.format(
- key="q_group_size",
- correct_value="a positive integer",
- found_value=self.q_group_size,
- ),
- )
- if self.residual_length < 0:
- raise ValueError(
- incorrect_arg_msg.format(
- key="residual_length",
- correct_value="a positive integer",
- found_value=self.residual_length,
- ),
- )
- if self.axis_key not in [0, 1, -1]:
- raise ValueError(
- incorrect_arg_msg.format(
- key="axis_key",
- correct_value="`1` or `0`, `-1`",
- found_value=self.axis_key,
- ),
- )
- if self.axis_value not in [0, 1, -1]:
- raise ValueError(
- incorrect_arg_msg.format(
- key="axis_value",
- correct_value="`1` or `0` or `-1`",
- found_value=self.axis_value,
- ),
- )
- @dataclass
- class StaticCacheConfig(CacheConfig):
- """
- Configuration class for static cache settings.
- """
- cache_implementation = "static"
- def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
- self.batch_size = batch_size
- self.max_cache_len = max_cache_len
- self.device = device
- def validate(self):
- """Validates if the arguments passed are correct"""
- incorrect_arg_msg = (
- "Some of the keys in `cache_config` are defined incorrectly. `{key}` should be {correct_value}` "
- "but found {found_value}"
- )
- if self.batch_size <= 0:
- raise ValueError(
- incorrect_arg_msg.format(
- key="batch_size",
- correct_value="> 0",
- found_value=self.batch_size,
- ),
- )
- if self.max_cache_len <= 0:
- raise ValueError(
- incorrect_arg_msg.format(
- key="max_cache_len",
- correct_value="> 0",
- found_value=self.max_cache_len,
- ),
- )
- class DynamicCache(Cache):
- """
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
- `[batch_size, num_heads, seq_len, head_dim]`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> past_key_values = DynamicCache()
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- DynamicCache()
- ```
- """
- @deprecate_kwarg("num_hidden_layers", version="4.47.0")
- def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
- super().__init__()
- self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
- self.key_cache: List[torch.Tensor] = []
- self.value_cache: List[torch.Tensor] = []
- def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
- """
- Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
- sequence length.
- """
- if layer_idx < len(self):
- return (self.key_cache[layer_idx], self.value_cache[layer_idx])
- else:
- raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
- def __iter__(self):
- """
- Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
- keys and values
- """
- for layer_idx in range(len(self)):
- yield (self.key_cache[layer_idx], self.value_cache[layer_idx])
- def __len__(self):
- """
- Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
- to the number of layers in the model.
- """
- return len(self.key_cache)
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.
- Return:
- A tuple containing the updated key and value states.
- """
- # Update the number of seen tokens
- if layer_idx == 0:
- self._seen_tokens += key_states.shape[-2]
- # Update the cache
- if len(self.key_cache) <= layer_idx:
- # There may be skipped layers, fill them with empty lists
- for _ in range(len(self.key_cache), layer_idx):
- self.key_cache.append([])
- self.value_cache.append([])
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
- elif len(self.key_cache[layer_idx]) == 0: # fills previously skipped layers; checking for tensor causes errors
- self.key_cache[layer_idx] = key_states
- self.value_cache[layer_idx] = value_states
- else:
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # TODO: deprecate this function in favor of `cache_position`
- is_empty_layer = (
- len(self.key_cache) == 0 # no cache in any layer
- or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
- or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
- )
- layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
- return layer_seq_length
- def get_max_cache_shape(self) -> Optional[int]:
- """Returns the maximum sequence length of the cache object. DynamicCache does not have a maximum length."""
- return None
- def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
- """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
- backward compatibility."""
- legacy_cache = ()
- for layer_idx in range(len(self)):
- legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
- return legacy_cache
- @classmethod
- @deprecate_kwarg("num_hidden_layers", version="4.47.0")
- def from_legacy_cache(
- cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, num_hidden_layers: int = None
- ) -> "DynamicCache":
- """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
- backward compatibility."""
- cache = cls()
- if past_key_values is not None:
- for layer_idx in range(len(past_key_values)):
- key_states, value_states = past_key_values[layer_idx]
- cache.update(key_states, value_states, layer_idx)
- return cache
- def crop(self, max_length: int):
- """Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
- negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
- # In case it is negative
- if max_length < 0:
- max_length = self.get_seq_length() - abs(max_length)
- if self.get_seq_length() <= max_length:
- return
- self._seen_tokens = max_length
- for idx in range(len(self.key_cache)):
- if self.key_cache[idx] != []:
- self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
- self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
- @deprecate_kwarg("num_hidden_layers", version="4.47.0")
- def batch_split(
- self, full_batch_size: int, split_size: int, num_hidden_layers: int = None
- ) -> List["DynamicCache"]:
- """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
- `_split_model_inputs()` in `generation.utils`"""
- out = []
- for i in range(0, full_batch_size, split_size):
- current_split = DynamicCache()
- current_split._seen_tokens = self._seen_tokens
- current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
- current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
- out.append(current_split)
- return out
- @classmethod
- @deprecate_kwarg("num_hidden_layers", version="4.47.0")
- def from_batch_splits(cls, splits: List["DynamicCache"], num_hidden_layers: int = None) -> "DynamicCache":
- """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
- `generation.utils`"""
- cache = cls()
- for idx in range(len(splits[0])):
- key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
- value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
- if key_cache != []:
- layer_keys = torch.cat(key_cache, dim=0)
- layer_values = torch.cat(value_cache, dim=0)
- cache.update(layer_keys, layer_values, idx)
- return cache
- def batch_repeat_interleave(self, repeats: int):
- """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
- for layer_idx in range(len(self)):
- self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
- self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
- def batch_select_indices(self, indices: torch.Tensor):
- """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
- for layer_idx in range(len(self)):
- self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
- self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
- class OffloadedCache(DynamicCache):
- """
- A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
- Useful for generating from models with very long context.
- In addition to the default CUDA stream, where all forward() computations happen,
- this class uses another stream, the prefetch stream, which it creates itself.
- Since scheduling of operations on separate streams happens independently, this class uses
- the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
- The movement of the layer k-1 cache to the CPU is handled by the default stream as a simple way to
- ensure the eviction is scheduled after all computations on that cache are finished.
- """
- def __init__(self) -> None:
- if not torch.cuda.is_available():
- raise RuntimeError("OffloadedCache can only be used with a GPU")
- super().__init__()
- self.original_device = []
- self.prefetch_stream = torch.cuda.Stream()
- self.beam_idx = None # used to delay beam search operations
- def prefetch_layer(self, layer_idx: int):
- "Starts prefetching the next layer cache"
- if layer_idx < len(self):
- with torch.cuda.stream(self.prefetch_stream):
- # Prefetch next layer tensors to GPU
- device = self.original_device[layer_idx]
- self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
- self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device, non_blocking=True)
- def evict_previous_layer(self, layer_idx: int):
- "Moves the previous layer cache to the CPU"
- if len(self) > 2:
- # We do it on the default stream so it occurs after all earlier computations on these tensors are done
- prev_layer_idx = (layer_idx - 1) % len(self)
- self.key_cache[prev_layer_idx] = self.key_cache[prev_layer_idx].to("cpu", non_blocking=True)
- self.value_cache[prev_layer_idx] = self.value_cache[prev_layer_idx].to("cpu", non_blocking=True)
- def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
- "Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
- if layer_idx < len(self):
- # Evict the previous layer if necessary
- torch.cuda.current_stream().synchronize()
- self.evict_previous_layer(layer_idx)
- # Load current layer cache to its original device if not already there
- original_device = self.original_device[layer_idx]
- self.prefetch_stream.synchronize()
- key_tensor = self.key_cache[layer_idx]
- value_tensor = self.value_cache[layer_idx]
- # Now deal with beam search ops which were delayed
- if self.beam_idx is not None:
- self.beam_idx = self.beam_idx.to(original_device)
- key_tensor = key_tensor.index_select(0, self.beam_idx)
- value_tensor = value_tensor.index_select(0, self.beam_idx)
- # Prefetch the next layer
- self.prefetch_layer((layer_idx + 1) % len(self))
- return (key_tensor, value_tensor)
- else:
- raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Saves the beam indices and reorders the cache when the tensor is back to its device."""
- # We delay this operation until the tensors are back to their original
- # device because performing torch.index_select on the CPU is very slow
- del self.beam_idx
- self.beam_idx = beam_idx.clone()
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. No additional arguments are used in `OffloadedCache`.
- Return:
- A tuple containing the updated key and value states.
- """
- # Update the number of seen tokens
- if layer_idx == 0:
- self._seen_tokens += key_states.shape[-2]
- # Update the cache
- if len(self.key_cache) < layer_idx:
- raise ValueError("OffloadedCache does not support model usage where layers are skipped. Use DynamicCache.")
- elif len(self.key_cache) == layer_idx:
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
- self.original_device.append(key_states.device)
- self.evict_previous_layer(layer_idx)
- else:
- key_tensor, value_tensor = self[layer_idx]
- self.key_cache[layer_idx] = torch.cat([key_tensor, key_states], dim=-2)
- self.value_cache[layer_idx] = torch.cat([value_tensor, value_states], dim=-2)
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
- # According to https://docs.python.org/3/library/exceptions.html#NotImplementedError
- # if a method is not supposed to be supported in a subclass we should set it to None
- from_legacy_cache = None
- to_legacy_cache = None
- class QuantizedCache(DynamicCache):
- """
- A quantizer cache similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://arxiv.org/abs/2402.02750).
- It allows the model to generate longer sequence length without allocating too much memory for Key and Value cache by applying quantization.
- The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length` is set as a maximum capacity for the
- original precision cache. When the length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache. The
- quantization is done per-channel with a set `q_group_size` for both Keys and Values, in contrast to what was described in the paper.
- It stores Keys and Values a list of quantized tensors (tuples in case we need to store metadata), one for each layer. Additionally, it stores the Key and
- Value in original precision states as a list of tensors, one for each layer. The size of each tensor
- is `[batch_size, num_heads, seq_len - residual_length, head_dim]`
- """
- def __init__(self, cache_config: QuantizedCacheConfig) -> None:
- super().__init__()
- self._quantized_key_cache: List[torch.Tensor] = []
- self._quantized_value_cache: List[torch.Tensor] = []
- self.nbits = cache_config.nbits
- self.residual_length = cache_config.residual_length
- self.q_group_size = cache_config.q_group_size
- self.axis_key = cache_config.axis_key
- self.axis_value = cache_config.axis_value
- self.compute_dtype = cache_config.compute_dtype
- self.device = cache_config.device
- super().__init__()
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- # Update the number of seen tokens
- if layer_idx == 0:
- self._seen_tokens += key_states.shape[-2]
- if len(self.key_cache) < layer_idx:
- raise ValueError("QuantizedCache does not support model usage where layers are skipped. Use DynamicCache.")
- elif len(self.key_cache) == layer_idx:
- self._quantized_key_cache.append(self._quantize(key_states.contiguous(), axis=self.axis_key))
- self._quantized_value_cache.append(self._quantize(value_states.contiguous(), axis=self.axis_value))
- self.key_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
- self.value_cache.append(torch.zeros(0, dtype=key_states.dtype, device=key_states.device))
- keys_to_return, values_to_return = key_states, value_states
- else:
- dequant_key = self._dequantize(self._quantized_key_cache[layer_idx])
- dequant_value = self._dequantize(self._quantized_value_cache[layer_idx])
- keys_to_return = [dequant_key, self.key_cache[layer_idx], key_states]
- values_to_return = [dequant_value, self.value_cache[layer_idx], value_states]
- keys_to_return = torch.cat(keys_to_return, dim=-2)
- values_to_return = torch.cat(values_to_return, dim=-2)
- if (
- self.key_cache[layer_idx].dim() == 4
- and self.key_cache[layer_idx].shape[-2] + 1 >= self.residual_length
- ):
- self._quantized_key_cache[layer_idx] = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
- self._quantized_value_cache[layer_idx] = self._quantize(
- values_to_return.contiguous(), axis=self.axis_value
- )
- self.key_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
- self.value_cache[layer_idx] = torch.zeros(0, dtype=key_states.dtype, device=key_states.device)
- else:
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
- return keys_to_return, values_to_return
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- if len(self.key_cache) <= layer_idx:
- return 0
- # since we cannot get the seq_length of each layer directly and rely on `_seen_tokens` which is
- # updated every "layer_idx" == 0, this is a hack to get the actual seq_length for the given layer_idx
- # this part of code otherwise fails when used to verify attn_weight shape in some models
- return self._seen_tokens if layer_idx == 0 else self._seen_tokens - 1
- def _quantize(self, tensor, axis):
- """Quantizes a key/value using a defined quantization method."""
- raise NotImplementedError("Make sure to implement `_quantize` in a subclass.")
- def _dequantize(self, q_tensor):
- """Dequantizes back the tensor that was quantized by `self._quantize()`"""
- raise NotImplementedError("Make sure to implement `_dequantize` in a subclass.")
- class QuantoQuantizedCache(QuantizedCache):
- """
- Quantized Cache class that uses `quanto` as a backend to perform quantization. Current implementation supports `int2` and `int4` dtypes only.
- Parameters:
- cache_config (`QuantizedCacheConfig`):
- A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
- Example:
- ```python
- >>> # Run pip install quanto first if you don't have it yet
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, QuantoQuantizedCache, QuantizedCacheConfig
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> cache_config = QuantizedCacheConfig(nbits=4)
- >>> past_key_values = QuantoQuantizedCache(cache_config=cache_config)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- QuantoQuantizedCache()
- ```
- """
- def __init__(self, cache_config: CacheConfig) -> None:
- super().__init__(cache_config)
- if is_optimum_quanto_available():
- from optimum.quanto import MaxOptimizer, qint2, qint4
- elif is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- quanto_version = version.parse(importlib.metadata.version("quanto"))
- if quanto_version < version.parse("0.2.0"):
- raise ImportError(
- f"You need quanto package version to be greater or equal than 0.2.0 to use `QuantoQuantizedCache`. Detected version {quanto_version}. "
- f"Since quanto will be deprecated, please install optimum-quanto instead with `pip install -U optimum-quanto`"
- )
- from quanto import MaxOptimizer, qint2, qint4
- if self.nbits not in [2, 4]:
- raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
- if self.axis_key not in [0, -1]:
- raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
- if self.axis_value not in [0, -1]:
- raise ValueError(
- f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
- )
- self.qtype = qint4 if self.nbits == 4 else qint2
- self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
- def _quantize(self, tensor, axis):
- # We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
- if is_optimum_quanto_available():
- from optimum.quanto import quantize_weight
- qtensor = quantize_weight(tensor, self.qtype, axis, self.q_group_size)
- return qtensor
- elif is_quanto_available():
- logger.warning_once(
- "Importing from quanto will be deprecated in v4.47. Please install optimum-quanto instead `pip install optimum-quanto`"
- )
- from quanto import AffineQuantizer
- scale, zeropoint = self.optimizer(tensor, self.qtype.bits, axis, self.q_group_size)
- qtensor = AffineQuantizer.apply(tensor, self.qtype, axis, self.q_group_size, scale, zeropoint)
- return qtensor
- def _dequantize(self, qtensor):
- return qtensor.dequantize()
- class HQQQuantizedCache(QuantizedCache):
- """
- Quantized Cache class that uses `HQQ` as a backend to perform quantization. Current implementation supports `int2`, `int4`, `int8` dtypes.
- Parameters:
- cache_config (`QuantizedCacheConfig`):
- A configuration containing all the arguments to be used by the quantizer, including axis, qtype and group size.
- Example:
- ```python
- >>> # Run pip install hqq first if you don't have it yet
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HQQQuantizedCache, QuantizedCacheConfig
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> cache_config = QuantizedCacheConfig(nbits=4, axis_key=1, axis_value=1)
- >>> past_key_values = HQQQuantizedCache(cache_config=cache_config)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- HQQQuantizedCache()
- ```
- """
- def __init__(self, cache_config: CacheConfig) -> None:
- super().__init__(cache_config)
- if self.nbits not in [1, 2, 3, 4, 8]:
- raise ValueError(
- f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
- )
- if self.axis_key not in [0, 1]:
- raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
- if self.axis_value not in [0, 1]:
- raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
- self.quantizer = HQQQuantizer
- def _quantize(self, tensor, axis):
- qtensor, meta = self.quantizer.quantize(
- tensor,
- axis=axis,
- device=self.device,
- compute_dtype=self.compute_dtype,
- nbits=self.nbits,
- group_size=self.q_group_size,
- )
- meta["compute_dtype"] = self.compute_dtype
- self.quantizer.cuda(qtensor, meta=meta, device=self.device) # Move to device and cast to dtype
- return qtensor, meta
- def _dequantize(self, qtensor):
- quant_tensor, meta = qtensor
- tensor = self.quantizer.dequantize(quant_tensor, meta)
- return tensor
- class SinkCache(Cache):
- """
- A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to
- generate beyond the length of its context window, without losing fluency in the conversation. As it discards past
- tokens, the model will lose the ability to generate tokens that depend on the context that was discarded.
- It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
- `[batch_size, num_heads, seq_len, head_dim]`.
- Parameters:
- window_length (`int`):
- The length of the context window.
- num_sink_tokens (`int`):
- The number of sink tokens. See the original paper for more information.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> past_key_values = SinkCache(window_length=256, num_sink_tokens=4)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- SinkCache()
- ```
- """
- is_sliding = True
- def __init__(self, window_length: int, num_sink_tokens: int) -> None:
- super().__init__()
- self.key_cache: List[torch.Tensor] = []
- self.value_cache: List[torch.Tensor] = []
- self.window_length = window_length
- self.num_sink_tokens = num_sink_tokens
- self.cos_sin_rerotation_cache = {}
- self._cos_cache = None
- self._sin_cache = None
- self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
- @staticmethod
- def _rotate_half(x):
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
- def _apply_key_rotary_pos_emb(
- self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
- ) -> torch.Tensor:
- rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin)
- return rotated_key_states
- def _get_rerotation_cos_sin(
- self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- if key_states.shape[-2] not in self.cos_sin_rerotation_cache:
- # Upcast to float32 temporarily for better accuracy
- cos = cos.to(torch.float32)
- sin = sin.to(torch.float32)
- # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence
- original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :]
- shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]]
- original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :]
- shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]]
- rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin
- rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin
- self.cos_sin_rerotation_cache[key_states.shape[-2]] = (
- rerotation_cos.to(key_states.dtype).unsqueeze(0),
- rerotation_sin.to(key_states.dtype).unsqueeze(0),
- )
- return self.cos_sin_rerotation_cache[key_states.shape[-2]]
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # TODO: deprecate this function in favor of `cache_position`
- # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
- if len(self.key_cache) <= layer_idx:
- return 0
- return self.key_cache[layer_idx].shape[-2]
- def get_max_cache_shape(self) -> Optional[int]:
- """Returns the maximum sequence length of the cache object, in case of SinkCache it is the window length."""
- return self.window_length
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`,
- `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the
- rotation as the tokens are shifted.
- Return:
- A tuple containing the updated key and value states.
- """
- # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
- # with partially rotated position embeddings, like Phi or Persimmon.
- sin = cache_kwargs.get("sin")
- cos = cache_kwargs.get("cos")
- partial_rotation_size = cache_kwargs.get("partial_rotation_size")
- using_rope = cos is not None and sin is not None
- # Update the number of seen tokens
- if layer_idx == 0:
- self._seen_tokens += key_states.shape[-2]
- # Update the sin/cos cache, which holds sin/cos values for all possible positions
- if using_rope and layer_idx == 0:
- # BC: some models still pass `sin`/`cos` with 2 dims. In those models, they are the full sin/cos. Remove
- # after all RoPE models have a llama-like cache utilization.
- if cos.dim() == 2:
- self._cos_cache = cos
- self._sin_cache = sin
- else:
- if self._cos_cache is None:
- self._cos_cache = cos[0, ...]
- self._sin_cache = sin[0, ...]
- elif self._cos_cache.shape[0] < self.window_length:
- self._cos_cache = torch.cat([self._cos_cache, cos[0, ...]], dim=0)
- self._sin_cache = torch.cat([self._sin_cache, sin[0, ...]], dim=0)
- # [bsz, num_heads, seq_len, head_dim]
- if len(self.key_cache) <= layer_idx:
- # Empty cache
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
- elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
- # Growing cache
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
- else:
- # Shifting cache
- keys_to_keep = self.key_cache[layer_idx][
- :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] :
- ]
- # On RoPE models, we need to recompute the Key rotation as the tokens are shifted
- if using_rope:
- rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(
- key_states, self._cos_cache[: self.window_length], self._sin_cache[: self.window_length]
- )
- if partial_rotation_size is not None:
- keys_to_keep, keys_pass = (
- keys_to_keep[..., :partial_rotation_size],
- keys_to_keep[..., partial_rotation_size:],
- )
- keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin)
- if partial_rotation_size is not None:
- keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1)
- # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens
- sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens]
- self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2)
- sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens]
- values_to_keep = self.value_cache[layer_idx][
- :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] :
- ]
- self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2)
- return self.key_cache[layer_idx], self.value_cache[layer_idx]
- class StaticCache(Cache):
- """
- Static Cache class to be used with `torch.compile(model)` and `torch.export()`.
- Parameters:
- config (`PretrainedConfig`):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- batch_size (`int`):
- The batch size with which the model will be used. Note that a new instance must be instantiated if a
- smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
- max_cache_len (`int`):
- The maximum sequence length with which the model will be used.
- device (`torch.device` or `str`):
- The device on which the cache should be initialized. Should be the same as the layer.
- dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
- The default `dtype` to use when initializing the layer.
- layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
- Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
- You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
- >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
- >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
- >>> max_generated_length = inputs.input_ids.shape[1] + 10
- >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- StaticCache()
- ```
- """
- # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
- def __init__(
- self,
- config: PretrainedConfig,
- batch_size: int = None,
- max_cache_len: int = None,
- device: torch.device = None,
- dtype: torch.dtype = torch.float32,
- max_batch_size: Optional[int] = None,
- layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
- ) -> None:
- super().__init__()
- if max_batch_size is not None:
- logger.warning_once(
- f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
- "v4.46. Use the more precisely named 'batch_size' argument instead."
- )
- self.batch_size = batch_size or max_batch_size
- self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
- # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
- self.head_dim = (
- config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- )
- self.dtype = dtype
- self.num_key_value_heads = (
- config.num_attention_heads
- if getattr(config, "num_key_value_heads", None) is None
- else config.num_key_value_heads
- )
- self.key_cache: List[torch.Tensor] = []
- self.value_cache: List[torch.Tensor] = []
- # Note: There will be significant perf decrease if switching to use 5D tensors instead.
- cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
- for idx in range(config.num_hidden_layers):
- if layer_device_map is not None:
- layer_device = layer_device_map[idx]
- else:
- layer_device = device
- new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
- new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
- # Notes:
- # 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
- # breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
- # it is not needed anyway)
- # 2. `torch.export()` requires mutations to be registered as buffers.
- if not is_torchdynamo_compiling():
- self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
- self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
- new_layer_key_cache = getattr(self, f"key_cache_{idx}")
- new_layer_value_cache = getattr(self, f"value_cache_{idx}")
- torch._dynamo.mark_static_address(new_layer_key_cache)
- torch._dynamo.mark_static_address(new_layer_value_cache)
- self.key_cache.append(new_layer_key_cache)
- self.value_cache.append(new_layer_value_cache)
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, `optional`):
- Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
- to know how where to write in the cache.
- Return:
- A tuple containing the updated key and value states.
- """
- cache_position = cache_kwargs.get("cache_position")
- k_out = self.key_cache[layer_idx]
- v_out = self.value_cache[layer_idx]
- if cache_position is None:
- k_out.copy_(key_states)
- v_out.copy_(value_states)
- else:
- # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
- # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does explicitly an in-place
- # operation, that avoids copies and uses less memory.
- try:
- k_out.index_copy_(2, cache_position, key_states)
- v_out.index_copy_(2, cache_position, value_states)
- except NotImplementedError:
- # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
- return k_out, v_out
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states that were seen by the model."""
- # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
- # limit the check to the first batch member and head dimension.
- # TODO: deprecate this function in favor of `cache_position`
- return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
- def get_max_cache_shape(self) -> Optional[int]:
- return self.max_cache_len
- def reset(self):
- """Resets the cache values while preserving the objects"""
- for layer_idx in range(len(self.key_cache)):
- # In-place ops prevent breaking the static address
- self.key_cache[layer_idx].zero_()
- self.value_cache[layer_idx].zero_()
- class SlidingWindowCache(StaticCache):
- """
- Sliding Window Cache class to be used with `torch.compile` for models like Mistral that support sliding window attention.
- Every time when we try to update the cache, we compute the `indices` based on `cache_position >= self.config.sliding_window - 1`,
- if true(which means the cache can not hold all the old key value states and new states together because of the sliding window constraint),
- we need to do a cycle shift based on `indices` to replace the oldest states by the new key value states passed in.
- The `to_shift` is only true once we are above sliding_window. Thus with `sliding_window==64`:
- indices = (slicing + to_shift[-1].int()-1) % self.config.sliding_window
- tensor([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
- 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36,
- 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
- 55, 56, 57, 58, 59, 60, 61, 62, 63, 0])
- We overwrite the cache using these, then we always write at cache_position (clamped to `sliding_window`)
- Parameters:
- config (`PretrainedConfig`):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- batch_size (`int`):
- The batch size with which the model will be used. Note that a new instance must be instantiated if a
- smaller batch size is used.
- max_cache_len (`int`):
- The maximum sequence length with which the model will be used.
- device (`torch.device` or `str`):
- The device on which the cache should be initialized. Should be the same as the layer.
- dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
- The default `dtype` to use when initializing the layer.
- layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
- Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
- You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, SlidingWindowCache
- >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
- >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
- >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
- >>> max_generated_length = inputs.input_ids.shape[1] + 10
- >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- SlidingWindowCache()
- ```
- """
- is_sliding = True
- # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
- def __init__(
- self,
- config: PretrainedConfig,
- batch_size: int = None,
- max_cache_len: int = None,
- device: torch.device = None,
- dtype: torch.dtype = torch.float32,
- max_batch_size: Optional[int] = None,
- layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
- ) -> None:
- if not hasattr(config, "sliding_window") or config.sliding_window is None:
- raise ValueError(
- "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
- "sliding window attention, please check if there is a `sliding_window` field in the model "
- "config and it's not set to None."
- )
- max_cache_len = min(config.sliding_window, max_cache_len)
- super().__init__(
- config=config,
- batch_size=batch_size,
- max_cache_len=max_cache_len,
- device=device,
- dtype=dtype,
- max_batch_size=max_batch_size,
- layer_device_map=layer_device_map,
- )
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor]:
- cache_position = cache_kwargs.get("cache_position")
- k_out = self.key_cache[layer_idx]
- v_out = self.value_cache[layer_idx]
- # assume this only happens in prefill phase when prompt length > sliding_window_size (= max_cache_len)
- if cache_position.shape[0] > self.max_cache_len:
- k_out = key_states[:, :, -self.max_cache_len :, :]
- v_out = value_states[:, :, -self.max_cache_len :, :]
- # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
- # we should return the whole states instead of k_out, v_out to take the whole prompt
- # into consideration when building kv cache instead of just throwing away tokens outside of the window
- return key_states, value_states
- slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
- cache_position = cache_position.clamp(0, self.max_cache_len - 1)
- to_shift = cache_position >= self.max_cache_len - 1
- indices = (slicing + to_shift[-1].int() - 1) % self.max_cache_len
- k_out = k_out[:, :, indices]
- v_out = v_out[:, :, indices]
- try:
- k_out.index_copy_(2, cache_position, key_states)
- v_out.index_copy_(2, cache_position, value_states)
- except NotImplementedError:
- # The operator 'aten::index_copy.out' is not currently implemented for the MPS device.
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
- # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
- self.key_cache[layer_idx].zero_()
- self.value_cache[layer_idx].zero_()
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
- return k_out, v_out
- def get_max_cache_shape(self) -> Optional[int]:
- return self.max_cache_len
- def reset(self):
- for layer_idx in range(len(self.key_cache)):
- # In-place ops prevent breaking the static address
- self.key_cache[layer_idx].zero_()
- self.value_cache[layer_idx].zero_()
- class EncoderDecoderCache(Cache):
- """
- Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
- cross-attention caches.
- Example:
- ```python
- >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
- >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
- >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
- >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
- >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
- >>> self_attention_cache = DynamicCache()
- >>> cross_attention_cache = DynamicCache()
- >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- EncoderDecoderCache()
- ```
- """
- def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
- super().__init__()
- self.self_attention_cache = self_attention_cache
- self.cross_attention_cache = cross_attention_cache
- self.is_updated = {}
- for layer_idx in range(len(cross_attention_cache.key_cache)):
- self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
- def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
- """
- Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
- sequence length.
- """
- if layer_idx < len(self):
- return (
- self.self_attention_cache.key_cache[layer_idx],
- self.self_attention_cache.value_cache[layer_idx],
- self.cross_attention_cache.key_cache[layer_idx],
- self.cross_attention_cache.value_cache[layer_idx],
- )
- else:
- raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
- def __len__(self):
- """
- Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
- to the number of layers in the model.
- """
- return len(self.self_attention_cache)
- def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
- """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
- legacy_cache = ()
- if len(self.cross_attention_cache) > 0:
- for self_attn, cross_attn in zip(
- self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
- ):
- legacy_cache += (self_attn + cross_attn,)
- else:
- legacy_cache = self.self_attention_cache.to_legacy_cache()
- return legacy_cache
- @classmethod
- def from_legacy_cache(
- cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
- ) -> "EncoderDecoderCache":
- """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
- cache = cls(
- self_attention_cache=DynamicCache(),
- cross_attention_cache=DynamicCache(),
- )
- if past_key_values is not None:
- for layer_idx in range(len(past_key_values)):
- key_states, value_states = past_key_values[layer_idx][:2]
- cache.self_attention_cache.update(key_states, value_states, layer_idx)
- if len(past_key_values[layer_idx]) > 2:
- key_states, value_states = past_key_values[layer_idx][2:]
- cache.cross_attention_cache.update(key_states, value_states, layer_idx)
- cache.is_updated[layer_idx] = True
- return cache
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
- return self.self_attention_cache.get_seq_length(layer_idx)
- def reset(self):
- if hasattr(self.self_attention_cache, "reset"):
- self.self_attention_cache.reset()
- if hasattr(self.cross_attention_cache, "reset"):
- self.cross_attention_cache.reset()
- elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
- raise ValueError(
- "Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
- "only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
- f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
- f"{self.cross_attention_cache.__str__()} for the cross attention cache."
- )
- for layer_idx in self.is_updated:
- self.is_updated[layer_idx] = False
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Reorders the cache for beam search, given the selected beam indices."""
- self.self_attention_cache.reorder_cache(beam_idx)
- self.cross_attention_cache.reorder_cache(beam_idx)
- def check_dynamic_cache(self, method: str):
- if not (
- isinstance(self.self_attention_cache, DynamicCache)
- and isinstance(self.cross_attention_cache, DynamicCache)
- ):
- raise ValueError(
- f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
- f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
- )
- # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
- def crop(self, maximum_length: int):
- """Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
- negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
- self.check_dynamic_cache(self.crop.__name__)
- self.self_attention_cache.crop(maximum_length)
- def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
- """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
- `_split_model_inputs()` in `generation.utils`"""
- self.check_dynamic_cache(self.batch_split.__name__)
- self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
- cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
- out = []
- for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
- out.append(EncoderDecoderCache(self_attn, cross_attn))
- return out
- @classmethod
- def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
- """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
- `generation.utils`"""
- self_attention_cache = DynamicCache()
- cross_attention_cache = DynamicCache()
- for idx in range(len(splits[0])):
- layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
- layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
- self_attention_cache.update(layer_keys, layer_values, idx)
- layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
- layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
- cross_attention_cache.update(layer_keys, layer_values, idx)
- return cls(self_attention_cache, cross_attention_cache)
- def batch_repeat_interleave(self, repeats: int):
- """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
- self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
- self.self_attention_cache.batch_repeat_interleave(repeats)
- self.cross_attention_cache.batch_repeat_interleave(repeats)
- def batch_select_indices(self, indices: torch.Tensor):
- """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
- self.check_dynamic_cache(self.batch_select_indices.__name__)
- self.self_attention_cache.batch_select_indices(indices)
- self.cross_attention_cache.batch_select_indices(indices)
- class HybridCache(Cache):
- """
- Hybrid Cache class to be used with `torch.compile` for Gemma2 models that alternate between a local sliding window attention
- and global attention in every other layer. Under the hood, Hybrid Cache leverages ["SlidingWindowCache"] for sliding window attention
- and ["StaticCache"] for global attention. For more information, see the documentation of each subcomponeent cache class.
- Parameters:
- config (`PretrainedConfig):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- batch_size (`int`):
- The batch size with which the model will be used. Note that a new instance must be instantiated if a
- smaller batch size is used.
- max_cache_len (`int`):
- The maximum sequence length with which the model will be used.
- device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
- The device on which the cache should be initialized. Should be the same as the layer.
- dtype (torch.dtype, *optional*, defaults to `torch.float32`):
- The default `dtype` to use when initializing the layer.
- layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
- Mapping between the layers and its device. This is required when you are manually initializing the cache and the model is splitted between differents gpus.
- You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, HybridCache
- >>> model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")
- >>> tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
- >>> inputs = tokenizer(text="My name is Gemma", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
- >>> max_generated_length = inputs.input_ids.shape[1] + 10
- >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- HybridCache()
- ```
- """
- # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
- def __init__(
- self,
- config: PretrainedConfig,
- batch_size: int = None,
- max_cache_len: int = None,
- device: Union[torch.device, str] = "cpu",
- dtype: torch.dtype = torch.float32,
- max_batch_size: Optional[int] = None,
- layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
- ) -> None:
- super().__init__()
- if max_batch_size is not None:
- logger.warning_once(
- f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
- "v4.46. Use the more precisely named 'batch_size' argument instead."
- )
- if not hasattr(config, "sliding_window") or config.sliding_window is None:
- raise ValueError(
- "Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
- "sliding window attention, please check if there is a `sliding_window` field in the model "
- "config and it's not set to None."
- )
- self.max_cache_len = max_cache_len
- self.batch_size = batch_size or max_batch_size
- # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
- self.head_dim = (
- config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- )
- self.dtype = dtype
- self.num_key_value_heads = (
- config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
- )
- self.is_sliding = torch.tensor(
- [not bool(i % 2) for i in range(config.num_hidden_layers)], dtype=torch.bool, device=device
- )
- self.key_cache: List[torch.Tensor] = []
- self.value_cache: List[torch.Tensor] = []
- global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
- sliding_cache_shape = (
- self.batch_size,
- self.num_key_value_heads,
- min(config.sliding_window, max_cache_len),
- self.head_dim,
- )
- for i in range(config.num_hidden_layers):
- if layer_device_map is not None:
- layer_device = layer_device_map[i]
- else:
- layer_device = device
- # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
- # breaks when updating the cache.
- cache_shape = global_cache_shape if not self.is_sliding[i] else sliding_cache_shape
- new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
- new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
- torch._dynamo.mark_static_address(new_layer_key_cache)
- torch._dynamo.mark_static_address(new_layer_value_cache)
- self.key_cache.append(new_layer_key_cache)
- self.value_cache.append(new_layer_value_cache)
- def _sliding_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
- if cache_position.shape[0] > max_cache_len:
- k_out = key_states[:, :, -max_cache_len:, :]
- v_out = value_states[:, :, -max_cache_len:, :]
- # Assumption: caches are all zeros at this point, `+=` is equivalent to `=` but compile-friendly
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
- # we should return the whole states instead of k_out, v_out to take the whole prompt
- # into consideration when building kv cache instead of just throwing away tokens outside of the window
- return key_states, value_states
- slicing = torch.ones(max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
- cache_position = cache_position.clamp(0, max_cache_len - 1)
- to_shift = cache_position >= max_cache_len - 1
- indices = (slicing + to_shift[-1].int() - 1) % max_cache_len
- k_out = k_out[:, :, indices]
- v_out = v_out[:, :, indices]
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
- # `_.zero()` followed by `+=` is equivalent `=`, but compile-friendly (without graph breaks due to assignment)
- self.key_cache[layer_idx].zero_()
- self.value_cache[layer_idx].zero_()
- self.key_cache[layer_idx] += k_out
- self.value_cache[layer_idx] += v_out
- return k_out, v_out
- def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
- self.key_cache[layer_idx] = k_out
- self.value_cache[layer_idx] = v_out
- return k_out, v_out
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor]:
- cache_position = cache_kwargs.get("cache_position")
- sliding_window = cache_kwargs.get("sliding_window")
- k_out = self.key_cache[layer_idx]
- v_out = self.value_cache[layer_idx]
- if sliding_window:
- update_fn = self._sliding_update
- else:
- update_fn = self._static_update
- return update_fn(
- cache_position,
- layer_idx,
- key_states,
- value_states,
- k_out,
- v_out,
- k_out.shape[2],
- )
- def get_max_cache_shape(self) -> Optional[int]:
- return self.max_cache_len
- def get_seq_length(self, layer_idx: Optional[int] = 0):
- # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
- # limit the check to the first batch member and head dimension.
- # TODO: deprecate this function in favor of `cache_position`
- if layer_idx != 0:
- raise ValueError(
- "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. "
- "Using the `layer_idx` argument is not supported."
- )
- return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
- def reset(self):
- """Resets the cache values while preserving the objects"""
- for layer_idx in range(len(self.key_cache)):
- # In-place ops prevent breaking the static address
- self.key_cache[layer_idx].zero_()
- self.value_cache[layer_idx].zero_()
- class MambaCache:
- """
- Cache for mamba model which does not have attention mechanism and key value states.
- Arguments:
- config (`PretrainedConfig):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- batch_size (`int`):
- The batch size with which the model will be used. Note that a new instance must be instantiated if a
- smaller batch size is used.
- dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
- The default `dtype` to use when initializing the layer.
- device (`torch.device` or `str`, *optional*):
- The device on which the cache should be initialized. Should be the same as the layer.
- Attributes:
- dtype: (`torch.dtype`):
- The default `dtype` used to initializing the cache.
- intermediate_size: (`int`):
- Model's intermediate_size taken from config.
- ssm_state_size: (`int`):
- Model's state_size taken from config.
- conv_kernel_size: (`int`):
- Model's convolution kernel size taken from config
- conv_states: (`torch.Tensor`):
- A tensor of shape `[layer_idx, batch_size, intermediate_size, conv_kernel_size]` that holds convolutional states.
- ssm_states: (`torch.Tensor`):
- A tensor of shape `[layer_idx, batch_size, intermediate_size, ssm_state_size]` that holds ssm states
- Example:
- ```python
- >>> from transformers import AutoTokenizer, MambaForCausalLM, MambaCache
- >>> model = MambaForCausalLM.from_pretrained("state-spaces/mamba-130m-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-130m-hf")
- >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values
- MambaCache()
- ```
- """
- # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
- def __init__(
- self,
- config: PretrainedConfig,
- batch_size: int = None,
- dtype: torch.dtype = torch.float16,
- device: Optional[Union[torch.device, str]] = None,
- max_batch_size: Optional[int] = None,
- ):
- if max_batch_size is not None:
- logger.warning_once(
- f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
- "v4.46. Use the more precisely named 'batch_size' argument instead."
- )
- self.dtype = dtype
- self.batch_size = batch_size or max_batch_size
- self.intermediate_size = config.intermediate_size
- self.ssm_state_size = config.state_size
- self.conv_kernel_size = config.conv_kernel
- self.conv_states: torch.Tensor = torch.zeros(
- config.num_hidden_layers,
- self.batch_size,
- self.intermediate_size,
- self.conv_kernel_size,
- device=device,
- dtype=dtype,
- )
- self.ssm_states: torch.Tensor = torch.zeros(
- config.num_hidden_layers,
- self.batch_size,
- self.intermediate_size,
- self.ssm_state_size,
- device=device,
- dtype=dtype,
- )
- torch._dynamo.mark_static_address(self.conv_states)
- torch._dynamo.mark_static_address(self.ssm_states)
- def update_conv_state(
- self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
- ) -> torch.Tensor:
- conv_state = self.conv_states[layer_idx]
- cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
- conv_state = conv_state.roll(shifts=-1, dims=-1)
- conv_state[:, :, cache_position] = new_conv_state.to(device=conv_state.device, dtype=conv_state.dtype)
- self.conv_states[layer_idx].zero_()
- self.conv_states[layer_idx] += conv_state
- return self.conv_states[layer_idx]
- def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
- self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
- return self.ssm_states[layer_idx]
- def reset(self):
- self.conv_states.zero_()
- self.ssm_states.zero_()
- class OffloadedStaticCache(StaticCache):
- """
- Static cache class to be used with `torch.compile(model)` that offloads to the CPU or
- another device.
- Args:
- config (`PretrainedConfig):
- The configuration file defining the shape-related attributes required to initialize
- the static cache.
- max_batch_size (`int`):
- The maximum batch size with which the model will be used.
- max_cache_len (`int`):
- The maximum sequence length with which the model will be used.
- device (`Union[str, torch.device]`):
- The device on which the cache should be initialized. Should be the same as the
- layer device.
- dtype (`torch.dtype`, *optional*):
- The default `dtype` to use when initializing the cache.
- offload_device (`Union[str, torch.device]`, *optional*, defaults to `cpu`):
- The device to offload to. Defaults to CPU.
- Attributes:
- key_cache (`List[torch.Tensor]`):
- Off-loaded key cache tensors. First one will be on device, where-as the others are
- off-loaded.
- value_cache (`List[torch.Tensor]`):
- Off-loaded value cache tensors. First one will be on device, where-as the others are
- off-loaded.
- max_batch_size (`int`):
- The maximum batch size with which this cache can be used.
- max_cache_len (`int`):
- The maximum sequence length with which this cache can be used.
- device (`torch.device`):
- The device on which the cache is used.
- offload_device (`torch.device`):
- The device used to offload to.
- dtype (`torch.dtype`):
- The `dtype` used to initializing the cache.
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, OffloadedStaticCache
- >>> model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
- >>> tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
- >>> inputs = tokenizer(text="My name is GPT2", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
- >>> max_generated_length = inputs.input_ids.shape[1] + 10
- >>> past_key_values = OffloadedStaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
- ```
- """
- def __init__(
- self,
- config: PretrainedConfig,
- max_batch_size: int,
- max_cache_len: Optional[int],
- device: Union[str, torch.device],
- dtype: Optional[torch.dtype] = None,
- offload_device: Union[str, torch.device] = torch.device("cpu"),
- ) -> None:
- self.max_batch_size = max_batch_size
- self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
- self.device = torch.device(device)
- self.offload_device = torch.device(offload_device)
- self.dtype = dtype if dtype is not None else torch.float32
- # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
- head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
- num_key_value_heads = (
- config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
- )
- cache_shape = (max_batch_size, num_key_value_heads, self.max_cache_len, head_dim)
- # Create offloaded CPU tensors.
- self.key_cache: List[torch.Tensor] = []
- self.value_cache: List[torch.Tensor] = []
- for i in range(config.num_hidden_layers):
- # First layer is always on-device.
- device = self.device if i == 0 else self.offload_device
- key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, device)
- self.key_cache.append(key_cache)
- self.value_cache.append(value_cache)
- # Create device tensors.
- self._device_key_cache: List[torch.Tensor] = []
- self._device_value_cache: List[torch.Tensor] = []
- for i in range(2):
- key_cache, value_cache = self._create_key_value_cache_tensors(cache_shape, self.device)
- self._device_key_cache.append(key_cache)
- self._device_value_cache.append(value_cache)
- # For backwards compatibility.
- # TODO(gante): Remove this.
- self._seen_tokens = 0
- # Create new CUDA stream for parallel prefetching.
- self._prefetch_stream = torch.cuda.Stream() if self.device.type == "cuda" else None
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[Dict[str, Any]] = None,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`Dict[str, Any]`, *optional*):
- Additional arguments for the cache subclass. The `OffloadedStaticCache` needs the
- `cache_position` input to know how where to write in the cache.
- Return:
- A tuple containing the updated key and value states.
- """
- if layer_idx == 0:
- # Update seen tokens.
- # TODO(gante): Remove this.
- self._seen_tokens += key_states.shape[-2]
- # Always there.
- k_out = self.key_cache[0]
- v_out = self.value_cache[0]
- else:
- # Wait for prefetch stream.
- if self._prefetch_stream is not None:
- torch.cuda.default_stream(self.device).wait_stream(self._prefetch_stream)
- k_out = self._device_key_cache[layer_idx & 1]
- v_out = self._device_value_cache[layer_idx & 1]
- self._prefetch_layer(layer_idx + 1)
- cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
- if cache_position is None:
- k_out.copy_(key_states)
- v_out.copy_(value_states)
- # Copy the values to the offloaded device as well.
- if layer_idx == 0:
- self.key_cache[layer_idx].copy_(key_states.to(self.offload_device))
- self.value_cache[layer_idx].copy_(value_states.to(self.offload_device))
- else:
- # Note: here we use `tensor.index_copy_(dim, index, tensor)` that is equivalent to
- # `tensor[:, :, index] = tensor`, but the first one is compile-friendly and it does
- # explicitly an in-place operation, that avoids copies and uses less memory.
- try:
- k_out.index_copy_(2, cache_position, key_states)
- v_out.index_copy_(2, cache_position, value_states)
- except NotImplementedError:
- # The operator 'aten::index_copy.out' is not currently implemented for the MPS
- # device.
- k_out[:, :, cache_position] = key_states
- v_out[:, :, cache_position] = value_states
- # Copy the values to the offloaded device as well.
- if layer_idx != 0:
- cache_position = cache_position.to(self.offload_device)
- key_states = key_states.to(self.offload_device)
- value_states = value_states.to(self.offload_device)
- try:
- self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
- self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
- except NotImplementedError:
- # The operator 'aten::index_copy.out' is not currently implemented for the MPS
- # device.
- self.key_cache[layer_idx][:, :, cache_position] = key_states
- self.value_cache[layer_idx][:, :, cache_position] = value_states
- return k_out, v_out
- def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
- """Returns the sequence length of the cached states that were seen by the model."""
- # TODO(gante): Remove this.
- return self._seen_tokens
- def get_max_cache_shape(self) -> Optional[int]:
- """Returns the maximum sequence length of the cached states."""
- return self.max_cache_len
- def reset(self) -> None:
- """Resets the cache values while preserving the objects."""
- # For backwards compatibility.
- # TODO(gante): Remove this.
- self._seen_tokens = 0
- # Zero out cache.
- for layer_idx in range(len(self.key_cache)):
- # In-place ops prevent breaking the static address.
- self.key_cache[layer_idx].zero_()
- self.value_cache[layer_idx].zero_()
- @property
- def seen_tokens(self) -> int:
- # For backwards compatibility.
- # TODO(gante): Remove this.
- return self._seen_tokens
- def _create_key_value_cache_tensors(
- self, shape: Tuple[int, ...], device: torch.device
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Creates K/V cache tensors on a device. Pins memory for CPU tensors. Marks them as static
- addresses for non-CPU tensors.
- Args:
- shape (`Tuple[int, ...]`): Shape.
- device (`torch.device`): Device.
- Returns:
- Key and value cache tensors as a tuple.
- """
- is_cpu_device = device == torch.device("cpu")
- key_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
- value_cache = torch.zeros(shape, dtype=self.dtype, device=device, pin_memory=is_cpu_device)
- # Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
- # preventing compiled graph breaks when updating the cache.
- torch._dynamo.mark_static_address(key_cache)
- torch._dynamo.mark_static_address(value_cache)
- return key_cache, value_cache
- def _prefetch_layer(self, layer_idx: int) -> None:
- """Prefetch a layer to the device. Needs to be called in order of layer indices."""
- # Don't fetch layers that do not exist.
- if layer_idx >= len(self.key_cache):
- return
- # Alternate between two on-device caches.
- if self._prefetch_stream is not None:
- with torch.cuda.stream(self._prefetch_stream):
- self._prefetch_layer_in_context(layer_idx)
- else:
- self._prefetch_layer_in_context(layer_idx)
- def _prefetch_layer_in_context(self, layer_idx: int) -> None:
- """Performs the actual copy of the layer to device cache."""
- self._device_key_cache[layer_idx & 1].copy_(self.key_cache[layer_idx], non_blocking=True)
- self._device_value_cache[layer_idx & 1].copy_(self.value_cache[layer_idx], non_blocking=True)
|