quant_modules.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820
  1. # coding=utf-8
  2. # Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
  3. # Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
  4. # Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import decimal
  18. import numpy as np
  19. import torch
  20. from torch import nn
  21. from torch.autograd import Function
  22. from ...utils import logging
  23. logger = logging.get_logger(__name__)
  24. class QuantEmbedding(nn.Module):
  25. """
  26. Quantized version of `torch.nn.Embedding`. Adds quantization-specific arguments on top of `torch.nn.Embedding`.
  27. Args:
  28. weight_bit (`int`, *optional*, defaults to `8`):
  29. Bitwidth for the quantized weight.
  30. momentum (`float`, *optional*, defaults to `0.95`):
  31. Momentum for updating the activation quantization range.
  32. quant_mode (`bool`, *optional*, defaults to `False`):
  33. Whether or not the layer is quantized.
  34. """
  35. def __init__(
  36. self,
  37. num_embeddings,
  38. embedding_dim,
  39. padding_idx=None,
  40. max_norm=None,
  41. norm_type=2.0,
  42. scale_grad_by_freq=False,
  43. sparse=False,
  44. _weight=None,
  45. weight_bit=8,
  46. momentum=0.95,
  47. quant_mode=False,
  48. ):
  49. super().__init__()
  50. self.num_ = num_embeddings
  51. self.dim = embedding_dim
  52. self.padding_idx = padding_idx
  53. self.max_norm = max_norm
  54. self.norm_type = norm_type
  55. self.scale_grad_by_freq = scale_grad_by_freq
  56. self.sparse = sparse
  57. self.weight = nn.Parameter(torch.zeros([num_embeddings, embedding_dim]))
  58. self.register_buffer("weight_scaling_factor", torch.zeros(1))
  59. self.register_buffer("weight_integer", torch.zeros_like(self.weight))
  60. self.weight_bit = weight_bit
  61. self.momentum = momentum
  62. self.quant_mode = quant_mode
  63. self.percentile_mode = False
  64. self.weight_function = SymmetricQuantFunction.apply
  65. def forward(self, x, positions=None, incremental_state=None):
  66. if not self.quant_mode:
  67. return (
  68. nn.functional.embedding(
  69. x,
  70. self.weight,
  71. self.padding_idx,
  72. self.max_norm,
  73. self.norm_type,
  74. self.scale_grad_by_freq,
  75. self.sparse,
  76. ),
  77. None,
  78. )
  79. w = self.weight
  80. w_transform = w.data.detach()
  81. w_min = w_transform.min().expand(1)
  82. w_max = w_transform.max().expand(1)
  83. self.weight_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, False)
  84. self.weight_integer = self.weight_function(
  85. self.weight, self.weight_bit, self.percentile_mode, self.weight_scaling_factor
  86. )
  87. emb_int = nn.functional.embedding(
  88. x,
  89. self.weight_integer,
  90. self.padding_idx,
  91. self.max_norm,
  92. self.norm_type,
  93. self.scale_grad_by_freq,
  94. self.sparse,
  95. )
  96. return emb_int * self.weight_scaling_factor, self.weight_scaling_factor
  97. class QuantAct(nn.Module):
  98. """
  99. Quantizes the given activation.
  100. Args:
  101. activation_bit (`int`):
  102. Bitwidth for the quantized activation.
  103. act_range_momentum (`float`, *optional*, defaults to `0.95`):
  104. Momentum for updating the activation quantization range.
  105. per_channel (`bool`, *optional*, defaults to `False`):
  106. Whether to or not use channel-wise quantization.
  107. channel_len (`int`, *optional*):
  108. Specify the channel length when set the *per_channel* True.
  109. quant_mode (`bool`, *optional*, defaults to `False`):
  110. Whether or not the layer is quantized.
  111. """
  112. def __init__(self, activation_bit, act_range_momentum=0.95, per_channel=False, channel_len=None, quant_mode=False):
  113. super().__init__()
  114. self.activation_bit = activation_bit
  115. self.act_range_momentum = act_range_momentum
  116. self.quant_mode = quant_mode
  117. self.per_channel = per_channel
  118. self.percentile = False
  119. self.act_function = SymmetricQuantFunction.apply
  120. if not self.per_channel:
  121. self.register_buffer("x_min", torch.zeros(1))
  122. self.register_buffer("x_max", torch.zeros(1))
  123. self.register_buffer("act_scaling_factor", torch.zeros(1))
  124. self.x_min -= 1e-5
  125. self.x_max += 1e-5
  126. else:
  127. raise NotImplementedError("per-channel mode is not currently supported for activation.")
  128. def __repr__(self):
  129. return (
  130. f"{self.__class__.__name__}(activation_bit={self.activation_bit}, "
  131. f"quant_mode: {self.quant_mode}, Act_min: {self.x_min.item():.2f}, "
  132. f"Act_max: {self.x_max.item():.2f})"
  133. )
  134. def forward(
  135. self,
  136. x,
  137. pre_act_scaling_factor=None,
  138. identity=None,
  139. identity_scaling_factor=None,
  140. specified_min=None,
  141. specified_max=None,
  142. ):
  143. x_act = x if identity is None else identity + x
  144. # collect running stats if training
  145. if self.training:
  146. assert not self.percentile, "percentile mode is not currently supported for activation."
  147. assert not self.per_channel, "per-channel mode is not currently supported for activation."
  148. x_min = x_act.data.min()
  149. x_max = x_act.data.max()
  150. assert (
  151. x_max.isnan().sum() == 0 and x_min.isnan().sum() == 0
  152. ), "NaN detected when computing min/max of the activation"
  153. # Initialization
  154. if self.x_min.min() > -1.1e-5 and self.x_max.max() < 1.1e-5:
  155. self.x_min = self.x_min + x_min
  156. self.x_max = self.x_max + x_max
  157. # exponential moving average (EMA)
  158. # use momentum to prevent the quantized values change greatly every iteration
  159. elif self.act_range_momentum == -1:
  160. self.x_min = torch.min(self.x_min, x_min)
  161. self.x_max = torch.max(self.x_max, x_max)
  162. else:
  163. self.x_min = self.x_min * self.act_range_momentum + x_min * (1 - self.act_range_momentum)
  164. self.x_max = self.x_max * self.act_range_momentum + x_max * (1 - self.act_range_momentum)
  165. if not self.quant_mode:
  166. return x_act, None
  167. x_min = self.x_min if specified_min is None else specified_min
  168. x_max = self.x_max if specified_max is None else specified_max
  169. self.act_scaling_factor = symmetric_linear_quantization_params(
  170. self.activation_bit, x_min, x_max, per_channel=self.per_channel
  171. )
  172. if pre_act_scaling_factor is None:
  173. # this is for the input quantization
  174. quant_act_int = self.act_function(x, self.activation_bit, self.percentile, self.act_scaling_factor)
  175. else:
  176. quant_act_int = FixedPointMul.apply(
  177. x,
  178. pre_act_scaling_factor,
  179. self.activation_bit,
  180. self.act_scaling_factor,
  181. identity,
  182. identity_scaling_factor,
  183. )
  184. correct_output_scale = self.act_scaling_factor.view(-1)
  185. return quant_act_int * correct_output_scale, self.act_scaling_factor
  186. class QuantLinear(nn.Module):
  187. """
  188. Quantized version of `torch.nn.Linear`. Adds quantization-specific arguments on top of `torch.nn.Linear`.
  189. Args:
  190. weight_bit (`int`, *optional*, defaults to `8`):
  191. Bitwidth for the quantized weight.
  192. bias_bit (`int`, *optional*, defaults to `32`):
  193. Bitwidth for the quantized bias.
  194. per_channel (`bool`, *optional*, defaults to `False`):
  195. Whether or not to use channel-wise quantization.
  196. quant_mode (`bool`, *optional*, defaults to `False`):
  197. Whether or not the layer is quantized.
  198. """
  199. def __init__(
  200. self, in_features, out_features, bias=True, weight_bit=8, bias_bit=32, per_channel=False, quant_mode=False
  201. ):
  202. super().__init__()
  203. self.in_features = in_features
  204. self.out_features = out_features
  205. self.weight = nn.Parameter(torch.zeros([out_features, in_features]))
  206. self.register_buffer("weight_integer", torch.zeros_like(self.weight))
  207. self.register_buffer("fc_scaling_factor", torch.zeros(self.out_features))
  208. if bias:
  209. self.bias = nn.Parameter(torch.zeros(out_features))
  210. self.register_buffer("bias_integer", torch.zeros_like(self.bias))
  211. self.weight_bit = weight_bit
  212. self.quant_mode = quant_mode
  213. self.per_channel = per_channel
  214. self.bias_bit = bias_bit
  215. self.quant_mode = quant_mode
  216. self.percentile_mode = False
  217. self.weight_function = SymmetricQuantFunction.apply
  218. def __repr__(self):
  219. s = super().__repr__()
  220. s = f"({s} weight_bit={self.weight_bit}, quant_mode={self.quant_mode})"
  221. return s
  222. def forward(self, x, prev_act_scaling_factor=None):
  223. if not self.quant_mode:
  224. return nn.functional.linear(x, weight=self.weight, bias=self.bias), None
  225. # assert that prev_act_scaling_factor is a scalar tensor
  226. assert prev_act_scaling_factor is not None and prev_act_scaling_factor.shape == (1,), (
  227. "Input activation to the QuantLinear layer should be globally (non-channel-wise) quantized. "
  228. "Please add a QuantAct layer with `per_channel = True` before this QuantAct layer"
  229. )
  230. w = self.weight
  231. w_transform = w.data.detach()
  232. if self.per_channel:
  233. w_min, _ = torch.min(w_transform, dim=1, out=None)
  234. w_max, _ = torch.max(w_transform, dim=1, out=None)
  235. else:
  236. w_min = w_transform.min().expand(1)
  237. w_max = w_transform.max().expand(1)
  238. self.fc_scaling_factor = symmetric_linear_quantization_params(self.weight_bit, w_min, w_max, self.per_channel)
  239. self.weight_integer = self.weight_function(
  240. self.weight, self.weight_bit, self.percentile_mode, self.fc_scaling_factor
  241. )
  242. bias_scaling_factor = self.fc_scaling_factor * prev_act_scaling_factor
  243. if self.bias is not None:
  244. self.bias_integer = self.weight_function(self.bias, self.bias_bit, False, bias_scaling_factor)
  245. prev_act_scaling_factor = prev_act_scaling_factor.view(1, -1)
  246. x_int = x / prev_act_scaling_factor
  247. return (
  248. nn.functional.linear(x_int, weight=self.weight_integer, bias=self.bias_integer) * bias_scaling_factor,
  249. bias_scaling_factor,
  250. )
  251. class IntGELU(nn.Module):
  252. """
  253. Quantized version of `torch.nn.GELU`. Adds quantization-specific arguments on top of `torch.nn.GELU`.
  254. Args:
  255. quant_mode (`bool`, *optional*, defaults to `False`):
  256. Whether or not the layer is quantized.
  257. force_dequant (`str`, *optional*, defaults to `"none"`):
  258. Force dequantize the layer if either "gelu" or "nonlinear" is given.
  259. """
  260. def __init__(self, quant_mode=True, force_dequant="none"):
  261. super().__init__()
  262. self.quant_mode = quant_mode
  263. if force_dequant in ["nonlinear", "gelu"]:
  264. logger.info("Force dequantize gelu")
  265. self.quant_mode = False
  266. if not self.quant_mode:
  267. self.activation_fn = nn.GELU()
  268. self.k = 1.4142
  269. self.const = 14 # dummy integer constant
  270. self.coeff = [-0.2888, -1.769, 1] # a(x+b)**2 + c
  271. self.coeff[2] /= self.coeff[0]
  272. def int_erf(self, x_int, scaling_factor):
  273. b_int = torch.floor(self.coeff[1] / scaling_factor)
  274. c_int = torch.floor(self.coeff[2] / scaling_factor**2)
  275. sign = torch.sign(x_int)
  276. abs_int = torch.min(torch.abs(x_int), -b_int)
  277. y_int = sign * ((abs_int + b_int) ** 2 + c_int)
  278. scaling_factor = scaling_factor**2 * self.coeff[0]
  279. # avoid overflow
  280. y_int = floor_ste.apply(y_int / 2**self.const)
  281. scaling_factor = scaling_factor * 2**self.const
  282. return y_int, scaling_factor
  283. def forward(self, x, scaling_factor=None):
  284. if not self.quant_mode:
  285. return self.activation_fn(x), None
  286. x_int = x / scaling_factor
  287. sigmoid_int, sigmoid_scaling_factor = self.int_erf(x_int, scaling_factor / self.k)
  288. shift_int = 1.0 // sigmoid_scaling_factor
  289. x_int = x_int * (sigmoid_int + shift_int)
  290. scaling_factor = scaling_factor * sigmoid_scaling_factor / 2
  291. return x_int * scaling_factor, scaling_factor
  292. class IntSoftmax(nn.Module):
  293. """
  294. Quantized version of `torch.nn.Softmax`. Adds quantization-specific arguments on top of `torch.nn.Softmax`.
  295. Args:
  296. output_bit (`int`):
  297. Bitwidth for the layer output activation.
  298. quant_mode (`bool`, *optional*, defaults to `False`):
  299. Whether or not the layer is quantized.
  300. force_dequant (`str`, *optional*, defaults to `"none"`):
  301. Force dequantize the layer if either "softmax" or "nonlinear" is given.
  302. """
  303. def __init__(self, output_bit, quant_mode=False, force_dequant="none"):
  304. super().__init__()
  305. self.output_bit = output_bit
  306. self.max_bit = 32
  307. self.quant_mode = quant_mode
  308. if force_dequant in ["nonlinear", "softmax"]:
  309. logger.info("Force dequantize softmax")
  310. self.quant_mode = False
  311. self.act = QuantAct(16, quant_mode=self.quant_mode)
  312. self.x0 = -0.6931 # -ln2
  313. self.const = 30 # dummy integer constant
  314. self.coef = [0.35815147, 0.96963238, 1.0] # ax**2 + bx + c
  315. self.coef[1] /= self.coef[0]
  316. self.coef[2] /= self.coef[0]
  317. def int_polynomial(self, x_int, scaling_factor):
  318. with torch.no_grad():
  319. b_int = torch.floor(self.coef[1] / scaling_factor)
  320. c_int = torch.floor(self.coef[2] / scaling_factor**2)
  321. z = (x_int + b_int) * x_int + c_int
  322. scaling_factor = self.coef[0] * scaling_factor**2
  323. return z, scaling_factor
  324. def int_exp(self, x_int, scaling_factor):
  325. with torch.no_grad():
  326. x0_int = torch.floor(self.x0 / scaling_factor)
  327. x_int = torch.max(x_int, self.const * x0_int)
  328. q = floor_ste.apply(x_int / x0_int)
  329. r = x_int - x0_int * q
  330. exp_int, exp_scaling_factor = self.int_polynomial(r, scaling_factor)
  331. exp_int = torch.clamp(floor_ste.apply(exp_int * 2 ** (self.const - q)), min=0)
  332. scaling_factor = exp_scaling_factor / 2**self.const
  333. return exp_int, scaling_factor
  334. def forward(self, x, scaling_factor):
  335. if not self.quant_mode:
  336. return nn.functional.softmax(x, dim=-1), None
  337. x_int = x / scaling_factor
  338. x_int_max, _ = x_int.max(dim=-1, keepdim=True)
  339. x_int = x_int - x_int_max
  340. exp_int, exp_scaling_factor = self.int_exp(x_int, scaling_factor)
  341. # Avoid overflow
  342. exp, exp_scaling_factor = self.act(exp_int, exp_scaling_factor)
  343. exp_int = exp / exp_scaling_factor
  344. exp_int_sum = exp_int.sum(dim=-1, keepdim=True)
  345. factor = floor_ste.apply(2**self.max_bit / exp_int_sum)
  346. exp_int = floor_ste.apply(exp_int * factor / 2 ** (self.max_bit - self.output_bit))
  347. scaling_factor = 1 / 2**self.output_bit
  348. return exp_int * scaling_factor, scaling_factor
  349. class IntLayerNorm(nn.Module):
  350. """
  351. Quantized version of `torch.nn.LayerNorm`. Adds quantization-specific arguments on top of `torch.nn.LayerNorm`.
  352. Args:
  353. output_bit (`int`, *optional*, defaults to `8`):
  354. Bitwidth for the layer output activation.
  355. quant_mode (`bool`, *optional*, defaults to `False`):
  356. Whether or not the layer is quantized.
  357. force_dequant (`str`, *optional*, defaults to `"none"`):
  358. Force dequantize the layer if either "layernorm" or "nonlinear" is given.
  359. """
  360. def __init__(self, normalized_shape, eps, output_bit=8, quant_mode=False, force_dequant="none"):
  361. super().__init__()
  362. self.normalized_shape = normalized_shape
  363. self.eps = eps
  364. self.weight = nn.Parameter(torch.zeros(normalized_shape))
  365. self.bias = nn.Parameter(torch.zeros(normalized_shape))
  366. self.quant_mode = quant_mode
  367. if force_dequant in ["nonlinear", "layernorm"]:
  368. logger.info("Force dequantize layernorm")
  369. self.quant_mode = False
  370. self.register_buffer("shift", torch.zeros(1))
  371. self.output_bit = output_bit
  372. self.max_bit = 32
  373. self.dim_sqrt = None
  374. self.activation = QuantAct(self.output_bit, quant_mode=self.quant_mode)
  375. def set_shift(self, y_int):
  376. with torch.no_grad():
  377. y_sq_int = y_int**2
  378. var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
  379. shift = (torch.log2(torch.sqrt(var_int / 2**self.max_bit)).ceil()).max()
  380. shift_old = self.shift
  381. self.shift = torch.max(self.shift, shift)
  382. logger.info(f"Dynamic shift adjustment: {int(shift_old)} -> {int(self.shift)}")
  383. def overflow_fallback(self, y_int):
  384. """
  385. This fallback function is called when overflow is detected during training time, and adjusts the `self.shift`
  386. to avoid overflow in the subsequent runs.
  387. """
  388. self.set_shift(y_int) # adjusts `self.shift`
  389. y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
  390. y_sq_int = y_int_shifted**2
  391. var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
  392. return var_int
  393. def forward(self, x, scaling_factor=None):
  394. if not self.quant_mode:
  395. mean = x.mean(axis=2, keepdim=True)
  396. y = x - mean
  397. var = torch.mean(y**2, axis=2, keepdim=True)
  398. x = y / torch.sqrt(self.eps + var)
  399. x = x * self.weight + self.bias
  400. return x, None
  401. # compute sqrt of the feature dimension if it is the first run
  402. if self.dim_sqrt is None:
  403. n = torch.tensor(x.shape[2], dtype=torch.float)
  404. self.dim_sqrt = torch.sqrt(n).to(x.device)
  405. # Normalization: computes mean and variance(std)
  406. x_int = x / scaling_factor
  407. mean_int = round_ste.apply(x_int.mean(axis=2, keepdim=True))
  408. y_int = x_int - mean_int
  409. y_int_shifted = floor_ste.apply(y_int / 2**self.shift)
  410. y_sq_int = y_int_shifted**2
  411. var_int = torch.sum(y_sq_int, axis=2, keepdim=True)
  412. # overflow handling in training time
  413. if self.training:
  414. # if overflow is detected
  415. if var_int.max() >= 2**self.max_bit:
  416. var_int = self.overflow_fallback(y_int)
  417. assert var_int.max() < 2**self.max_bit + 0.1, (
  418. "Error detected in overflow handling: "
  419. "`var_int` exceeds `self.max_bit` (the maximum possible bit width)"
  420. )
  421. # To be replaced with integer-sqrt kernel that produces the same output
  422. std_int = floor_ste.apply(torch.sqrt(var_int)) * 2**self.shift
  423. factor = floor_ste.apply(2**31 / std_int)
  424. y_int = floor_ste.apply(y_int * factor / 2)
  425. scaling_factor = self.dim_sqrt / 2**30
  426. # scaling and shifting
  427. bias = self.bias.data.detach() / (self.weight.data.detach())
  428. bias_int = floor_ste.apply(bias / scaling_factor)
  429. y_int = y_int + bias_int
  430. scaling_factor = scaling_factor * self.weight
  431. x = y_int * scaling_factor
  432. return x, scaling_factor
  433. def get_percentile_min_max(input, lower_percentile, upper_percentile, output_tensor=False):
  434. """
  435. Calculate the percentile max and min values in a given tensor
  436. Args:
  437. input (`torch.Tensor`):
  438. The target tensor to calculate percentile max and min.
  439. lower_percentile (`float`):
  440. If 0.1, means we return the value of the smallest 0.1% value in the tensor as percentile min.
  441. upper_percentile (`float`):
  442. If 99.9, means we return the value of the largest 0.1% value in the tensor as percentile max.
  443. output_tensor (`bool`, *optional*, defaults to `False`):
  444. If True, this function returns tensors, otherwise it returns values.
  445. Returns:
  446. `Tuple(torch.Tensor, torch.Tensor)`: Percentile min and max value of *input*
  447. """
  448. input_length = input.shape[0]
  449. lower_index = round(input_length * (1 - lower_percentile * 0.01))
  450. upper_index = round(input_length * upper_percentile * 0.01)
  451. upper_bound = torch.kthvalue(input, k=upper_index).values
  452. if lower_percentile == 0:
  453. lower_bound = upper_bound * 0
  454. # lower_index += 1
  455. else:
  456. lower_bound = -torch.kthvalue(-input, k=lower_index).values
  457. if not output_tensor:
  458. lower_bound = lower_bound.item()
  459. upper_bound = upper_bound.item()
  460. return lower_bound, upper_bound
  461. def linear_quantize(input, scale, zero_point, inplace=False):
  462. """
  463. Quantize single-precision input tensor to integers with the given scaling factor and zeropoint.
  464. Args:
  465. input (`torch.Tensor`):
  466. Single-precision input tensor to be quantized.
  467. scale (`torch.Tensor`):
  468. Scaling factor for quantization.
  469. zero_pint (`torch.Tensor`):
  470. Shift for quantization.
  471. inplace (`bool`, *optional*, defaults to `False`):
  472. Whether to compute inplace or not.
  473. Returns:
  474. `torch.Tensor`: Linearly quantized value of *input* according to *scale* and *zero_point*.
  475. """
  476. # reshape scale and zeropoint for convolutional weights and activation
  477. if len(input.shape) == 4:
  478. scale = scale.view(-1, 1, 1, 1)
  479. zero_point = zero_point.view(-1, 1, 1, 1)
  480. # reshape scale and zeropoint for linear weights
  481. elif len(input.shape) == 2:
  482. scale = scale.view(-1, 1)
  483. zero_point = zero_point.view(-1, 1)
  484. else:
  485. scale = scale.view(-1)
  486. zero_point = zero_point.view(-1)
  487. # quantized = float / scale + zero_point
  488. if inplace:
  489. input.mul_(1.0 / scale).add_(zero_point).round_()
  490. return input
  491. return torch.round(1.0 / scale * input + zero_point)
  492. def symmetric_linear_quantization_params(num_bits, saturation_min, saturation_max, per_channel=False):
  493. """
  494. Compute the scaling factor with the given quantization range for symmetric quantization.
  495. Args:
  496. saturation_min (`torch.Tensor`):
  497. Lower bound for quantization range.
  498. saturation_max (`torch.Tensor`):
  499. Upper bound for quantization range.
  500. per_channel (`bool`, *optional*, defaults to `False`):
  501. Whether to or not use channel-wise quantization.
  502. Returns:
  503. `torch.Tensor`: Scaling factor that linearly quantizes the given range between *saturation_min* and
  504. *saturation_max*.
  505. """
  506. # in this part, we do not need any gradient computation,
  507. # in order to enforce this, we put torch.no_grad()
  508. with torch.no_grad():
  509. n = 2 ** (num_bits - 1) - 1
  510. if per_channel:
  511. scale, _ = torch.max(torch.stack([saturation_min.abs(), saturation_max.abs()], dim=1), dim=1)
  512. scale = torch.clamp(scale, min=1e-8) / n
  513. else:
  514. scale = max(saturation_min.abs(), saturation_max.abs())
  515. scale = torch.clamp(scale, min=1e-8) / n
  516. return scale
  517. class SymmetricQuantFunction(Function):
  518. """
  519. Class to quantize the given floating-point values using symmetric quantization with given range and bitwidth.
  520. """
  521. @staticmethod
  522. def forward(ctx, x, k, percentile_mode, scale):
  523. """
  524. Args:
  525. x (`torch.Tensor`):
  526. Floating point tensor to be quantized.
  527. k (`int`):
  528. Quantization bitwidth.
  529. percentile_mode (`bool`):
  530. Whether or not to use percentile calibration.
  531. scale (`torch.Tensor`):
  532. Pre-calculated scaling factor for *x*. Note that the current implementation of SymmetricQuantFunction
  533. requires pre-calculated scaling factor.
  534. Returns:
  535. `torch.Tensor`: Symmetric-quantized value of *input*.
  536. """
  537. zero_point = torch.tensor(0.0).to(scale.device)
  538. n = 2 ** (k - 1) - 1
  539. new_quant_x = linear_quantize(x, scale, zero_point, inplace=False)
  540. new_quant_x = torch.clamp(new_quant_x, -n, n - 1)
  541. ctx.scale = scale
  542. return new_quant_x
  543. @staticmethod
  544. def backward(ctx, grad_output):
  545. scale = ctx.scale
  546. if len(grad_output.shape) == 4:
  547. scale = scale.view(-1, 1, 1, 1)
  548. # reshape scale and zeropoint for linear weights
  549. elif len(grad_output.shape) == 2:
  550. scale = scale.view(-1, 1)
  551. else:
  552. scale = scale.view(-1)
  553. return grad_output.clone() / scale, None, None, None, None
  554. class floor_ste(Function):
  555. """
  556. Straight-through Estimator(STE) for torch.floor()
  557. """
  558. @staticmethod
  559. def forward(ctx, x):
  560. return torch.floor(x)
  561. @staticmethod
  562. def backward(ctx, grad_output):
  563. return grad_output.clone()
  564. class round_ste(Function):
  565. """
  566. Straight-through Estimator(STE) for torch.round()
  567. """
  568. @staticmethod
  569. def forward(ctx, x):
  570. return torch.round(x)
  571. @staticmethod
  572. def backward(ctx, grad_output):
  573. return grad_output.clone()
  574. def batch_frexp(inputs, max_bit=31):
  575. """
  576. Decompose the scaling factor into mantissa and twos exponent.
  577. Args:
  578. scaling_factor (`torch.Tensor`):
  579. Target scaling factor to decompose.
  580. Returns:
  581. ``Tuple(torch.Tensor, torch.Tensor)`: mantisa and exponent
  582. """
  583. shape_of_input = inputs.size()
  584. # trans the input to be a 1-d tensor
  585. inputs = inputs.view(-1)
  586. output_m, output_e = np.frexp(inputs.cpu().numpy())
  587. tmp_m = []
  588. for m in output_m:
  589. int_m_shifted = int(
  590. decimal.Decimal(m * (2**max_bit)).quantize(decimal.Decimal("1"), rounding=decimal.ROUND_HALF_UP)
  591. )
  592. tmp_m.append(int_m_shifted)
  593. output_m = np.array(tmp_m)
  594. output_e = float(max_bit) - output_e
  595. return (
  596. torch.from_numpy(output_m).to(inputs.device).view(shape_of_input),
  597. torch.from_numpy(output_e).to(inputs.device).view(shape_of_input),
  598. )
  599. class FixedPointMul(Function):
  600. """
  601. Function to perform fixed-point arithmetic that can match integer arithmetic on hardware.
  602. Args:
  603. pre_act (`torch.Tensor`):
  604. Input tensor.
  605. pre_act_scaling_factor (`torch.Tensor`):
  606. Scaling factor of the input tensor *pre_act*.
  607. bit_num (`int`):
  608. Quantization bitwidth.
  609. z_scaling_factor (`torch.Tensor`):
  610. Scaling factor of the output tensor.
  611. identity (`torch.Tensor`, *optional*):
  612. Identity tensor, if exists.
  613. identity_scaling_factor (`torch.Tensor`, *optional*):
  614. Scaling factor of the identity tensor *identity*, if exists.
  615. Returns:
  616. `torch.Tensor`: Output tensor(*pre_act* if *identity* is not given, otherwise the addition of *pre_act* and
  617. *identity*), whose scale is rescaled to *z_scaling_factor*.
  618. """
  619. @staticmethod
  620. def forward(
  621. ctx,
  622. pre_act,
  623. pre_act_scaling_factor,
  624. bit_num,
  625. z_scaling_factor,
  626. identity=None,
  627. identity_scaling_factor=None,
  628. ):
  629. if len(pre_act_scaling_factor.shape) == 3:
  630. reshape = lambda x: x # noqa: E731
  631. else:
  632. reshape = lambda x: x.view(1, 1, -1) # noqa: E731
  633. ctx.identity = identity
  634. n = 2 ** (bit_num - 1) - 1
  635. with torch.no_grad():
  636. pre_act_scaling_factor = reshape(pre_act_scaling_factor)
  637. if identity is not None:
  638. identity_scaling_factor = reshape(identity_scaling_factor)
  639. ctx.z_scaling_factor = z_scaling_factor
  640. z_int = torch.round(pre_act / pre_act_scaling_factor)
  641. _A = pre_act_scaling_factor.type(torch.double)
  642. _B = (z_scaling_factor.type(torch.float)).type(torch.double)
  643. new_scale = _A / _B
  644. new_scale = reshape(new_scale)
  645. m, e = batch_frexp(new_scale)
  646. output = z_int.type(torch.double) * m.type(torch.double)
  647. output = torch.round(output / (2.0**e))
  648. if identity is not None:
  649. # needs addition of identity activation
  650. wx_int = torch.round(identity / identity_scaling_factor)
  651. _A = identity_scaling_factor.type(torch.double)
  652. _B = (z_scaling_factor.type(torch.float)).type(torch.double)
  653. new_scale = _A / _B
  654. new_scale = reshape(new_scale)
  655. m1, e1 = batch_frexp(new_scale)
  656. output1 = wx_int.type(torch.double) * m1.type(torch.double)
  657. output1 = torch.round(output1 / (2.0**e1))
  658. output = output1 + output
  659. return torch.clamp(output.type(torch.float), -n - 1, n)
  660. @staticmethod
  661. def backward(ctx, grad_output):
  662. identity_grad = None
  663. if ctx.identity is not None:
  664. identity_grad = grad_output.clone() / ctx.z_scaling_factor
  665. return grad_output.clone() / ctx.z_scaling_factor, None, None, None, None, identity_grad, None