wrappers.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # mypy: allow-untyped-defs
  2. from contextlib import contextmanager
  3. import torch
  4. import torch._custom_ops
  5. from torch._C import DispatchKey
  6. from torch._higher_order_ops.strict_mode import strict_mode
  7. from torch._higher_order_ops.utils import autograd_not_implemented
  8. from torch._ops import HigherOrderOperator
  9. from torch._subclasses.fake_tensor import FakeTensorMode
  10. from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
  11. from torch.utils import _pytree as pytree
  12. _export_tracepoint = HigherOrderOperator("_export_tracepoint")
  13. @_export_tracepoint.py_impl(ProxyTorchDispatchMode)
  14. def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
  15. if not mode.enable_tracing:
  16. return _export_tracepoint(*args, **kwargs)
  17. p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
  18. proxy = mode.tracer.create_proxy(
  19. "call_function", _export_tracepoint, p_args, p_kwargs
  20. )
  21. return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
  22. @_export_tracepoint.py_impl(FakeTensorMode)
  23. def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
  24. with mode:
  25. return args
  26. @_export_tracepoint.py_functionalize_impl
  27. def export_tracepoint_functional(ctx, *args, **kwargs):
  28. unwrapped_args = ctx.unwrap_tensors(args)
  29. unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
  30. with ctx.redispatch_to_next():
  31. out = _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
  32. return ctx.wrap_tensors(out)
  33. _export_tracepoint.py_impl(DispatchKey.Autograd)(
  34. autograd_not_implemented(_export_tracepoint, deferred_error=True)
  35. )
  36. @_export_tracepoint.py_impl(DispatchKey.CPU)
  37. def export_tracepoint_cpu(*args, **kwargs):
  38. return args
  39. def _wrap_submodule(mod, path, module_call_specs):
  40. assert isinstance(mod, torch.nn.Module)
  41. assert path != ""
  42. submodule = mod
  43. for name in path.split("."):
  44. if not hasattr(submodule, name):
  45. raise RuntimeError(f"Couldn't find submodule at path {path}")
  46. submodule = getattr(submodule, name)
  47. def update_module_call_signatures(path, in_spec, out_spec):
  48. if path in module_call_specs:
  49. assert module_call_specs[path]["in_spec"] == in_spec
  50. assert module_call_specs[path]["out_spec"] == out_spec
  51. module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
  52. def check_flattened(flat_args):
  53. for a in flat_args:
  54. if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
  55. raise AssertionError(
  56. f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
  57. )
  58. def pre_hook(module, args, kwargs):
  59. flat_args, in_spec = pytree.tree_flatten((args, kwargs))
  60. check_flattened(flat_args)
  61. flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
  62. args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
  63. return args, kwargs
  64. def post_hook(module, args, kwargs, res):
  65. _, in_spec = pytree.tree_flatten((args, kwargs))
  66. flat_res, out_spec = pytree.tree_flatten(res)
  67. check_flattened(flat_res)
  68. flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
  69. update_module_call_signatures(path, in_spec, out_spec)
  70. return pytree.tree_unflatten(flat_res, out_spec)
  71. pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
  72. post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
  73. return pre_handle, post_handle
  74. @contextmanager
  75. def _wrap_submodules(f, preserve_signature, module_call_signatures):
  76. handles = []
  77. try:
  78. for path in preserve_signature:
  79. handles.extend(_wrap_submodule(f, path, module_call_signatures))
  80. yield
  81. finally:
  82. for handle in handles:
  83. handle.remove()
  84. def _mark_strict_experimental(cls):
  85. def call(self, *args):
  86. return strict_mode(self, args)
  87. cls.__call__ = call
  88. return cls