__init__.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # mypy: allow-untyped-defs
  2. import os.path as _osp
  3. import torch
  4. from .throughput_benchmark import ThroughputBenchmark
  5. from .cpp_backtrace import get_cpp_backtrace
  6. from .backend_registration import rename_privateuse1_backend, generate_methods_for_privateuse1_backend
  7. from . import deterministic
  8. from . import collect_env
  9. import weakref
  10. import copyreg
  11. def set_module(obj, mod):
  12. """
  13. Set the module attribute on a python object for a given object for nicer printing
  14. """
  15. if not isinstance(mod, str):
  16. raise TypeError("The mod argument should be a string")
  17. obj.__module__ = mod
  18. if torch._running_with_deploy():
  19. # not valid inside torch_deploy interpreter, no paths exists for frozen modules
  20. cmake_prefix_path = None
  21. else:
  22. cmake_prefix_path = _osp.join(_osp.dirname(_osp.dirname(__file__)), 'share', 'cmake')
  23. def swap_tensors(t1, t2):
  24. """
  25. This function swaps the content of the two Tensor objects.
  26. At a high level, this will make t1 have the content of t2 while preserving
  27. its identity.
  28. This will not work if t1 and t2 have different slots.
  29. """
  30. # Ensure there are no weakrefs
  31. if weakref.getweakrefs(t1):
  32. raise RuntimeError("Cannot swap t1 because it has weakref associated with it")
  33. if weakref.getweakrefs(t2):
  34. raise RuntimeError("Cannot swap t2 because it has weakref associated with it")
  35. t1_slots = set(copyreg._slotnames(t1.__class__)) # type: ignore[attr-defined]
  36. t2_slots = set(copyreg._slotnames(t2.__class__)) # type: ignore[attr-defined]
  37. if t1_slots != t2_slots:
  38. raise RuntimeError("Cannot swap t1 and t2 if they have different slots")
  39. def swap_attr(name):
  40. tmp = getattr(t1, name)
  41. setattr(t1, name, (getattr(t2, name)))
  42. setattr(t2, name, tmp)
  43. def error_pre_hook(grad_outputs):
  44. raise RuntimeError("Trying to execute AccumulateGrad node that was poisoned by swap_tensors "
  45. "this can happen when you try to run backward on a tensor that was swapped. "
  46. "For a module m with `torch.__future__.set_swap_module_params_on_conversion(True)` "
  47. "you should not change the device or dtype of the module (e.g. `m.cpu()` or `m.half()`) "
  48. "between running forward and backward. To resolve this, please only change the "
  49. "device/dtype before running forward (or after both forward and backward).")
  50. def check_use_count(t, name='t1'):
  51. use_count = t._use_count()
  52. error_str = (f"Expected use_count of {name} to be 1 or 2 with an AccumulateGrad node but got {use_count} "
  53. f"make sure you are not holding references to the tensor in other places.")
  54. if use_count > 1:
  55. if use_count == 2 and t.is_leaf:
  56. accum_grad_node = torch.autograd.graph.get_gradient_edge(t).node
  57. # Make sure that the accumulate_grad node was not lazy_init-ed by get_gradient_edge
  58. if t._use_count() == 2:
  59. accum_grad_node.register_prehook(error_pre_hook)
  60. else:
  61. raise RuntimeError(error_str)
  62. else:
  63. raise RuntimeError(error_str)
  64. check_use_count(t1, 't1')
  65. check_use_count(t2, 't2')
  66. # Swap the types
  67. # Note that this will fail if there are mismatched slots
  68. swap_attr("__class__")
  69. # Swap the dynamic attributes
  70. swap_attr("__dict__")
  71. # Swap the slots
  72. for slot in t1_slots:
  73. if hasattr(t1, slot) and hasattr(t2, slot):
  74. swap_attr(slot)
  75. elif hasattr(t1, slot):
  76. setattr(t2, slot, (getattr(t1, slot)))
  77. delattr(t1, slot)
  78. elif hasattr(t2, slot):
  79. setattr(t1, slot, (getattr(t2, slot)))
  80. delattr(t2, slot)
  81. # Swap the at::Tensor they point to
  82. torch._C._swap_tensor_impl(t1, t2)