| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286 |
- from ..utils import is_accelerate_available, is_torch_available, logging
- if is_accelerate_available():
- from accelerate import init_empty_weights
- if is_torch_available():
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- logger = logging.get_logger(__name__)
- # the weights are ternary so can be represented with 2 bits, and they are packed in uint8 tensors, hence the number of values per item is 4
- VALUES_PER_ITEM = 4
- def pack_weights(quantized_weights: torch.Tensor) -> torch.Tensor:
- """
- Packs a tensor of quantized weights into a compact format using 2 bits per value.
- Parameters:
- -----------
- quantized_weights : torch.Tensor
- A tensor containing ternary quantized weights with values in {-1, 0, 1}. These values are adjusted to
- {0, 1, 2} before being packed.
- Returns:
- --------
- torch.Tensor
- A packed tensor where each element stores 4 quantized values (each using 2 bits) in an 8-bit format.
- """
- original_shape = quantized_weights.shape
- row_dim = (original_shape[0] + VALUES_PER_ITEM - 1) // VALUES_PER_ITEM
- if len(original_shape) == 1:
- packed_tensor_shape = (row_dim,)
- else:
- packed_tensor_shape = (row_dim, *original_shape[1:])
- quantized_weights += 1
- packed = torch.zeros(packed_tensor_shape, device=quantized_weights.device, dtype=torch.uint8)
- unpacked = quantized_weights.to(torch.uint8)
- it = min(VALUES_PER_ITEM, (original_shape[0] // row_dim) + 1)
- for i in range(it):
- start = i * row_dim
- end = min(start + row_dim, original_shape[0])
- packed[: (end - start)] |= unpacked[start:end] << 2 * i
- return packed
- @torch.compile
- def unpack_weights(packed: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
- """
- Unpacks a tensor of quantized weights that were stored in a packed format using 2 bits per value.
- Parameters:
- -----------
- packed : torch.Tensor
- A tensor containing packed weights where each element represents 4 quantized values (using 2 bits per value).
- dtype : torch.dtype
- The dtype of the returned Tensor
- Returns:
- --------
- torch.Tensor
- A tensor of unpacked weights, where each value is converted from its packed 2-bit representation.
- Example:
- --------
- packed = torch.tensor([[0b10100001, 0b00011000],
- [0b10010000, 0b00001010]], dtype=torch.uint8)
- # Unpack the values
- unpacked = unpack_weights(packed)
- # Resulting unpacked tensor
- print(unpacked)
- # Output: tensor([[ 0, -1],
- [-1, 1],
- [-1, 1],
- [-1, 1],
- [ 1, 0],
- [ 0, -1],
- [ 1, -1],
- [ 1, -1]])
- Explanation of the example:
- ---------------------------
- Let's take the first value for example 0b10100001, we we will only focus on the first column,
- because every element is unpacked across the first dimension
- - First 2 bits: `01` → 0 at [0][0]
- - Second 2 bits: `00` → -1 at [0][2]
- - Third 2 bits: `10` → 1 at [0][4]
- - Fourth 2 bits: `10` → 1 at [0][6]
- the second value of the same row (0b10010000) will give the values for [0][1], [0][3], [0][5], [0][7]
- We subtract 1 because during the packing process, it's easier to work with values like 0, 1, and 2. To make this possible,
- we add 1 to the original ternary weights (which are typically -1, 0, and 1) when packing them. When unpacking, we reverse
- this by subtracting 1 to restore the original ternary values.
- """
- packed_shape = packed.shape
- if len(packed_shape) == 1:
- original_row_dim = packed_shape[0] * VALUES_PER_ITEM
- unpacked_shape = (original_row_dim,)
- else:
- original_row_dim = packed_shape[0] * VALUES_PER_ITEM
- unpacked_shape = (original_row_dim, *packed_shape[1:])
- unpacked = torch.zeros(unpacked_shape, device=packed.device, dtype=torch.uint8)
- for i in range(VALUES_PER_ITEM):
- start = i * packed_shape[0]
- end = start + packed_shape[0]
- mask = 3 << (2 * i)
- unpacked[start:end] = (packed & mask) >> (2 * i)
- return unpacked.to(dtype) - 1
- class BitLinear(nn.Module):
- def __init__(self, in_features: int, out_features: int, bias: bool, device=None, dtype=None):
- super().__init__()
- self.dtype = dtype
- self.register_buffer(
- "weight",
- torch.zeros(
- (out_features // VALUES_PER_ITEM, in_features),
- dtype=torch.uint8,
- device=device,
- ),
- )
- self.register_buffer(
- "weight_scale",
- torch.ones(
- (1),
- dtype=dtype,
- device=device,
- ),
- )
- if bias:
- self.register_buffer("bias", torch.zeros((out_features), dtype=dtype, device=device))
- else:
- self.bias = None
- @torch.compile
- def activation_quant(self, input, num_bits=8):
- """
- Activation function : Performs symmetric, per-token quantization on the input activations.
- Parameters:
- -----------
- x : torch.Tensor
- Input activations to be quantized.
- num_bits : int, optional (default=8)
- Number of bits to use for quantization, determining the quantization range.
- Returns:
- --------
- result : torch.Tensor
- Quantized activation tensor, with values mapped to an `int8` range.
- scale : torch.Tensor
- The per-channel scaling factors used to quantize the tensor.
- """
- Qn = -(2 ** (num_bits - 1))
- Qp = 2 ** (num_bits - 1) - 1
- scale = Qp / input.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
- result = (input * scale).round().clamp(Qn, Qp)
- return result.to(torch.int8), scale
- @torch.compile
- def post_quant_process(self, input, input_scale, weight_scale):
- out = input / (input_scale * weight_scale)
- return out
- def forward(self, input):
- w = self.weight
- w_quant = unpack_weights(w, dtype=self.dtype)
- input_quant, input_scale = self.activation_quant(input)
- y = F.linear(input_quant.to(self.dtype), w_quant)
- y = self.post_quant_process(y, self.weight_scale, input_scale)
- if self.bias is not None:
- y += self.bias.view(1, -1).expand_as(y)
- return y
- def _replace_with_bitnet_linear(
- model,
- modules_to_not_convert=None,
- current_key_name=None,
- quantization_config=None,
- has_been_replaced=False,
- pre_quantized=False,
- ):
- """
- Private method that wraps the recursion for module replacement.
- Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
- """
- if current_key_name is None:
- current_key_name = []
- for name, module in model.named_children():
- if current_key_name is None:
- current_key_name = []
- current_key_name.append(name)
- # Check if the current key is not in the `modules_to_not_convert`
- if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
- with init_empty_weights():
- if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
- in_features = module.in_features
- out_features = module.out_features
- model._modules[name] = BitLinear(
- in_features=in_features,
- out_features=out_features,
- bias=module.bias is not None,
- device=module.weight.device,
- dtype=module.weight.dtype,
- )
- has_been_replaced = True
- model._modules[name].requires_grad_(False)
- if len(list(module.children())) > 0:
- _, has_been_replaced = _replace_with_bitnet_linear(
- module,
- modules_to_not_convert=modules_to_not_convert,
- current_key_name=current_key_name,
- quantization_config=quantization_config,
- has_been_replaced=has_been_replaced,
- )
- # Remove the last key for recursion
- current_key_name.pop(-1)
- return model, has_been_replaced
- def replace_with_bitnet_linear(
- model,
- modules_to_not_convert=None,
- current_key_name=None,
- quantization_config=None,
- pre_quantized=False,
- ):
- """
- A helper function to replace all `torch.nn.Linear` modules by `BitLinear158` modules`.
- The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
- be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
- CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
- Parameters:
- model (`torch.nn.Module`):
- Input model or `torch.nn.Module` as the function is run recursively.
- modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
- Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
- for numerical stability reasons.
- current_key_name (`List[`str`]`, *optional*):
- An array to track the current key of the recursion. This is used to check whether the current key (part of
- it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
- `disk`).
- """
- modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
- if quantization_config and quantization_config.modules_to_not_convert is not None:
- modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
- modules_to_not_convert = list(set(modules_to_not_convert))
- model, has_been_replaced = _replace_with_bitnet_linear(
- model,
- modules_to_not_convert,
- current_key_name,
- quantization_config,
- pre_quantized=pre_quantized,
- )
- if not has_been_replaced:
- logger.warning(
- "You are loading your model using bitnet but no linear modules were found in your model."
- " Please double check your model architecture, or submit an issue on github if you think this is"
- " a bug."
- )
- return model
|