common_quantized.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # mypy: ignore-errors
  2. r"""Importing this file includes common utility methods for checking quantized
  3. tensors and modules.
  4. """
  5. import numpy as np
  6. import torch
  7. from contextlib import contextmanager
  8. from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_PPC, IS_MACOS, IS_WINDOWS
  9. supported_qengines = torch.backends.quantized.supported_engines
  10. supported_qengines.remove('none')
  11. # Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326
  12. # QNNPACK is not supported on PPC
  13. # QNNPACK throws ASAN heap-buffer-overflow error.
  14. if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_ASAN, TEST_WITH_TSAN, TEST_WITH_UBSAN, IS_MACOS, IS_WINDOWS]):
  15. supported_qengines.remove('qnnpack')
  16. def _conv_output_shape(input_size, kernel_size, padding, stride, dilation,
  17. output_padding=0):
  18. """Computes the output shape given convolution parameters."""
  19. return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1)
  20. * (dilation - 1)) / stride) + 2 * output_padding + 1
  21. # Quantization references
  22. def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8):
  23. """Quantizes a numpy array."""
  24. if qmin is None:
  25. qmin = np.iinfo(dtype).min
  26. if qmax is None:
  27. qmax = np.iinfo(dtype).max
  28. qx = np.round(x / scale + zero_point).astype(np.int64)
  29. qx = np.clip(qx, qmin, qmax)
  30. qx = qx.astype(dtype)
  31. return qx
  32. def _dequantize(qx, scale, zero_point):
  33. """Dequantizes a numpy array."""
  34. x = (qx.astype(float) - zero_point) * scale
  35. return x
  36. def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8):
  37. """Requantizes a numpy array, i.e., intermediate int32 or int16 values are
  38. converted back to given type"""
  39. qx = (x * multiplier).round() + zero_point
  40. qx = np.clip(qx, qmin, qmax).astype(qtype)
  41. return qx
  42. def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine):
  43. """Calculate the dynamic quantization parameters (scale, zero_point)
  44. according to the min and max element of the tensor"""
  45. assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric)
  46. if qscheme == torch.per_tensor_symmetric:
  47. assert dtype == torch.qint8
  48. if isinstance(X, torch.Tensor):
  49. X = X.numpy()
  50. if dtype == torch.qint8:
  51. if reduce_range:
  52. qmin, qmax = -64, 63
  53. else:
  54. qmin, qmax = -128, 127
  55. else: # dtype == torch.quint8
  56. if reduce_range:
  57. qmin, qmax = 0, 127
  58. else:
  59. qmin, qmax = 0, 255
  60. min_val = X.min()
  61. max_val = X.max()
  62. is_symmetric = (qscheme == torch.per_tensor_symmetric)
  63. if min_val == max_val:
  64. scale = 1.0
  65. zero_point = 0
  66. else:
  67. if is_symmetric:
  68. max_val = max(max_val, -min_val)
  69. min_val = -max_val
  70. scale = (max_val - min_val) / (qmax - qmin)
  71. scale = max(scale, np.finfo(np.float32).eps)
  72. zero_point = 0
  73. else:
  74. max_val = max(max_val, 0.0)
  75. min_val = min(min_val, 0.0)
  76. scale = (max_val - min_val) / (qmax - qmin)
  77. scale = max(scale, np.finfo(np.float32).eps)
  78. zero_point = qmin - round(min_val / scale)
  79. zero_point = max(qmin, zero_point)
  80. zero_point = min(qmax, zero_point)
  81. return [float(scale), int(zero_point)]
  82. def _calculate_dynamic_per_channel_qparams(X, dtype):
  83. """Calculate the dynamic quantization parameters (scale, zero_point)
  84. according to the min and max element of the tensor"""
  85. if isinstance(X, torch.Tensor):
  86. X = X.numpy()
  87. qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max
  88. n_levels = qmax - qmin
  89. scale = np.zeros(X.shape[0], dtype=np.float64)
  90. zero_point = np.zeros(X.shape[0], dtype=np.int64)
  91. for i in range(zero_point.shape[0]):
  92. min_val = X.min()
  93. max_val = X.max()
  94. if min_val == max_val:
  95. scale[i] = 1.0
  96. zero_point[i] = 0
  97. else:
  98. max_val = max(max_val, 0.0)
  99. min_val = min(min_val, 0.0)
  100. scale[i] = (max_val - min_val) / n_levels
  101. scale[i] = max(scale[i], np.finfo(np.float32).eps)
  102. zero_point[i] = qmin - round(min_val / scale[i])
  103. zero_point[i] = max(qmin, zero_point[i])
  104. zero_point[i] = min(qmax, zero_point[i])
  105. return scale, zero_point
  106. def _snr(x, x_hat):
  107. """Calculates the signal to noise ratio and returns the signal and noise
  108. power, as well as the SNR in dB.
  109. If the input is a list/tuple this function is called recursively on each
  110. element. The result will have the same nested structure as the inputs.
  111. Args:
  112. x, x_hat: Either a tensor or a nested list/tuple of tensors.
  113. Returns:
  114. signal, noise, SNR(in dB): Either floats or a nested list of floats
  115. """
  116. if isinstance(x, (list, tuple)):
  117. assert len(x) == len(x_hat)
  118. res = []
  119. for idx in range(len(x)):
  120. res.append(_snr(x[idx], x_hat[idx]))
  121. return res
  122. if x_hat.is_quantized:
  123. x_hat = x_hat.dequantize()
  124. if x.is_quantized:
  125. x = x.dequantize()
  126. noise = (x - x_hat).norm()
  127. if noise == 0:
  128. return 0.0, float('inf'), float('inf')
  129. signal = x.norm()
  130. snr = signal / noise
  131. snr_db = 20 * snr.log10()
  132. return signal, noise, snr_db
  133. @contextmanager
  134. def override_quantized_engine(qengine):
  135. previous = torch.backends.quantized.engine
  136. torch.backends.quantized.engine = qengine
  137. try:
  138. yield
  139. finally:
  140. torch.backends.quantized.engine = previous
  141. @contextmanager
  142. def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack):
  143. try:
  144. if qengine_is_qnnpack:
  145. torch._C._set_default_mobile_cpu_allocator()
  146. yield
  147. finally:
  148. if qengine_is_qnnpack:
  149. torch._C._unset_default_mobile_cpu_allocator()
  150. # TODO: Update all quantization tests to use this decorator.
  151. # Currently for some of the tests it seems to have inconsistent params
  152. # for fbgemm vs qnnpack.
  153. def override_qengines(qfunction):
  154. def test_fn(*args, **kwargs):
  155. for qengine in supported_qengines:
  156. with override_quantized_engine(qengine):
  157. # qfunction should not return anything.
  158. qfunction(*args, **kwargs)
  159. return test_fn
  160. def qengine_is_fbgemm():
  161. return torch.backends.quantized.engine == 'fbgemm'
  162. def qengine_is_qnnpack():
  163. return torch.backends.quantized.engine == 'qnnpack'
  164. def qengine_is_onednn():
  165. return torch.backends.quantized.engine == 'onednn'
  166. def qengine_is_x86():
  167. return torch.backends.quantized.engine == 'x86'
  168. # Helper function used to simulate per-channel fake-quant against any axis
  169. def _permute_to_axis_zero(X, axis):
  170. new_axis_list = list(range(X.dim()))
  171. new_axis_list[axis] = 0
  172. new_axis_list[0] = axis
  173. y = X.permute(tuple(new_axis_list))
  174. return y, new_axis_list
  175. # Reference method for fake quantize
  176. # Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
  177. def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
  178. dtype = X.dtype
  179. X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
  180. res = torch.zeros_like(X)
  181. for i in range(X.size()[0]):
  182. res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) +
  183. per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i]
  184. out = res.permute(tuple(permute_axis_list))
  185. return out.to(dtype)
  186. # Reference method for the gradient of the fake quantize operator
  187. # Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64
  188. def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max):
  189. dtype = X.dtype
  190. X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis)
  191. Xq = torch.zeros_like(X)
  192. for i in range(X.size()[0]):
  193. Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i])
  194. Xq = Xq.permute(tuple(permute_axis_list))
  195. mask = (Xq >= quant_min) * (Xq <= quant_max)
  196. res = torch.zeros_like(dY)
  197. res[mask] = dY[mask]
  198. return res.to(dtype)
  199. def to_tensor(X, device):
  200. if not isinstance(X, torch.Tensor):
  201. X = torch.tensor(X)
  202. else:
  203. X = X.clone().detach()
  204. return X.to(device=torch.device(device), dtype=torch.float32)