fbgemm_fp8.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..utils import is_accelerate_available, is_fbgemm_gpu_available, is_torch_available, logging
  15. if is_torch_available():
  16. import torch
  17. from torch import nn
  18. if is_accelerate_available():
  19. from accelerate import init_empty_weights
  20. if is_fbgemm_gpu_available():
  21. import fbgemm_gpu.experimental.gen_ai # noqa: F401
  22. logger = logging.get_logger(__name__)
  23. class FbgemmFp8Linear(torch.nn.Module):
  24. def __init__(self, in_features, out_features, bias, weight_dtype=torch.float32):
  25. super().__init__()
  26. self.in_features = in_features
  27. self.out_features = out_features
  28. self.register_buffer("weight", torch.zeros((out_features, in_features), dtype=torch.float8_e4m3fn))
  29. self.register_buffer("weight_scale", torch.zeros((out_features, 1), dtype=weight_dtype))
  30. self.register_buffer("input_scale_ub", torch.zeros([1], dtype=torch.float), persistent=False)
  31. if bias:
  32. self.register_buffer("bias", torch.zeros((self.out_features), dtype=weight_dtype))
  33. else:
  34. self.bias = None
  35. def forward(self, x):
  36. num_tokens = None
  37. # quantize_fp8_per_row will squash the leading dimensions, so save the desired shape here
  38. output_shape = (*x.shape[:-1], -1)
  39. # x_quantized and x_scale are not necessarily on the same device as x, this is an issue.
  40. # https://github.com/pytorch/FBGEMM/blob/e08af8539c391437f447173863df0f3f6f6f1855/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu#L1237C3-L1237C45
  41. x_quantized, x_scale = torch.ops.fbgemm.quantize_fp8_per_row(
  42. x.view(-1, x.shape[-1]), num_tokens, self.input_scale_ub
  43. )
  44. # moving x_quantized, x_scale here creates glibberish output ... However, if we move the output, it works
  45. # x_quantized, x_scale = x_quantized.to(x.device), x_scale.to(x.device)
  46. # The computation still happens on the device where self.weight is even if x_quantized is not on the same device as self.weight
  47. output = torch.ops.fbgemm.f8f8bf16_rowwise(
  48. x_quantized, self.weight, x_scale, self.weight_scale, use_fast_accum=True
  49. )
  50. output = output + self.bias if self.bias is not None else output
  51. # Hacky for now, we have the output to the device of x
  52. output = output.to(x.device)
  53. output = output.reshape(output_shape)
  54. del x_quantized, x_scale
  55. return output
  56. def _replace_with_fbgemm_fp8_linear(
  57. model,
  58. modules_to_not_convert=None,
  59. current_key_name=None,
  60. quantization_config=None,
  61. has_been_replaced=False,
  62. pre_quantized=False,
  63. ):
  64. """
  65. Private method that wraps the recursion for module replacement.
  66. Returns the converted model and a boolean that indicates if the conversion has been successfull or not.
  67. """
  68. if current_key_name is None:
  69. current_key_name = []
  70. for name, module in model.named_children():
  71. current_key_name.append(name)
  72. if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
  73. # Check if the current key is not in the `modules_to_not_convert`
  74. current_key_name_str = ".".join(current_key_name)
  75. if not any(
  76. (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
  77. ):
  78. with init_empty_weights(include_buffers=True):
  79. in_features = module.in_features
  80. out_features = module.out_features
  81. model._modules[name] = FbgemmFp8Linear(
  82. in_features,
  83. out_features,
  84. module.bias is not None,
  85. )
  86. has_been_replaced = True
  87. # Force requires grad to False to avoid unexpected errors
  88. model._modules[name].requires_grad_(False)
  89. # set non persistant buffer outside of init_empty_weights
  90. model._modules[name].input_scale_ub = torch.tensor(
  91. [quantization_config.activation_scale_ub], dtype=torch.float
  92. )
  93. if len(list(module.children())) > 0:
  94. _, has_been_replaced = _replace_with_fbgemm_fp8_linear(
  95. module,
  96. modules_to_not_convert,
  97. current_key_name,
  98. quantization_config,
  99. has_been_replaced=has_been_replaced,
  100. pre_quantized=pre_quantized,
  101. )
  102. # Remove the last key for recursion
  103. current_key_name.pop(-1)
  104. return model, has_been_replaced
  105. def replace_with_fbgemm_fp8_linear(
  106. model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
  107. ):
  108. """
  109. A helper function to replace all `torch.nn.Linear` modules by `FbgemmFp8Linear` modules.
  110. This will enable running your models using high performance fp8 kernel from FBGEMM library.
  111. The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
  112. be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
  113. CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
  114. Parameters:
  115. model (`torch.nn.Module`):
  116. Input model or `torch.nn.Module` as the function is run recursively.
  117. modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
  118. Names of the modules to not convert in `FP8Linear`. In practice we keep the `lm_head` in full precision
  119. for numerical stability reasons.
  120. current_key_name (`List[`str`]`, *optional*):
  121. An array to track the current key of the recursion. This is used to check whether the current key (part of
  122. it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
  123. `disk`).
  124. """
  125. modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
  126. if quantization_config.modules_to_not_convert is not None:
  127. modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
  128. modules_to_not_convert = list(set(modules_to_not_convert))
  129. model, has_been_replaced = _replace_with_fbgemm_fp8_linear(
  130. model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
  131. )
  132. if not has_been_replaced:
  133. logger.warning(
  134. "You are loading your model using FP8 quantization but no linear modules were found in your model."
  135. " Please double check your model architecture, or submit an issue on github if you think this is"
  136. " a bug."
  137. )
  138. return model