hypothesis_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371
  1. # mypy: ignore-errors
  2. from collections import defaultdict
  3. from collections.abc import Iterable
  4. import numpy as np
  5. import torch
  6. import hypothesis
  7. from functools import reduce
  8. from hypothesis import assume
  9. from hypothesis import settings
  10. from hypothesis import strategies as st
  11. from hypothesis.extra import numpy as stnp
  12. from hypothesis.strategies import SearchStrategy
  13. from torch.testing._internal.common_quantized import _calculate_dynamic_qparams, _calculate_dynamic_per_channel_qparams
  14. # Setup for the hypothesis tests.
  15. # The tuples are (torch_quantized_dtype, zero_point_enforce), where the last
  16. # element is enforced zero_point. If None, any zero_point point within the
  17. # range of the data type is OK.
  18. # Tuple with all quantized data types.
  19. _ALL_QINT_TYPES = (
  20. torch.quint8,
  21. torch.qint8,
  22. torch.qint32,
  23. )
  24. # Enforced zero point for every quantized data type.
  25. # If None, any zero_point point within the range of the data type is OK.
  26. _ENFORCED_ZERO_POINT = defaultdict(lambda: None, {
  27. torch.quint8: None,
  28. torch.qint8: None,
  29. torch.qint32: 0
  30. })
  31. def _get_valid_min_max(qparams):
  32. scale, zero_point, quantized_type = qparams
  33. adjustment = 1 + torch.finfo(torch.float).eps
  34. _long_type_info = torch.iinfo(torch.long)
  35. long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment
  36. # make sure intermediate results are within the range of long
  37. min_value = max((long_min - zero_point) * scale, (long_min / scale + zero_point))
  38. max_value = min((long_max - zero_point) * scale, (long_max / scale + zero_point))
  39. return np.float32(min_value), np.float32(max_value)
  40. # This wrapper wraps around `st.floats` and checks the version of `hypothesis`, if
  41. # it is too old, removes the `width` parameter (which was introduced)
  42. # in 3.67.0
  43. def _floats_wrapper(*args, **kwargs):
  44. if 'width' in kwargs and hypothesis.version.__version_info__ < (3, 67, 0):
  45. # As long as nan, inf, min, max are not specified, reimplement the width
  46. # parameter for older versions of hypothesis.
  47. no_nan_and_inf = (
  48. (('allow_nan' in kwargs and not kwargs['allow_nan']) or
  49. 'allow_nan' not in kwargs) and
  50. (('allow_infinity' in kwargs and not kwargs['allow_infinity']) or
  51. 'allow_infinity' not in kwargs))
  52. min_and_max_not_specified = (
  53. len(args) == 0 and
  54. 'min_value' not in kwargs and
  55. 'max_value' not in kwargs
  56. )
  57. if no_nan_and_inf and min_and_max_not_specified:
  58. if kwargs['width'] == 16:
  59. kwargs['min_value'] = torch.finfo(torch.float16).min
  60. kwargs['max_value'] = torch.finfo(torch.float16).max
  61. elif kwargs['width'] == 32:
  62. kwargs['min_value'] = torch.finfo(torch.float32).min
  63. kwargs['max_value'] = torch.finfo(torch.float32).max
  64. elif kwargs['width'] == 64:
  65. kwargs['min_value'] = torch.finfo(torch.float64).min
  66. kwargs['max_value'] = torch.finfo(torch.float64).max
  67. kwargs.pop('width')
  68. return st.floats(*args, **kwargs)
  69. def floats(*args, **kwargs):
  70. if 'width' not in kwargs:
  71. kwargs['width'] = 32
  72. return _floats_wrapper(*args, **kwargs)
  73. """Hypothesis filter to avoid overflows with quantized tensors.
  74. Args:
  75. tensor: Tensor of floats to filter
  76. qparams: Quantization parameters as returned by the `qparams`.
  77. Returns:
  78. True
  79. Raises:
  80. hypothesis.UnsatisfiedAssumption
  81. Note: This filter is slow. Use it only when filtering of the test cases is
  82. absolutely necessary!
  83. """
  84. def assume_not_overflowing(tensor, qparams):
  85. min_value, max_value = _get_valid_min_max(qparams)
  86. assume(tensor.min() >= min_value)
  87. assume(tensor.max() <= max_value)
  88. return True
  89. """Strategy for generating the quantization parameters.
  90. Args:
  91. dtypes: quantized data types to sample from.
  92. scale_min / scale_max: Min and max scales. If None, set to 1e-3 / 1e3.
  93. zero_point_min / zero_point_max: Min and max for the zero point. If None,
  94. set to the minimum and maximum of the quantized data type.
  95. Note: The min and max are only valid if the zero_point is not enforced
  96. by the data type itself.
  97. Generates:
  98. scale: Sampled scale.
  99. zero_point: Sampled zero point.
  100. quantized_type: Sampled quantized type.
  101. """
  102. @st.composite
  103. def qparams(draw, dtypes=None, scale_min=None, scale_max=None,
  104. zero_point_min=None, zero_point_max=None):
  105. if dtypes is None:
  106. dtypes = _ALL_QINT_TYPES
  107. if not isinstance(dtypes, (list, tuple)):
  108. dtypes = (dtypes,)
  109. quantized_type = draw(st.sampled_from(dtypes))
  110. _type_info = torch.iinfo(quantized_type)
  111. qmin, qmax = _type_info.min, _type_info.max
  112. # TODO: Maybe embed the enforced zero_point in the `torch.iinfo`.
  113. _zp_enforced = _ENFORCED_ZERO_POINT[quantized_type]
  114. if _zp_enforced is not None:
  115. zero_point = _zp_enforced
  116. else:
  117. _zp_min = qmin if zero_point_min is None else zero_point_min
  118. _zp_max = qmax if zero_point_max is None else zero_point_max
  119. zero_point = draw(st.integers(min_value=_zp_min, max_value=_zp_max))
  120. if scale_min is None:
  121. scale_min = torch.finfo(torch.float).eps
  122. if scale_max is None:
  123. scale_max = torch.finfo(torch.float).max
  124. scale = draw(floats(min_value=scale_min, max_value=scale_max, width=32))
  125. return scale, zero_point, quantized_type
  126. """Strategy to create different shapes.
  127. Args:
  128. min_dims / max_dims: minimum and maximum rank.
  129. min_side / max_side: minimum and maximum dimensions per rank.
  130. Generates:
  131. Possible shapes for a tensor, constrained to the rank and dimensionality.
  132. Example:
  133. # Generates 3D and 4D tensors.
  134. @given(Q = qtensor(shapes=array_shapes(min_dims=3, max_dims=4))
  135. some_test(self, Q):...
  136. """
  137. @st.composite
  138. def array_shapes(draw, min_dims=1, max_dims=None, min_side=1, max_side=None, max_numel=None):
  139. """Return a strategy for array shapes (tuples of int >= 1)."""
  140. assert min_dims < 32
  141. if max_dims is None:
  142. max_dims = min(min_dims + 2, 32)
  143. assert max_dims < 32
  144. if max_side is None:
  145. max_side = min_side + 5
  146. candidate = st.lists(st.integers(min_side, max_side), min_size=min_dims, max_size=max_dims)
  147. if max_numel is not None:
  148. candidate = candidate.filter(lambda x: reduce(int.__mul__, x, 1) <= max_numel)
  149. return draw(candidate.map(tuple))
  150. """Strategy for generating test cases for tensors.
  151. The resulting tensor is in float32 format.
  152. Args:
  153. shapes: Shapes under test for the tensor. Could be either a hypothesis
  154. strategy, or an iterable of different shapes to sample from.
  155. elements: Elements to generate from for the returned data type.
  156. If None, the strategy resolves to float within range [-1e6, 1e6].
  157. qparams: Instance of the qparams strategy. This is used to filter the tensor
  158. such that the overflow would not happen.
  159. Generates:
  160. X: Tensor of type float32. Note that NaN and +/-inf is not included.
  161. qparams: (If `qparams` arg is set) Quantization parameters for X.
  162. The returned parameters are `(scale, zero_point, quantization_type)`.
  163. (If `qparams` arg is None), returns None.
  164. """
  165. @st.composite
  166. def tensor(draw, shapes=None, elements=None, qparams=None, dtype=np.float32):
  167. if isinstance(shapes, SearchStrategy):
  168. _shape = draw(shapes)
  169. else:
  170. _shape = draw(st.sampled_from(shapes))
  171. if qparams is None:
  172. if elements is None:
  173. elements = floats(-1e6, 1e6, allow_nan=False, width=32)
  174. X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
  175. assume(not (np.isnan(X).any() or np.isinf(X).any()))
  176. return X, None
  177. qparams = draw(qparams)
  178. if elements is None:
  179. min_value, max_value = _get_valid_min_max(qparams)
  180. elements = floats(min_value, max_value, allow_infinity=False,
  181. allow_nan=False, width=32)
  182. X = draw(stnp.arrays(dtype=dtype, elements=elements, shape=_shape))
  183. # Recompute the scale and zero_points according to the X statistics.
  184. scale, zp = _calculate_dynamic_qparams(X, qparams[2])
  185. enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
  186. if enforced_zp is not None:
  187. zp = enforced_zp
  188. return X, (scale, zp, qparams[2])
  189. @st.composite
  190. def per_channel_tensor(draw, shapes=None, elements=None, qparams=None):
  191. if isinstance(shapes, SearchStrategy):
  192. _shape = draw(shapes)
  193. else:
  194. _shape = draw(st.sampled_from(shapes))
  195. if qparams is None:
  196. if elements is None:
  197. elements = floats(-1e6, 1e6, allow_nan=False, width=32)
  198. X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
  199. assume(not (np.isnan(X).any() or np.isinf(X).any()))
  200. return X, None
  201. qparams = draw(qparams)
  202. if elements is None:
  203. min_value, max_value = _get_valid_min_max(qparams)
  204. elements = floats(min_value, max_value, allow_infinity=False,
  205. allow_nan=False, width=32)
  206. X = draw(stnp.arrays(dtype=np.float32, elements=elements, shape=_shape))
  207. # Recompute the scale and zero_points according to the X statistics.
  208. scale, zp = _calculate_dynamic_per_channel_qparams(X, qparams[2])
  209. enforced_zp = _ENFORCED_ZERO_POINT.get(qparams[2], None)
  210. if enforced_zp is not None:
  211. zp = enforced_zp
  212. # Permute to model quantization along an axis
  213. axis = int(np.random.randint(0, X.ndim, 1))
  214. permute_axes = np.arange(X.ndim)
  215. permute_axes[0] = axis
  216. permute_axes[axis] = 0
  217. X = np.transpose(X, permute_axes)
  218. return X, (scale, zp, axis, qparams[2])
  219. """Strategy for generating test cases for tensors used in Conv.
  220. The resulting tensors is in float32 format.
  221. Args:
  222. spatial_dim: Spatial Dim for feature maps. If given as an iterable, randomly
  223. picks one from the pool to make it the spatial dimension
  224. batch_size_range: Range to generate `batch_size`.
  225. Must be tuple of `(min, max)`.
  226. input_channels_per_group_range:
  227. Range to generate `input_channels_per_group`.
  228. Must be tuple of `(min, max)`.
  229. output_channels_per_group_range:
  230. Range to generate `output_channels_per_group`.
  231. Must be tuple of `(min, max)`.
  232. feature_map_range: Range to generate feature map size for each spatial_dim.
  233. Must be tuple of `(min, max)`.
  234. kernel_range: Range to generate kernel size for each spatial_dim. Must be
  235. tuple of `(min, max)`.
  236. max_groups: Maximum number of groups to generate.
  237. elements: Elements to generate from for the returned data type.
  238. If None, the strategy resolves to float within range [-1e6, 1e6].
  239. qparams: Strategy for quantization parameters. for X, w, and b.
  240. Could be either a single strategy (used for all) or a list of
  241. three strategies for X, w, b.
  242. Generates:
  243. (X, W, b, g): Tensors of type `float32` of the following drawen shapes:
  244. X: (`batch_size, input_channels, H, W`)
  245. W: (`output_channels, input_channels_per_group) + kernel_shape
  246. b: `(output_channels,)`
  247. groups: Number of groups the input is divided into
  248. Note: X, W, b are tuples of (Tensor, qparams), where qparams could be either
  249. None or (scale, zero_point, quantized_type)
  250. Example:
  251. @given(tensor_conv(
  252. spatial_dim=2,
  253. batch_size_range=(1, 3),
  254. input_channels_per_group_range=(1, 7),
  255. output_channels_per_group_range=(1, 7),
  256. feature_map_range=(6, 12),
  257. kernel_range=(3, 5),
  258. max_groups=4,
  259. elements=st.floats(-1.0, 1.0),
  260. qparams=qparams()
  261. ))
  262. """
  263. @st.composite
  264. def tensor_conv(
  265. draw, spatial_dim=2, batch_size_range=(1, 4),
  266. input_channels_per_group_range=(3, 7),
  267. output_channels_per_group_range=(3, 7), feature_map_range=(6, 12),
  268. kernel_range=(3, 7), max_groups=1, can_be_transposed=False,
  269. elements=None, qparams=None
  270. ):
  271. # Resolve the minibatch, in_channels, out_channels, iH/iW, iK/iW
  272. batch_size = draw(st.integers(*batch_size_range))
  273. input_channels_per_group = draw(
  274. st.integers(*input_channels_per_group_range))
  275. output_channels_per_group = draw(
  276. st.integers(*output_channels_per_group_range))
  277. groups = draw(st.integers(1, max_groups))
  278. input_channels = input_channels_per_group * groups
  279. output_channels = output_channels_per_group * groups
  280. if isinstance(spatial_dim, Iterable):
  281. spatial_dim = draw(st.sampled_from(spatial_dim))
  282. feature_map_shape = []
  283. for i in range(spatial_dim):
  284. feature_map_shape.append(draw(st.integers(*feature_map_range)))
  285. kernels = []
  286. for i in range(spatial_dim):
  287. kernels.append(draw(st.integers(*kernel_range)))
  288. tr = False
  289. weight_shape = (output_channels, input_channels_per_group) + tuple(kernels)
  290. bias_shape = output_channels
  291. if can_be_transposed:
  292. tr = draw(st.booleans())
  293. if tr:
  294. weight_shape = (input_channels, output_channels_per_group) + tuple(kernels)
  295. bias_shape = output_channels
  296. # Resolve the tensors
  297. if qparams is not None:
  298. if isinstance(qparams, (list, tuple)):
  299. assert len(qparams) == 3, "Need 3 qparams for X, w, b"
  300. else:
  301. qparams = [qparams] * 3
  302. X = draw(tensor(shapes=(
  303. (batch_size, input_channels) + tuple(feature_map_shape),),
  304. elements=elements, qparams=qparams[0]))
  305. W = draw(tensor(shapes=(weight_shape,), elements=elements,
  306. qparams=qparams[1]))
  307. b = draw(tensor(shapes=(bias_shape,), elements=elements,
  308. qparams=qparams[2]))
  309. return X, W, b, groups, tr
  310. # We set the deadline in the currently loaded profile.
  311. # Creating (and loading) a separate profile overrides any settings the user
  312. # already specified.
  313. hypothesis_version = hypothesis.version.__version_info__
  314. current_settings = settings._profiles[settings._current_profile].__dict__
  315. current_settings['deadline'] = None
  316. if hypothesis_version >= (3, 16, 0) and hypothesis_version < (5, 0, 0):
  317. current_settings['timeout'] = hypothesis.unlimited
  318. def assert_deadline_disabled():
  319. if hypothesis_version < (3, 27, 0):
  320. import warnings
  321. warning_message = (
  322. "Your version of hypothesis is outdated. "
  323. "To avoid `DeadlineExceeded` errors, please update. "
  324. f"Current hypothesis version: {hypothesis.__version__}"
  325. )
  326. warnings.warn(warning_message)
  327. else:
  328. assert settings().deadline is None