functional_call.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # mypy: allow-untyped-defs
  2. from collections import Counter
  3. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
  4. import torch
  5. import torch.nn as nn
  6. from torch import Tensor
  7. from torch._functorch.utils import exposed_in
  8. @exposed_in("torch.func")
  9. def functional_call(
  10. module: "torch.nn.Module",
  11. parameter_and_buffer_dicts: Union[Dict[str, Tensor], Sequence[Dict[str, Tensor]]],
  12. args: Union[Any, Tuple],
  13. kwargs: Optional[Dict[str, Any]] = None,
  14. *,
  15. tie_weights: bool = True,
  16. strict: bool = False,
  17. ):
  18. r"""Performs a functional call on the module by replacing the module parameters
  19. and buffers with the provided ones.
  20. .. note:: If the module has active parametrizations, passing a value in the
  21. :attr:`parameter_and_buffer_dicts` argument with the name set to the regular parameter
  22. name will completely disable the parametrization.
  23. If you want to apply the parametrization function to the value passed
  24. please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``.
  25. .. note:: If the module performs in-place operations on parameters/buffers, these will be reflected
  26. in the ``parameter_and_buffer_dicts`` input.
  27. Example::
  28. >>> a = {'foo': torch.zeros(())}
  29. >>> # xdoctest: +SKIP
  30. >>> mod = Foo() # does self.foo = self.foo + 1
  31. >>> print(mod.foo) # tensor(0.)
  32. >>> functional_call(mod, a, torch.ones(()))
  33. >>> print(mod.foo) # tensor(0.)
  34. >>> print(a['foo']) # tensor(1.)
  35. .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
  36. tie_weights flag.
  37. Example::
  38. >>> a = {'foo': torch.zeros(())}
  39. >>> # xdoctest: +SKIP
  40. >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
  41. >>> print(mod.foo) # tensor(1.)
  42. >>> mod(torch.zeros(())) # tensor(2.)
  43. >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
  44. >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
  45. >>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())}
  46. >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
  47. An example of passing multiple dictionaries
  48. .. code-block:: python
  49. a = ({'weight': torch.ones(1, 1)}, {'buffer': torch.zeros(1)}) # two separate dictionaries
  50. mod = nn.Bar(1, 1) # return self.weight @ x + self.buffer
  51. print(mod.weight) # tensor(...)
  52. print(mod.buffer) # tensor(...)
  53. x = torch.randn((1, 1))
  54. print(x)
  55. functional_call(mod, a, x) # same as x
  56. print(mod.weight) # same as before functional_call
  57. And here is an example of applying the grad transform over the parameters
  58. of a model.
  59. .. code-block:: python
  60. import torch
  61. import torch.nn as nn
  62. from torch.func import functional_call, grad
  63. x = torch.randn(4, 3)
  64. t = torch.randn(4, 3)
  65. model = nn.Linear(3, 3)
  66. def compute_loss(params, x, t):
  67. y = functional_call(model, params, x)
  68. return nn.functional.mse_loss(y, t)
  69. grad_weights = grad(compute_loss)(dict(model.named_parameters()), x, t)
  70. .. note:: If the user does not need grad tracking outside of grad transforms, they can detach all of the
  71. parameters for better performance and memory usage
  72. Example::
  73. >>> detached_params = {k: v.detach() for k, v in model.named_parameters()}
  74. >>> grad_weights = grad(compute_loss)(detached_params, x, t)
  75. >>> grad_weights.grad_fn # None--it's not tracking gradients outside of grad
  76. This means that the user cannot call ``grad_weight.backward()``. However, if they don't need autograd tracking
  77. outside of the transforms, this will result in less memory usage and faster speeds.
  78. Args:
  79. module (torch.nn.Module): the module to call
  80. parameters_and_buffer_dicts (Dict[str, Tensor] or tuple of Dict[str, Tensor]): the parameters that will be used in
  81. the module call. If given a tuple of dictionaries, they must have distinct keys so that all dictionaries can
  82. be used together
  83. args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
  84. kwargs (dict): keyword arguments to be passed to the module call
  85. tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
  86. tied in the reparameterized version. Therefore, if True and different values are passed for the tied
  87. parameters and buffers, it will error. If False, it will not respect the originally tied parameters and
  88. buffers unless the values passed for both weights are the same. Default: True.
  89. strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and
  90. buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will
  91. error. Default: False.
  92. Returns:
  93. Any: the result of calling ``module``.
  94. """
  95. if isinstance(parameter_and_buffer_dicts, dict):
  96. parameters_and_buffers = parameter_and_buffer_dicts
  97. elif isinstance(parameter_and_buffer_dicts, Sequence):
  98. if not all(isinstance(d, dict) for d in parameter_and_buffer_dicts):
  99. raise ValueError(
  100. "Expected all elements of parameter_and_buffer_dicts to be dictionaries"
  101. )
  102. all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
  103. repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1]
  104. if len(repeated_keys) > 0:
  105. raise ValueError(
  106. f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
  107. )
  108. parameters_and_buffers = {
  109. k: v for d in parameter_and_buffer_dicts for k, v in d.items()
  110. }
  111. else:
  112. raise ValueError(
  113. f"Expected parameter_and_buffer_dicts to be a dict, or a list/tuple of dicts, "
  114. f"but got {type(parameter_and_buffer_dicts)}"
  115. )
  116. return nn.utils.stateless._functional_call(
  117. module,
  118. parameters_and_buffers,
  119. args,
  120. kwargs,
  121. tie_weights=tie_weights,
  122. strict=strict,
  123. )
  124. @exposed_in("torch.func")
  125. def stack_module_state(
  126. models: List[nn.Module],
  127. ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  128. """stack_module_state(models) -> params, buffers
  129. Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
  130. Given a list of ``M`` ``nn.Modules`` of the same class, returns two dictionaries
  131. that stack all of their parameters and buffers together, indexed by name.
  132. The stacked parameters are optimizable (i.e. they are new leaf nodes in the
  133. autograd history that are unrelated to the original parameters and can be
  134. passed directly to an optimizer).
  135. Here's an example of how to ensemble over a very simple model:
  136. .. code-block:: python
  137. num_models = 5
  138. batch_size = 64
  139. in_features, out_features = 3, 3
  140. models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
  141. data = torch.randn(batch_size, 3)
  142. def wrapper(params, buffers, data):
  143. return torch.func.functional_call(model[0], (params, buffers), data)
  144. params, buffers = stack_module_state(models)
  145. output = vmap(wrapper, (0, 0, None))(params, buffers, data)
  146. assert output.shape == (num_models, batch_size, out_features)
  147. When there's submodules, this follows state dict naming conventions
  148. .. code-block:: python
  149. import torch.nn as nn
  150. class Foo(nn.Module):
  151. def __init__(self, in_features, out_features):
  152. super().__init__()
  153. hidden = 4
  154. self.l1 = nn.Linear(in_features, hidden)
  155. self.l2 = nn.Linear(hidden, out_features)
  156. def forward(self, x):
  157. return self.l2(self.l1(x))
  158. num_models = 5
  159. in_features, out_features = 3, 3
  160. models = [Foo(in_features, out_features) for i in range(num_models)]
  161. params, buffers = stack_module_state(models)
  162. print(list(params.keys())) # "l1.weight", "l1.bias", "l2.weight", "l2.bias"
  163. .. warning::
  164. All of the modules being stacked together must be the same (except for
  165. the values of their parameters/buffers). For example, they should be in the
  166. same mode (training vs eval).
  167. """
  168. if len(models) == 0:
  169. raise RuntimeError("stack_module_state: Expected at least one model, got 0.")
  170. if not (all(m.training for m in models) or all(not m.training for m in models)):
  171. raise RuntimeError(
  172. "stack_module_state: Expected all models to have the same training/eval mode."
  173. )
  174. model0_typ = type(models[0])
  175. if not all(type(m) == model0_typ for m in models):
  176. raise RuntimeError(
  177. "stack_module_state: Expected all models to be of the same class."
  178. )
  179. all_params = [dict(model.named_parameters()) for model in models]
  180. params = {
  181. k: construct_stacked_leaf(tuple(params[k] for params in all_params), k)
  182. for k in all_params[0]
  183. }
  184. all_buffers = [dict(model.named_buffers()) for model in models]
  185. buffers = {
  186. k: construct_stacked_leaf(tuple(buffers[k] for buffers in all_buffers), k)
  187. for k in all_buffers[0]
  188. }
  189. return params, buffers
  190. def construct_stacked_leaf(
  191. tensors: Union[Tuple[Tensor, ...], List[Tensor]], name: str
  192. ) -> Tensor:
  193. all_requires_grad = all(t.requires_grad for t in tensors)
  194. none_requires_grad = all(not t.requires_grad for t in tensors)
  195. if not all_requires_grad and not none_requires_grad:
  196. raise RuntimeError(
  197. f"Expected {name} from each model to have the same .requires_grad"
  198. )
  199. result = torch.stack(tensors)
  200. if all_requires_grad:
  201. result = result.detach().requires_grad_()
  202. return result