_vmap_internals.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. # mypy: allow-untyped-defs
  2. import functools
  3. from typing import Any, Callable, List, Optional, Tuple, Union
  4. from typing_extensions import deprecated
  5. import torch
  6. from torch import Tensor
  7. from torch.utils._pytree import _broadcast_to_and_flatten, tree_flatten, tree_unflatten
  8. in_dims_t = Union[int, Tuple]
  9. out_dims_t = Union[int, Tuple[int, ...]]
  10. # Checks that all args-to-be-batched have the same batch dim size
  11. def _validate_and_get_batch_size(
  12. flat_in_dims: List[Optional[int]], flat_args: List
  13. ) -> int:
  14. batch_sizes = [
  15. arg.size(in_dim)
  16. for in_dim, arg in zip(flat_in_dims, flat_args)
  17. if in_dim is not None
  18. ]
  19. if batch_sizes and any(size != batch_sizes[0] for size in batch_sizes):
  20. raise ValueError(
  21. f"vmap: Expected all tensors to have the same size in the mapped "
  22. f"dimension, got sizes {batch_sizes} for the mapped dimension"
  23. )
  24. return batch_sizes[0]
  25. def _num_outputs(batched_outputs: Union[Tensor, Tuple[Tensor, ...]]) -> int:
  26. if isinstance(batched_outputs, tuple):
  27. return len(batched_outputs)
  28. return 1
  29. # If value is a tuple, check it has length `num_elements`.
  30. # If value is not a tuple, make a tuple with `value` repeated `num_elements` times
  31. def _as_tuple(
  32. value: Any, num_elements: int, error_message_lambda: Callable[[], str]
  33. ) -> Tuple:
  34. if not isinstance(value, tuple):
  35. return (value,) * num_elements
  36. if len(value) != num_elements:
  37. raise ValueError(error_message_lambda())
  38. return value
  39. # Creates BatchedTensors for every Tensor in arg that should be batched.
  40. # Returns the (potentially) batched arguments and the batch_size.
  41. def _create_batched_inputs(
  42. in_dims: in_dims_t, args: Tuple, vmap_level: int, func: Callable
  43. ) -> Tuple[Tuple, int]:
  44. if not isinstance(in_dims, int) and not isinstance(in_dims, tuple):
  45. raise ValueError(
  46. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  47. f"expected `in_dims` to be int or a (potentially nested) tuple "
  48. f"matching the structure of inputs, got: {type(in_dims)}."
  49. )
  50. if len(args) == 0:
  51. raise ValueError(
  52. f"vmap({_get_name(func)})(<inputs>): got no inputs. Maybe you forgot to add "
  53. f"inputs, or you are trying to vmap over a function with no inputs. "
  54. f"The latter is unsupported."
  55. )
  56. flat_args, args_spec = tree_flatten(args)
  57. flat_in_dims = _broadcast_to_and_flatten(in_dims, args_spec)
  58. if flat_in_dims is None:
  59. raise ValueError(
  60. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  61. f"in_dims is not compatible with the structure of `inputs`. "
  62. f"in_dims has structure {tree_flatten(in_dims)[1]} but inputs "
  63. f"has structure {args_spec}."
  64. )
  65. for arg, in_dim in zip(flat_args, flat_in_dims):
  66. if not isinstance(in_dim, int) and in_dim is not None:
  67. raise ValueError(
  68. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  69. f"Got in_dim={in_dim} for an input but in_dim must be either "
  70. f"an integer dimension or None."
  71. )
  72. if isinstance(in_dim, int) and not isinstance(arg, Tensor):
  73. raise ValueError(
  74. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  75. f"Got in_dim={in_dim} for an input but the input is of type "
  76. f"{type(arg)}. We cannot vmap over non-Tensor arguments, "
  77. f"please use None as the respective in_dim"
  78. )
  79. if in_dim is not None and (in_dim < 0 or in_dim >= arg.dim()):
  80. raise ValueError(
  81. f"vmap({_get_name(func)}, in_dims={in_dims}, ...)(<inputs>): "
  82. f"Got in_dim={in_dim} for some input, but that input is a Tensor "
  83. f"of dimensionality {arg.dim()} so expected in_dim to satisfy "
  84. f"0 <= in_dim < {arg.dim()}."
  85. )
  86. batch_size = _validate_and_get_batch_size(flat_in_dims, flat_args)
  87. # See NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  88. batched_inputs = [
  89. arg if in_dim is None else torch._add_batch_dim(arg, in_dim, vmap_level)
  90. for in_dim, arg in zip(flat_in_dims, flat_args)
  91. ]
  92. return tree_unflatten(batched_inputs, args_spec), batch_size
  93. # Undos the batching (and any batch dimensions) associated with the `vmap_level`.
  94. def _unwrap_batched(
  95. batched_outputs: Union[Tensor, Tuple[Tensor, ...]],
  96. out_dims: out_dims_t,
  97. vmap_level: int,
  98. batch_size: int,
  99. func: Callable,
  100. allow_none_pass_through: bool = False,
  101. ) -> Tuple:
  102. num_outputs = _num_outputs(batched_outputs)
  103. out_dims_as_tuple = _as_tuple(
  104. out_dims,
  105. num_outputs,
  106. lambda: f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must "
  107. f"have one dim per output (got {num_outputs} outputs) of {_get_name(func)}.",
  108. )
  109. # NOTE [Ignored _remove_batch_dim, _add_batch_dim]
  110. # There is something wrong with our type bindings for functions that begin
  111. # with '_', see #40397.
  112. if isinstance(batched_outputs, Tensor):
  113. out_dim = out_dims_as_tuple[0]
  114. return torch._remove_batch_dim(batched_outputs, vmap_level, batch_size, out_dim) # type: ignore[return-value]
  115. if allow_none_pass_through:
  116. return tuple(
  117. (
  118. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  119. if out is not None
  120. else None
  121. )
  122. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  123. )
  124. else:
  125. return tuple(
  126. torch._remove_batch_dim(out, vmap_level, batch_size, out_dim)
  127. for out, out_dim in zip(batched_outputs, out_dims_as_tuple)
  128. )
  129. # Checks that `fn` returned one or more Tensors and nothing else.
  130. # NB: A python function that return multiple arguments returns a single tuple,
  131. # so we are effectively checking that `outputs` is a single Tensor or a tuple of
  132. # Tensors.
  133. def _validate_outputs(outputs: Any, func: Callable) -> None:
  134. if isinstance(outputs, Tensor):
  135. return
  136. if not isinstance(outputs, tuple):
  137. raise ValueError(
  138. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  139. f"Tensors, got type {type(outputs)} as the return."
  140. )
  141. for idx, output in enumerate(outputs):
  142. if isinstance(output, Tensor):
  143. continue
  144. raise ValueError(
  145. f"vmap({_get_name(func)}, ...): `{_get_name(func)}` must only return "
  146. f"Tensors, got type {type(output)} for return {idx}."
  147. )
  148. def _check_out_dims_is_int_or_int_tuple(out_dims: out_dims_t, func: Callable) -> None:
  149. if isinstance(out_dims, int):
  150. return
  151. if not isinstance(out_dims, tuple) or not all(
  152. isinstance(out_dim, int) for out_dim in out_dims
  153. ):
  154. raise ValueError(
  155. f"vmap({_get_name(func)}, ..., out_dims={out_dims}): `out_dims` must be "
  156. f"an int or a tuple of int representing where in the outputs the "
  157. f"vmapped dimension should appear."
  158. )
  159. def _get_name(func: Callable):
  160. if hasattr(func, "__name__"):
  161. return func.__name__
  162. # Not all callables have __name__, in fact, only static functions/methods do.
  163. # A callable created via functools.partial or an nn.Module, to name some
  164. # examples, don't have a __name__.
  165. return repr(func)
  166. # vmap(func)(inputs) wraps all Tensor inputs to be batched in BatchedTensors,
  167. # sends those into func, and then unwraps the output BatchedTensors. Operations
  168. # on BatchedTensors perform the batched operations that the user is asking for.
  169. @deprecated(
  170. "Please use `torch.vmap` instead of `torch._vmap_internals.vmap`.",
  171. category=FutureWarning,
  172. )
  173. def vmap(func: Callable, in_dims: in_dims_t = 0, out_dims: out_dims_t = 0) -> Callable:
  174. """
  175. Please use torch.vmap instead of this API.
  176. """
  177. return _vmap(func, in_dims, out_dims)
  178. # A version of vmap but without the initial "experimental prototype" warning
  179. def _vmap(
  180. func: Callable,
  181. in_dims: in_dims_t = 0,
  182. out_dims: out_dims_t = 0,
  183. allow_none_pass_through: bool = False,
  184. ) -> Callable:
  185. # The `allow_none_pass_through` argument is a temporary workaround may be removed.
  186. # Currently it enables us to wrap the call in `autograd.grad` to the autograd engine,
  187. # which may return None if any of the inputs are unused. See the issue discussing this:
  188. # https://github.com/facebookresearch/functorch/issues/159.
  189. @functools.wraps(func)
  190. def wrapped(*args):
  191. _check_out_dims_is_int_or_int_tuple(out_dims, func)
  192. vmap_level = torch._C._vmapmode_increment_nesting()
  193. try:
  194. batched_inputs, batch_size = _create_batched_inputs(
  195. in_dims, args, vmap_level, func
  196. )
  197. batched_outputs = func(*batched_inputs)
  198. if not allow_none_pass_through:
  199. _validate_outputs(batched_outputs, func)
  200. return _unwrap_batched(
  201. batched_outputs,
  202. out_dims,
  203. vmap_level,
  204. batch_size,
  205. func,
  206. allow_none_pass_through=allow_none_pass_through,
  207. )
  208. finally:
  209. torch._C._vmapmode_decrement_nesting()
  210. return wrapped