fake_utils.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # mypy: ignore-errors
  2. import functools
  3. import warnings
  4. from typing import Callable, Union
  5. import torch
  6. import torch.utils._pytree as pytree
  7. from torch._ops import OpOverload
  8. from torch._subclasses.fake_tensor import (
  9. FakeTensorMode,
  10. tree_flatten_only,
  11. UnsupportedFakeTensorException,
  12. )
  13. from torch.utils._python_dispatch import TorchDispatchMode
  14. aten = torch._ops.ops.aten
  15. def outputs_alias_inputs(outputs, inputs):
  16. input_storages = {
  17. inp._typed_storage()._cdata
  18. for inp in tree_flatten_only(torch.Tensor, inputs)
  19. if torch._C._has_storage(inp)
  20. }
  21. return any(
  22. torch._C._has_storage(out) and out._typed_storage()._cdata in input_storages
  23. for out in tree_flatten_only(torch.Tensor, outputs)
  24. )
  25. def outputs_are_inputs(outputs, inputs):
  26. input_ids = {id(inp) for inp in tree_flatten_only(torch.Tensor, inputs)}
  27. return any(id(out) in input_ids for out in tree_flatten_only(torch.Tensor, outputs))
  28. def output_alias_each_other(outputs):
  29. storages = set()
  30. for out in tree_flatten_only(torch.Tensor, outputs):
  31. if not torch._C._has_storage(out):
  32. continue
  33. stor = out._typed_storage()._cdata
  34. if stor in storages:
  35. return True
  36. storages.add(stor)
  37. return False
  38. def is_sdpa_error(func, idx, e):
  39. if (
  40. (
  41. func is aten._scaled_dot_product_flash_attention.default
  42. or func is aten._flash_attention_forward.default
  43. )
  44. and idx in (6, 7)
  45. and "Devices" in repr(e)
  46. ):
  47. return True
  48. if (
  49. (
  50. func is aten._scaled_dot_product_efficient_attention.default
  51. or func is aten._efficient_attention_forward.default
  52. )
  53. and idx in (2, 3)
  54. and "Devices" in repr(e)
  55. ):
  56. return True
  57. return False
  58. class CrossRefFakeMode(TorchDispatchMode):
  59. def __init__(
  60. self,
  61. ignore_op_fn: Union[Callable[[OpOverload], bool], None] = None,
  62. *,
  63. check_strides=True,
  64. check_aliasing=True,
  65. ):
  66. self.ignore_op_fn = (
  67. ignore_op_fn if ignore_op_fn is not None else lambda fn: False
  68. )
  69. self.check_strides = check_strides
  70. self.check_aliasing = check_aliasing
  71. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  72. kwargs = kwargs or {}
  73. fake_r = None
  74. # empty_like excluded for now due to sparse complex
  75. # aten._to_dense.default this one is getting called with csc
  76. if (
  77. func
  78. not in (
  79. aten.lift_fresh.default,
  80. aten.lift_fresh_copy.default,
  81. aten.set_.source_Storage_storage_offset,
  82. )
  83. and not self.ignore_op_fn(func)
  84. and torch.Tag.dynamic_output_shape not in func.tags
  85. and torch.Tag.inplace_view not in func.tags
  86. and torch.Tag.data_dependent_output not in func.tags
  87. ):
  88. # Do not import symbolic_shapes at the top of the module as it imports sympy and that's slow
  89. from torch.fx.experimental.symbolic_shapes import ShapeEnv
  90. try:
  91. # TODO: enable_python_dispatcher() here
  92. with FakeTensorMode(shape_env=ShapeEnv()) as fake_mode:
  93. fake_args, fake_kwargs = pytree.tree_map_only(
  94. torch.Tensor,
  95. functools.partial(fake_mode.from_tensor, static_shapes=True),
  96. (args, kwargs),
  97. )
  98. with warnings.catch_warnings():
  99. fake_r = func(*fake_args, **fake_kwargs)
  100. except UnsupportedFakeTensorException:
  101. pass
  102. context = (
  103. f"When comparing the output of {func} on FakeTensor and concrete Tensors, "
  104. f"found"
  105. )
  106. r = func(*args, **kwargs)
  107. if fake_r is not None:
  108. r_flat = pytree.tree_leaves(r)
  109. f_flat = pytree.tree_leaves(fake_r)
  110. assert len(f_flat) == len(
  111. r_flat
  112. ), f"{context} mismatch in number of returns {len(f_flat)} != {len(r_flat)}"
  113. if self.check_aliasing:
  114. r_aliasing = outputs_alias_inputs(r, (args, kwargs))
  115. f_aliasing = outputs_alias_inputs(fake_r, (fake_args, fake_kwargs))
  116. assert (
  117. r_aliasing == f_aliasing
  118. ), f"{context} mismatch in outputs_alias_inputs check {f_aliasing} != {r_aliasing}"
  119. r_identity_eq = outputs_are_inputs(r, (args, kwargs))
  120. f_identity_eq = outputs_are_inputs(fake_r, (fake_args, fake_kwargs))
  121. assert (
  122. r_identity_eq == f_identity_eq
  123. ), f"{context} mismatch in outputs_are_inputs check {f_identity_eq} != {r_identity_eq}"
  124. r_output_alias_each_other = output_alias_each_other(r)
  125. f_output_alias_each_other = output_alias_each_other(fake_r)
  126. assert r_output_alias_each_other == f_output_alias_each_other, (
  127. f"{context} mismatch in outputs_alias_each_other check "
  128. f"{f_output_alias_each_other} != {r_output_alias_each_other}"
  129. )
  130. for idx, (r_out, fake_out) in enumerate(
  131. zip(pytree.tree_leaves(r), pytree.tree_leaves(fake_r))
  132. ):
  133. r_is_ten = isinstance(r_out, torch.Tensor)
  134. assert r_is_ten == isinstance(
  135. fake_out, torch.Tensor
  136. ), f"{context} mismatched number of tensor outputs"
  137. if r_is_ten:
  138. assert r_out.requires_grad == fake_out.requires_grad, (
  139. f"{context} mismatched requires_grad-ness of outputs. "
  140. f"This usually means that you have added autograd support "
  141. f"for your operator at a dispatch key other than Autograd, "
  142. f"which will lead to problems"
  143. )
  144. if torch._C._has_storage(r_out):
  145. r_offset = r_out.storage_offset()
  146. f_offset = fake_out.storage_offset()
  147. assert (
  148. r_offset == f_offset
  149. ), f"{context} mismatched storage offset"
  150. try:
  151. torch._prims.utils.compare_tensor_meta(
  152. r_out,
  153. fake_out,
  154. check_strides=self.check_strides,
  155. allow_rhs_unbacked=True,
  156. )
  157. except Exception as e:
  158. if is_sdpa_error(func, idx, e):
  159. continue
  160. error_message = (
  161. f"{context} mismatched tensor metadata: {e}"
  162. if len(r_flat) == 1
  163. else f"{context} mismatched tensor metadata for output[{idx}]: {e}"
  164. )
  165. raise RuntimeError(error_message) from e
  166. return r