schema_check_mode.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # mypy: ignore-errors
  2. from collections import namedtuple
  3. from copy import deepcopy
  4. from itertools import combinations
  5. import torch
  6. from torch.fx.operator_schemas import normalize_function
  7. from torch.utils import _pytree as pytree
  8. from torch.utils._python_dispatch import TorchDispatchMode
  9. from torch.utils._pytree import tree_map
  10. # Named Tuples used within SchemaCheckMode
  11. Mutation = namedtuple("Mutation", ["op_name", "arg_name"])
  12. Aliasing = namedtuple("Aliasing", ["op_name", "arg_name", "output_number"])
  13. # Simplified naming for C++ classes
  14. SchemaArgument = torch._C._SchemaArgument
  15. SchemaArgType = torch._C._SchemaArgType
  16. SchemaInfo = torch._C._SchemaInfo
  17. # This TorchDispatchMode Subclass is used to verify op schemas
  18. # This TorchDispatchMode Scubclass currently:
  19. # - Records the called ops
  20. # - Checks for mutations on all inputs
  21. # - Checks for aliasing on all inputs
  22. # move these 2 functions here to avoid numpy dependency in testing/_internal/common_utils.py
  23. def is_iterable_of_tensors(iterable):
  24. # Tensor itself is iterable so we check this first
  25. if isinstance(iterable, torch.Tensor):
  26. return False
  27. try:
  28. if len(iterable) == 0:
  29. return False
  30. for t in iter(iterable):
  31. if not isinstance(t, torch.Tensor):
  32. return False
  33. except TypeError as te:
  34. return False
  35. return True
  36. def clone_inputs(args):
  37. inputs = []
  38. for arg in args:
  39. if isinstance(arg, torch.Tensor):
  40. inputs.append(arg.detach().clone())
  41. elif is_iterable_of_tensors(arg):
  42. inputs.append([t.detach().clone() for t in arg])
  43. else:
  44. inputs.append(arg)
  45. return inputs
  46. class SchemaCheckMode(TorchDispatchMode):
  47. def __init__(self):
  48. # Information recorded for testing purposes. For example:
  49. # - incorrect schemas
  50. # - overly conservative schemas
  51. self.ops = []
  52. self.mutated = []
  53. self.aliasing = []
  54. def reset_cache(self):
  55. self.ops.clear()
  56. self.mutated.clear()
  57. self.aliasing.clear()
  58. def display_ops(self):
  59. print(*self.ops, sep=",")
  60. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  61. def bitwise_equal(lhs, rhs):
  62. if lhs.is_quantized:
  63. # TODO: This is only OK if can't have NaN quantized; idk if
  64. # this is actually true
  65. return torch.equal(lhs, rhs)
  66. else:
  67. return torch.allclose(lhs, rhs, equal_nan=True)
  68. def has_mutated(before, after, md):
  69. are_tensors = type(before) == torch.Tensor and type(after) == torch.Tensor
  70. if (
  71. are_tensors
  72. and before.layout != torch.sparse_csr
  73. and after.layout != torch.sparse_csr
  74. ):
  75. return not (
  76. before.size() == after.size()
  77. and bitwise_equal(before, after)
  78. and md[0] == after.stride()
  79. and md[1] == after._typed_storage()._cdata
  80. )
  81. return False
  82. def has_aliased(lhs, rhs):
  83. try:
  84. return torch._C._overlaps(lhs, rhs)
  85. except Exception as exception:
  86. if str(exception).startswith("Cannot inspect value of type "):
  87. return False
  88. else:
  89. raise exception
  90. def standardize_name(name):
  91. return name if name != "self" else "input"
  92. def unwrap(e):
  93. if isinstance(e, torch.Tensor) and not type(e) == torch.Tensor:
  94. try:
  95. return e.elem
  96. except AttributeError as t:
  97. return e
  98. return e
  99. def parse_metadata(e):
  100. if isinstance(e, torch.Tensor):
  101. if not type(e) == torch.Tensor:
  102. try:
  103. current = e.elem
  104. return (
  105. deepcopy(current.stride()),
  106. current._typed_storage()._cdata,
  107. )
  108. except AttributeError as t:
  109. return None
  110. # Sparse CSR tensors do not have strides or storage
  111. elif e.layout != torch.sparse_csr:
  112. return (deepcopy(e.stride()), e._typed_storage()._cdata)
  113. return None
  114. self.ops.append(func._schema.name)
  115. # Clone and process arguments and outputs
  116. pre_arguments = normalize_function(
  117. func, args, kwargs, normalize_to_only_use_kwargs=True
  118. ).kwargs
  119. c_p_args = dict(zip(pre_arguments.keys(), clone_inputs(pre_arguments.values())))
  120. cloned_arguments = {
  121. name: tree_map(unwrap, c_p_args.get(name)) for name in c_p_args
  122. }
  123. cloned_metadata = {
  124. name: [
  125. parse_metadata(a) for a in pytree.tree_leaves(pre_arguments.get(name))
  126. ]
  127. for name in pre_arguments
  128. }
  129. out = func(*args, **kwargs)
  130. arguments = {
  131. name: tree_map(unwrap, pre_arguments.get(name)) for name in pre_arguments
  132. }
  133. tuple_out = out if isinstance(out, tuple) else (out,)
  134. tuple_out = tree_map(unwrap, tuple_out)
  135. schema_info = SchemaInfo(func._schema)
  136. schema_info.add_argument_values(pre_arguments)
  137. # Process arguments with outputs
  138. for i in range(len(func._schema.arguments)):
  139. arg = func._schema.arguments[i]
  140. name = standardize_name(arg.name)
  141. if arguments.get(name) is not None:
  142. before = cloned_arguments.get(name)
  143. md = cloned_metadata.get(name)
  144. after = arguments.get(name)
  145. for j in range(len(tuple_out)):
  146. # aten::_unsafe_view is intended to have incorrect aliasing notation (hence unsafe)
  147. unsafe_ops = ("aten::_unsafe_view", "aten::unsafe_split")
  148. if (
  149. has_aliased(tuple_out[j], after)
  150. and func._schema.name not in unsafe_ops
  151. ):
  152. if not schema_info.may_contain_alias(
  153. SchemaArgument(SchemaArgType.output, j),
  154. SchemaArgument(SchemaArgType.input, i),
  155. ):
  156. raise RuntimeError(
  157. f"Argument {name} is not defined to alias output but was aliasing"
  158. )
  159. else:
  160. self.aliasing.append(
  161. Aliasing(func._schema.name, name, f"output_{j}")
  162. )
  163. if after is tuple_out[j] and isinstance(after, torch.Tensor):
  164. # Only mutable ops e.g. (add_, add.out) are allowed to directly return inputs.
  165. if not schema_info.is_mutable(
  166. SchemaArgument(SchemaArgType.input, i)
  167. ) and func not in [
  168. torch.ops.aten.lift.default,
  169. torch.ops.aten.lift_fresh.default,
  170. ]:
  171. raise RuntimeError(
  172. f"""\
  173. Dispatcher operators below autograd are not allowed to directly return inputs.
  174. However, we found that `outputs[{str(j)}] is {name}"""
  175. )
  176. if any(
  177. has_mutated(a, b, c)
  178. for a, b, c in zip(
  179. pytree.tree_leaves(before), pytree.tree_leaves(after), md
  180. )
  181. ):
  182. if not schema_info.is_mutable(
  183. SchemaArgument(SchemaArgType.input, i)
  184. ):
  185. raise RuntimeError(
  186. f"Argument {name} is not defined as mutable but was mutated"
  187. )
  188. else:
  189. self.mutated.append(Mutation(func._schema.name, name))
  190. # Aliasing between outputs
  191. for i, j in combinations(range(len(func._schema.returns)), 2):
  192. if has_aliased(tuple_out[i], tuple_out[j]):
  193. if not schema_info.may_contain_alias(
  194. SchemaArgument(SchemaArgType.output, i),
  195. SchemaArgument(SchemaArgType.output, j),
  196. ):
  197. raise RuntimeError(f"Outputs {i} and {j} alias unexpectedly")
  198. return out