replicate.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. import torch
  2. from ..modules import Module
  3. from . import comm
  4. from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Sequence, Set, TypeVar, Union, cast
  5. from torch._utils import _get_device_index
  6. from collections import OrderedDict
  7. if TYPE_CHECKING:
  8. from torch.jit import ScriptModule
  9. from torch.jit._state import EnabledProxy
  10. __all__ = ['replicate']
  11. def _is_script_module(module: Module) -> bool:
  12. import torch.jit
  13. return isinstance(module, torch.jit.ScriptModule)
  14. def _is_script_method(module: Module) -> bool:
  15. import torch.jit
  16. return isinstance(module, torch._C.ScriptMethod)
  17. def _init_script_module() -> "ScriptModule":
  18. import torch.jit
  19. return torch.jit.ScriptModule()
  20. def _is_jit_enabled() -> "EnabledProxy":
  21. import torch.jit._state
  22. return torch.jit._state._enabled
  23. # Check if we can safely replicate the module.
  24. # there are two types of module:
  25. # 1. python modules
  26. # 2. ScriptModule
  27. #
  28. # currently a module cannot be replicated properly if the descendants of
  29. # any ScriptModule contains python module (type 1 above)
  30. def _replicatable_module(module: Module, memo: Optional[Set[Module]] = None) -> bool:
  31. # module.modules() contains module itself as the first element
  32. def descendant_modules(module: Module) -> Iterator[Module]:
  33. gen = module.modules()
  34. next(gen)
  35. return gen
  36. if not _is_jit_enabled():
  37. return True
  38. if memo is None:
  39. memo = set()
  40. # memoize visited modules
  41. memo.add(module)
  42. if _is_script_module(module):
  43. memo.update(descendant_modules(module))
  44. return all(_is_script_module(descendant) for
  45. descendant in descendant_modules(module))
  46. for child in module.children():
  47. # since any unreplicatable module will cause the check to return
  48. # False early, visited modules here can be safely ignored.
  49. if child in memo:
  50. continue
  51. if not _replicatable_module(child, memo):
  52. return False
  53. return True
  54. def _broadcast_coalesced_reshape(
  55. tensors: Sequence[torch.Tensor],
  56. devices: Sequence[Union[int, torch.device]],
  57. detach: bool = False,
  58. ) -> List[List[torch.Tensor]]:
  59. from ._functions import Broadcast
  60. if detach:
  61. return comm.broadcast_coalesced(tensors, devices)
  62. else:
  63. # Use the autograd function to broadcast if not detach
  64. if len(tensors) > 0:
  65. tensor_copies = Broadcast.apply(devices, *tensors)
  66. return [tensor_copies[i:i + len(tensors)]
  67. for i in range(0, len(tensor_copies), len(tensors))]
  68. else:
  69. return []
  70. T = TypeVar("T", bound=Module)
  71. def replicate(
  72. network: T,
  73. devices: Sequence[Union[int, torch.device]],
  74. detach: bool = False,
  75. ) -> List[T]:
  76. if not _replicatable_module(network):
  77. raise RuntimeError("Cannot replicate network where python modules are "
  78. "childrens of ScriptModule")
  79. if not devices:
  80. return []
  81. devices = [_get_device_index(x, True) for x in devices]
  82. num_replicas = len(devices)
  83. params = list(network.parameters())
  84. param_indices = {param: idx for idx, param in enumerate(params)}
  85. param_copies = _broadcast_coalesced_reshape(params, devices, detach)
  86. buffers = list(network.buffers())
  87. buffers_rg: List[torch.Tensor] = []
  88. buffers_not_rg: List[torch.Tensor] = []
  89. for buf in buffers:
  90. if buf.requires_grad and not detach:
  91. buffers_rg.append(buf)
  92. else:
  93. buffers_not_rg.append(buf)
  94. buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)}
  95. buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)}
  96. buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach)
  97. buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True)
  98. modules = list(network.modules())
  99. module_copies: List[List[Module]] = [[] for _ in devices]
  100. module_indices: Dict[Module, int] = {}
  101. for i, module in enumerate(modules):
  102. module_indices[module] = i
  103. for j in range(num_replicas):
  104. replica = module._replicate_for_data_parallel()
  105. # This is a temporary fix for DDP. DDP needs to access the
  106. # replicated model parameters. It used to do so through
  107. # `mode.parameters()`. The fix added in #33907 for DP stops the
  108. # `parameters()` API from exposing the replicated parameters.
  109. # Hence, we add a `_former_parameters` dict here to support DDP.
  110. replica._former_parameters = OrderedDict()
  111. module_copies[j].append(replica)
  112. for i, module in enumerate(modules):
  113. for key, child in module._modules.items():
  114. if child is None:
  115. for j in range(num_replicas):
  116. replica = module_copies[j][i]
  117. replica._modules[key] = None
  118. else:
  119. module_idx = module_indices[child]
  120. for j in range(num_replicas):
  121. replica = module_copies[j][i]
  122. setattr(replica, key, module_copies[j][module_idx])
  123. for key, param in module._parameters.items():
  124. if param is None:
  125. for j in range(num_replicas):
  126. replica = module_copies[j][i]
  127. replica._parameters[key] = None
  128. else:
  129. param_idx = param_indices[param]
  130. for j in range(num_replicas):
  131. replica = module_copies[j][i]
  132. param_copy = param_copies[j][param_idx]
  133. # parameters in replicas are no longer leaves,
  134. # so setattr them as non-parameter attributes
  135. setattr(replica, key, param_copy)
  136. # expose the parameter for DDP
  137. replica._former_parameters[key] = param_copy
  138. for key, buf in module._buffers.items(): # type: ignore[assignment]
  139. if buf is None:
  140. for j in range(num_replicas):
  141. replica = module_copies[j][i]
  142. replica._buffers[key] = None
  143. else:
  144. if buf.requires_grad and not detach:
  145. buffer_copies = buffer_copies_rg
  146. buffer_idx = buffer_indices_rg[buf]
  147. else:
  148. buffer_copies = buffer_copies_not_rg
  149. buffer_idx = buffer_indices_not_rg[buf]
  150. for j in range(num_replicas):
  151. replica = module_copies[j][i]
  152. setattr(replica, key, buffer_copies[j][buffer_idx])
  153. return [cast(T, module_copies[j][0]) for j in range(num_replicas)]