common_jit.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323
  1. # mypy: ignore-errors
  2. # Torch
  3. import torch
  4. import torch.cuda
  5. import torch.jit
  6. import torch.jit._logging
  7. import torch.jit.frontend
  8. import torch.jit.quantized
  9. # Testing utils
  10. from torch.testing._internal.common_dtype import floating_and_complex_types_and
  11. from torch.testing._internal.common_utils import TestCase, \
  12. freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
  13. from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
  14. # Standard library
  15. from itertools import chain
  16. from typing import List, Union
  17. from torch._C import TensorType
  18. import io
  19. def check_output_types(self, func, ref_outputs, args, kwargs):
  20. graph = getattr(func, 'last_graph', None)
  21. types = [o.type() for o in graph.outputs()]
  22. self.assertTrue(len(types) == 1)
  23. t = types[0]
  24. torch._C._jit_assert_is_instance(ref_outputs, t)
  25. # Test names in this set are only checked for a single derivative
  26. nn_functional_single_grad = frozenset('test_nn_' + name for name in [
  27. 'pdist',
  28. 'multilabel_margin_loss',
  29. 'max_unpool3d',
  30. 'multi_margin_loss',
  31. 'binary_cross_entropy',
  32. 'binary_cross_entropy_size_average',
  33. 'ctc_loss',
  34. 'grid_sample',
  35. ])
  36. def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
  37. allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
  38. """Verifies a function performs identically to some reference implementation.
  39. Commonly, this is used to verify that a JIT implementation
  40. (output_func) matches the behavior of the eager implementation
  41. (reference_func).
  42. """
  43. kwargs = kwargs if kwargs else {}
  44. def allSum(vs):
  45. if isinstance(vs, torch.Tensor):
  46. vs = (vs,)
  47. return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
  48. for i, v in enumerate(vs)
  49. if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
  50. def clone_tensor(t, preserve_requires_grad):
  51. require_grad = preserve_requires_grad and t.requires_grad
  52. return t.detach().clone().requires_grad_(require_grad)
  53. def clone_inputs(preserve_requires_grad: bool):
  54. inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
  55. for arg in args:
  56. if isinstance(arg, torch.Tensor):
  57. inputs.append(clone_tensor(arg, preserve_requires_grad))
  58. elif is_iterable_of_tensors(arg):
  59. inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
  60. else:
  61. inputs.append(arg)
  62. return inputs
  63. # Returns tensors in args that requires_grad, including tensors in TensorList args
  64. def get_recording_tensors(args):
  65. recording_tensors: List[torch.Tensor] = []
  66. for arg in args:
  67. if isinstance(arg, torch.Tensor) and arg.requires_grad:
  68. recording_tensors.append(arg)
  69. elif is_iterable_of_tensors(arg):
  70. recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
  71. return recording_tensors
  72. # test no gradients case
  73. nograd_inputs = clone_inputs(preserve_requires_grad=False)
  74. outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
  75. with enable_profiling_mode_for_profiling_tests():
  76. outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
  77. self.assertEqual(outputs, outputs_test)
  78. if check_types:
  79. check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
  80. if no_grad:
  81. # skip grad tests
  82. return
  83. with enable_profiling_mode_for_profiling_tests():
  84. # test single grad case
  85. recording_inputs = clone_inputs(preserve_requires_grad=True)
  86. recording_tensors = get_recording_tensors(recording_inputs)
  87. outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
  88. grads = torch.autograd.grad(allSum(outputs), recording_tensors,
  89. allow_unused=allow_unused)
  90. outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
  91. grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
  92. allow_unused=allow_unused)
  93. self.assertEqual(outputs, outputs_test)
  94. self.assertEqual(grads, grads_test)
  95. # test the grad grad case
  96. if self._testMethodName in nn_functional_single_grad or no_gradgrad:
  97. return
  98. outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
  99. l1 = allSum(outputs)
  100. grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
  101. allow_unused=allow_unused)
  102. l2 = (allSum(grads) * l1)
  103. grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
  104. recording_inputs = clone_inputs(preserve_requires_grad=True)
  105. recording_tensors = get_recording_tensors(recording_inputs)
  106. outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
  107. l1_test = allSum(outputs_test)
  108. grads_test = torch.autograd.grad(
  109. l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
  110. l2_test = (allSum(grads_test) * l1_test)
  111. grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
  112. self.assertEqual(outputs, outputs_test)
  113. self.assertEqual(grads, grads_test)
  114. for g2, g2_test in zip(grads2, grads2_test):
  115. if g2 is None and g2_test is None:
  116. continue
  117. self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
  118. class JitCommonTestCase(TestCase):
  119. def createFunctionFromGraph(self, trace):
  120. graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
  121. return torch._C._create_function_from_graph("forward", graph)
  122. def assertExportImport(self, trace, inputs):
  123. m = self.createFunctionFromGraph(trace)
  124. self.assertExportImportModule(m, inputs)
  125. def assertExportImportModule(self, m, inputs):
  126. m_import = self.getExportImportCopy(m)
  127. a = self.runAndSaveRNG(m, inputs)
  128. b = self.runAndSaveRNG(m_import, inputs)
  129. self.assertEqual(a, b, "Results of original model and "
  130. "exported/imported version of model differed")
  131. def runAndSaveRNG(self, func, inputs, kwargs=None):
  132. kwargs = kwargs if kwargs else {}
  133. with freeze_rng_state():
  134. results = func(*inputs, **kwargs)
  135. return results
  136. def getExportImportCopy(self, m, also_test_file=True, map_location=None):
  137. buffer = io.BytesIO()
  138. torch.jit.save(m, buffer)
  139. buffer.seek(0)
  140. imported = torch.jit.load(buffer, map_location=map_location)
  141. if not also_test_file:
  142. return imported
  143. with TemporaryFileName() as fname:
  144. torch.jit.save(imported, fname)
  145. return torch.jit.load(fname, map_location=map_location)
  146. def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
  147. fusion_nodes_not_found, non_fusible_nodes_being_fused,
  148. fusion_nodes_found, nodes_in_diff_graph):
  149. err_msg = "\nFailure in testing nodes' autodifferentiation. "
  150. if should_autodiff_node:
  151. err_msg += "One or more nodes were expected to be autodiffed, " \
  152. "but were not found in specified fusible/nonfusible " \
  153. "DifferentiableGraph groups. \nSpecifically:"
  154. # The node is intended to appear in a differentiable graph but doesn't
  155. diff_nodes_missing = []
  156. # The node is intended to appear in a differentiable graph
  157. # outside of a fusion group but instead is in a fusion group
  158. diff_nodes_in_fusion = []
  159. # The node is intended to appear in a fusion group but doesn't
  160. fusion_nodes_missing = []
  161. # The node is intended to appear in a fusion group but instead
  162. # is just in an outer differentiable graph
  163. fusion_nodes_in_diff = []
  164. for node in nodes_not_in_diff_graph:
  165. if node in non_fusible_nodes_being_fused:
  166. diff_nodes_in_fusion.append(node)
  167. else:
  168. diff_nodes_missing.append(node)
  169. for node in fusion_nodes_not_found:
  170. if node in nodes_in_diff_graph:
  171. fusion_nodes_in_diff.append(node)
  172. else:
  173. fusion_nodes_missing.append(node)
  174. if len(diff_nodes_missing) > 0:
  175. err_msg += f"\n {diff_nodes_missing} were not in one of the " \
  176. "DifferentiableGraphs when they were expected to be. " \
  177. "Did you intend for these nodes to be autodiffed? " \
  178. "If not, remove them from the list of nonfusible nodes."
  179. if len(diff_nodes_in_fusion) > 0:
  180. err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \
  181. "when they were expected to be just in a DifferentiableGraph. If it was " \
  182. "intended for these nodes to be in FusionGroups, reclassify these nodes as " \
  183. "fusible nodes. If these nodes were not intended to be fused, your " \
  184. "autodifferentiation logic might be wrong."
  185. if len(fusion_nodes_missing) > 0:
  186. err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \
  187. "of the DifferentiableGraphs when they were expected to be. " \
  188. "They were also not found in an outer DifferentiableGraph. Did you " \
  189. "intend for these nodes to be autodifferentiated? If not, you should " \
  190. "remove these nodes from the test's fusible nodes. Otherwise your " \
  191. "autodifferentiation logic might be wrong."
  192. if len(fusion_nodes_in_diff) > 0:
  193. err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \
  194. "of the DifferentiableGraphs when they were expected to be, " \
  195. "instead they were found just in an outer DifferentiableGraph. " \
  196. "Did you intend for these nodes to be fused? If not, you should " \
  197. "move these nodes into the test's nonfusible nodes. Otherwise your " \
  198. "autodifferentiation logic might be wrong."
  199. else:
  200. err_msg += "One or more nodes were not expected to be autodiffed " \
  201. "but were found in a DifferentiableGraph or in a FusionGroup " \
  202. "of a DifferentiableGraph. Did you intend for these nodes to be " \
  203. "autodiffed? If so, change this test to expect autodifferentiation. " \
  204. "\nSpecifically:"
  205. if len(fusion_nodes_found) > 0:
  206. err_msg += f"\n {fusion_nodes_found} were not expected to be in " \
  207. "one of the DifferentiableGraphs, but appeared in a FusionGroup " \
  208. "of a DifferentiableGraph. "
  209. if len(nodes_in_diff_graph) > 0:
  210. err_msg += f"\n {nodes_in_diff_graph} were not expected to " \
  211. "be in one of the DifferentiableGraphs but were."
  212. return err_msg
  213. def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
  214. diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
  215. diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
  216. # Note: currently no tests have fusible_nodes
  217. fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
  218. fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
  219. # For any non-fusible node, it must show up in one of the DifferentiableGraphs.
  220. nodes_in_diff_graph = []
  221. nodes_not_in_diff_graph = []
  222. non_fusible_nodes_being_fused = []
  223. for node in nonfusible_nodes:
  224. if any(g.findNode(node) is not None for g in diff_subgraphs):
  225. nodes_in_diff_graph.append(node)
  226. else:
  227. nodes_not_in_diff_graph.append(node)
  228. if any(g.findNode(node) is not None for g in fusion_subgraphs):
  229. non_fusible_nodes_being_fused.append(node)
  230. found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
  231. # For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
  232. fusion_nodes_found = []
  233. fusion_nodes_not_found = []
  234. for node in fusible_nodes:
  235. if any(g.findNode(node) is not None for g in fusion_subgraphs):
  236. fusion_nodes_found.append(node)
  237. else:
  238. fusion_nodes_not_found.append(node)
  239. found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
  240. if should_autodiff_node is not None:
  241. err_msg = self.autoDiffErrorMessage(should_autodiff_node,
  242. nodes_not_in_diff_graph,
  243. fusion_nodes_not_found,
  244. non_fusible_nodes_being_fused,
  245. fusion_nodes_found,
  246. nodes_in_diff_graph)
  247. self.assertEqual(should_autodiff_node,
  248. found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
  249. def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
  250. traced_graph, assert_propagation, constant_prop=True):
  251. # repropagte input shapes provided by tracing,
  252. prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
  253. for enable_test_mode in [True, False]:
  254. # here we are testing allowing/disallowing substituting in complete shapes as constants,
  255. # disallowing constants helps stress test partial eval and substitution pipeline
  256. torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
  257. torch._C._jit_erase_non_input_shape_information(traced_graph)
  258. if constant_prop:
  259. torch._C._jit_pass_constant_propagation(traced_graph)
  260. torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
  261. # Add sizes to default tensor type to avoid checking something out of scope
  262. # and difficulties with tracer leaving in other parts of tensor type
  263. output = next(traced_graph.outputs()).type()
  264. def test_type(type, actual_size):
  265. sizes = type.symbolic_sizes()
  266. out_type = TensorType.get().with_sizes(sizes)
  267. actual_type = TensorType.get().with_sizes(actual_size)
  268. # always check actual shape is a subtype of the output
  269. self.assertTrue(actual_type.isSubtypeOf(out_type))
  270. # and then if assertion flag is provided, check shape analysis
  271. # is successful
  272. if assert_propagation:
  273. self.assertEqual(out_type.sizes(), actual_size)
  274. if output.isSubtypeOf(torch._C.TensorType.get()):
  275. test_type(output, out_sizes)
  276. else:
  277. tuple_elements = output.elements()
  278. for i in range(len(tuple_elements)):
  279. test_type(tuple_elements[i], out_sizes[i])
  280. torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)