contract.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. # mypy: allow-untyped-defs
  2. import uuid
  3. from collections import OrderedDict
  4. from functools import wraps
  5. from typing import Callable, Dict, List, Optional, Type
  6. import torch.nn as nn
  7. from torch.distributed._composable_state import _State
  8. def generate_state_key(string="__composable_api_state_key"):
  9. return f"{string}_{str(uuid.uuid4())}"
  10. STATE_KEY = generate_state_key()
  11. REGISTRY_KEY = generate_state_key()
  12. # TODO: we can add additional info to RegistryItem to share across APIs. E.g.,
  13. # we can add args and kwargs here, and then we can detect whether fully_shard
  14. # is combined with reentrant activation checkpointing and error out with a clear
  15. # message.
  16. class RegistryItem:
  17. pass
  18. def contract(state_cls: Type[_State] = _State):
  19. r"""
  20. Decorate a function as a composable distributed API, where the first
  21. argument of the function must be an :class:`nn.Module` instance. The
  22. decorator verifies that the wrapped function does not modify parameter,
  23. buffer or sub-module fully-qualified names (FQN).
  24. When a function ``func`` is decorated by ``@contract()``, a
  25. ``.state(module: nn.Module)`` method will be installed to the decorated
  26. function. Then you can retrieve and modify the state on a module by calling
  27. ``func.state(module)``.
  28. Example::
  29. >>> # xdoctest: +SKIP
  30. >>> import torch.nn as nn
  31. >>>
  32. >>> class MyModel(nn.Module):
  33. >>> def __init__(self):
  34. >>> super().__init__()
  35. >>> self.l1 = nn.Linear(10, 10)
  36. >>> self.l2 = nn.Linear(10, 10)
  37. >>>
  38. >>> def forward(self, x):
  39. >>> return self.l2(self.l1(x))
  40. >>>
  41. >>> @contract()
  42. >>> def my_feature(module: nn.Module) -> nn.Module:
  43. >>> my_feature.state(module).some_state = "any value"
  44. >>> return module
  45. >>>
  46. >>> model = MyModel()
  47. >>> my_feature(model.l1)
  48. >>> assert my_feature.state(model.l1).some_state == "any value"
  49. >>> my_feature(model.l2)
  50. >>> model(torch.randn(2, 10)).sum().backward()
  51. """
  52. # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package
  53. @wraps(state_cls)
  54. def inner(func):
  55. @wraps(func)
  56. def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]:
  57. # get existing global states
  58. default_all_state: Dict[Callable, _State] = OrderedDict()
  59. all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload]
  60. STATE_KEY, default_all_state
  61. )
  62. assert isinstance(
  63. all_state, dict
  64. ), "Distributed composable API states corrupted"
  65. # get global registry
  66. default_registry: Dict[str, RegistryItem] = OrderedDict()
  67. registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload]
  68. REGISTRY_KEY, default_registry
  69. )
  70. assert isinstance(
  71. registry, dict
  72. ), "Distributed composable API registry corrupted"
  73. # make sure the API func has not been applied to the input module yet.
  74. assert func not in all_state and func.__name__ not in registry, (
  75. "Each distinct composable distributed API can only be applied to a "
  76. f"module once. {func.__name__} has already been applied to the "
  77. f"following module.\n{module}"
  78. )
  79. # install states specific to the wrapped ``func``
  80. all_state.setdefault(func, state_cls())
  81. # register ``func`` in the global registry by name
  82. registry.setdefault(func.__name__, RegistryItem())
  83. orig_named_params = OrderedDict(module.named_parameters())
  84. orig_named_buffers = OrderedDict(
  85. module.named_buffers(remove_duplicate=False)
  86. )
  87. orig_named_modules = OrderedDict(
  88. module.named_modules(remove_duplicate=False)
  89. )
  90. updated = func(module, *args, **kwargs)
  91. if updated is None:
  92. updated = module
  93. new_named_params = OrderedDict(updated.named_parameters())
  94. new_named_buffers = OrderedDict(
  95. updated.named_buffers(remove_duplicate=False)
  96. )
  97. new_named_modules = OrderedDict(
  98. updated.named_modules(remove_duplicate=False)
  99. )
  100. assert isinstance(updated, nn.Module), (
  101. "Output of composable distributed APIs must be either None or "
  102. f"nn.Module, but got {type(updated)}"
  103. )
  104. def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str):
  105. if orig_fqns == new_fqns:
  106. return
  107. orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns)
  108. orig_only = orig_fqn_set - new_fqn_set
  109. new_only = new_fqn_set - orig_fqn_set
  110. if len(orig_only) or len(new_only):
  111. raise RuntimeError(
  112. f"{check_key}"
  113. "Composable distributed API implementations cannot modify "
  114. "FQNs.\n"
  115. f"Only in original FQNs: {orig_only},\n"
  116. f"Only in new FQNs: {new_only}"
  117. )
  118. else:
  119. raise RuntimeError(
  120. f"{check_key}"
  121. "Composable distributed API implementations cannot modify "
  122. "the order of FQNs.\n"
  123. f"Original FQNs: {orig_only}\n"
  124. f"New FQNs: {new_only}"
  125. )
  126. check_fqn(
  127. list(orig_named_params.keys()),
  128. list(new_named_params.keys()),
  129. "Check parameters, ",
  130. )
  131. check_fqn(
  132. list(orig_named_buffers.keys()),
  133. list(new_named_buffers.keys()),
  134. "Check buffer, ",
  135. )
  136. check_fqn(
  137. list(orig_named_modules.keys()),
  138. list(new_named_modules.keys()),
  139. "Check modules, ",
  140. )
  141. # TODO: a stricter verification should also reject changing module
  142. # types and monkey-patching forward() method implementations.
  143. # TODO: verify that installed distributed paradigms are compatible with
  144. # each other.
  145. return updated
  146. def get_state(module: nn.Module) -> Optional[_State]:
  147. return module.__dict__.setdefault( # type: ignore[call-overload]
  148. STATE_KEY,
  149. {}, # TODO(@yhcharles): this is a temporary fix, need a better way
  150. ).get(
  151. func
  152. ) # type: ignore[call-overload]
  153. wrapper.state = get_state # type: ignore[attr-defined]
  154. return wrapper
  155. return inner
  156. def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]:
  157. r"""
  158. Get an ``OrderedDict`` of composable APIs that have been applied to the
  159. ``module``, indexed by the API name. If no API has been applied, then this
  160. returns ``None``.
  161. """
  162. return getattr(module, REGISTRY_KEY, None)