batchnorm.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Any
  3. import torch
  4. from torch import Tensor
  5. from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer
  6. from .. import functional as F
  7. from .. import init
  8. from ._functions import SyncBatchNorm as sync_batch_norm
  9. from .lazy import LazyModuleMixin
  10. from .module import Module
  11. __all__ = ['BatchNorm1d', 'LazyBatchNorm1d', 'BatchNorm2d', 'LazyBatchNorm2d', 'BatchNorm3d',
  12. 'LazyBatchNorm3d', 'SyncBatchNorm']
  13. class _NormBase(Module):
  14. """Common base of _InstanceNorm and _BatchNorm."""
  15. _version = 2
  16. __constants__ = ["track_running_stats", "momentum", "eps", "num_features", "affine"]
  17. num_features: int
  18. eps: float
  19. momentum: Optional[float]
  20. affine: bool
  21. track_running_stats: bool
  22. # WARNING: weight and bias purposely not defined here.
  23. # See https://github.com/pytorch/pytorch/issues/39670
  24. def __init__(
  25. self,
  26. num_features: int,
  27. eps: float = 1e-5,
  28. momentum: Optional[float] = 0.1,
  29. affine: bool = True,
  30. track_running_stats: bool = True,
  31. device=None,
  32. dtype=None
  33. ) -> None:
  34. factory_kwargs = {'device': device, 'dtype': dtype}
  35. super().__init__()
  36. self.num_features = num_features
  37. self.eps = eps
  38. self.momentum = momentum
  39. self.affine = affine
  40. self.track_running_stats = track_running_stats
  41. if self.affine:
  42. self.weight = Parameter(torch.empty(num_features, **factory_kwargs))
  43. self.bias = Parameter(torch.empty(num_features, **factory_kwargs))
  44. else:
  45. self.register_parameter("weight", None)
  46. self.register_parameter("bias", None)
  47. if self.track_running_stats:
  48. self.register_buffer('running_mean', torch.zeros(num_features, **factory_kwargs))
  49. self.register_buffer('running_var', torch.ones(num_features, **factory_kwargs))
  50. self.running_mean: Optional[Tensor]
  51. self.running_var: Optional[Tensor]
  52. self.register_buffer('num_batches_tracked',
  53. torch.tensor(0, dtype=torch.long,
  54. **{k: v for k, v in factory_kwargs.items() if k != 'dtype'}))
  55. self.num_batches_tracked: Optional[Tensor]
  56. else:
  57. self.register_buffer("running_mean", None)
  58. self.register_buffer("running_var", None)
  59. self.register_buffer("num_batches_tracked", None)
  60. self.reset_parameters()
  61. def reset_running_stats(self) -> None:
  62. if self.track_running_stats:
  63. # running_mean/running_var/num_batches... are registered at runtime depending
  64. # if self.track_running_stats is on
  65. self.running_mean.zero_() # type: ignore[union-attr]
  66. self.running_var.fill_(1) # type: ignore[union-attr]
  67. self.num_batches_tracked.zero_() # type: ignore[union-attr,operator]
  68. def reset_parameters(self) -> None:
  69. self.reset_running_stats()
  70. if self.affine:
  71. init.ones_(self.weight)
  72. init.zeros_(self.bias)
  73. def _check_input_dim(self, input):
  74. raise NotImplementedError
  75. def extra_repr(self):
  76. return (
  77. "{num_features}, eps={eps}, momentum={momentum}, affine={affine}, "
  78. "track_running_stats={track_running_stats}".format(**self.__dict__)
  79. )
  80. def _load_from_state_dict(
  81. self,
  82. state_dict,
  83. prefix,
  84. local_metadata,
  85. strict,
  86. missing_keys,
  87. unexpected_keys,
  88. error_msgs,
  89. ):
  90. version = local_metadata.get("version", None)
  91. if (version is None or version < 2) and self.track_running_stats:
  92. # at version 2: added num_batches_tracked buffer
  93. # this should have a default value of 0
  94. num_batches_tracked_key = prefix + "num_batches_tracked"
  95. if num_batches_tracked_key not in state_dict:
  96. state_dict[num_batches_tracked_key] = (
  97. self.num_batches_tracked
  98. if self.num_batches_tracked is not None and self.num_batches_tracked.device != torch.device('meta')
  99. else torch.tensor(0, dtype=torch.long)
  100. )
  101. super()._load_from_state_dict(
  102. state_dict,
  103. prefix,
  104. local_metadata,
  105. strict,
  106. missing_keys,
  107. unexpected_keys,
  108. error_msgs,
  109. )
  110. class _BatchNorm(_NormBase):
  111. def __init__(
  112. self,
  113. num_features: int,
  114. eps: float = 1e-5,
  115. momentum: Optional[float] = 0.1,
  116. affine: bool = True,
  117. track_running_stats: bool = True,
  118. device=None,
  119. dtype=None
  120. ) -> None:
  121. factory_kwargs = {'device': device, 'dtype': dtype}
  122. super().__init__(
  123. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  124. )
  125. def forward(self, input: Tensor) -> Tensor:
  126. self._check_input_dim(input)
  127. # exponential_average_factor is set to self.momentum
  128. # (when it is available) only so that it gets updated
  129. # in ONNX graph when this node is exported to ONNX.
  130. if self.momentum is None:
  131. exponential_average_factor = 0.0
  132. else:
  133. exponential_average_factor = self.momentum
  134. if self.training and self.track_running_stats:
  135. # TODO: if statement only here to tell the jit to skip emitting this when it is None
  136. if self.num_batches_tracked is not None: # type: ignore[has-type]
  137. self.num_batches_tracked.add_(1) # type: ignore[has-type]
  138. if self.momentum is None: # use cumulative moving average
  139. exponential_average_factor = 1.0 / float(self.num_batches_tracked)
  140. else: # use exponential moving average
  141. exponential_average_factor = self.momentum
  142. r"""
  143. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  144. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  145. """
  146. if self.training:
  147. bn_training = True
  148. else:
  149. bn_training = (self.running_mean is None) and (self.running_var is None)
  150. r"""
  151. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  152. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  153. used for normalization (i.e. in eval mode when buffers are not None).
  154. """
  155. return F.batch_norm(
  156. input,
  157. # If buffers are not to be tracked, ensure that they won't be updated
  158. self.running_mean
  159. if not self.training or self.track_running_stats
  160. else None,
  161. self.running_var if not self.training or self.track_running_stats else None,
  162. self.weight,
  163. self.bias,
  164. bn_training,
  165. exponential_average_factor,
  166. self.eps,
  167. )
  168. class _LazyNormBase(LazyModuleMixin, _NormBase):
  169. weight: UninitializedParameter # type: ignore[assignment]
  170. bias: UninitializedParameter # type: ignore[assignment]
  171. def __init__(self, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
  172. device=None, dtype=None) -> None:
  173. factory_kwargs = {'device': device, 'dtype': dtype}
  174. super().__init__(
  175. # affine and track_running_stats are hardcoded to False to
  176. # avoid creating tensors that will soon be overwritten.
  177. 0,
  178. eps,
  179. momentum,
  180. False,
  181. False,
  182. **factory_kwargs,
  183. )
  184. self.affine = affine
  185. self.track_running_stats = track_running_stats
  186. if self.affine:
  187. self.weight = UninitializedParameter(**factory_kwargs)
  188. self.bias = UninitializedParameter(**factory_kwargs)
  189. if self.track_running_stats:
  190. self.running_mean = UninitializedBuffer(**factory_kwargs)
  191. self.running_var = UninitializedBuffer(**factory_kwargs)
  192. self.num_batches_tracked = torch.tensor(
  193. 0, dtype=torch.long, **{k: v for k, v in factory_kwargs.items() if k != 'dtype'})
  194. def reset_parameters(self) -> None:
  195. if not self.has_uninitialized_params() and self.num_features != 0:
  196. super().reset_parameters()
  197. def initialize_parameters(self, input) -> None: # type: ignore[override]
  198. if self.has_uninitialized_params():
  199. self.num_features = input.shape[1]
  200. if self.affine:
  201. assert isinstance(self.weight, UninitializedParameter)
  202. assert isinstance(self.bias, UninitializedParameter)
  203. self.weight.materialize((self.num_features,))
  204. self.bias.materialize((self.num_features,))
  205. if self.track_running_stats:
  206. self.running_mean.materialize((self.num_features,)) # type:ignore[union-attr]
  207. self.running_var.materialize((self.num_features,)) # type:ignore[union-attr]
  208. self.reset_parameters()
  209. class BatchNorm1d(_BatchNorm):
  210. r"""Applies Batch Normalization over a 2D or 3D input.
  211. Method described in the paper
  212. `Batch Normalization: Accelerating Deep Network Training by Reducing
  213. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  214. .. math::
  215. y = \frac{x - \mathrm{E}[x]}{\sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  216. The mean and standard-deviation are calculated per-dimension over
  217. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  218. of size `C` (where `C` is the number of features or channels of the input). By default, the
  219. elements of :math:`\gamma` are set to 1 and the elements of :math:`\beta` are set to 0.
  220. At train time in the forward pass, the standard-deviation is calculated via the biased estimator,
  221. equivalent to ``torch.var(input, unbiased=False)``. However, the value stored in the
  222. moving average of the standard-deviation is calculated via the unbiased estimator, equivalent to
  223. ``torch.var(input, unbiased=True)``.
  224. Also by default, during training this layer keeps running estimates of its
  225. computed mean and variance, which are then used for normalization during
  226. evaluation. The running estimates are kept with a default :attr:`momentum`
  227. of 0.1.
  228. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  229. keep running estimates, and batch statistics are instead used during
  230. evaluation time as well.
  231. .. note::
  232. This :attr:`momentum` argument is different from one used in optimizer
  233. classes and the conventional notion of momentum. Mathematically, the
  234. update rule for running statistics here is
  235. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  236. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  237. new observed value.
  238. Because the Batch Normalization is done over the `C` dimension, computing statistics
  239. on `(N, L)` slices, it's common terminology to call this Temporal Batch Normalization.
  240. Args:
  241. num_features: number of features or channels :math:`C` of the input
  242. eps: a value added to the denominator for numerical stability.
  243. Default: 1e-5
  244. momentum: the value used for the running_mean and running_var
  245. computation. Can be set to ``None`` for cumulative moving average
  246. (i.e. simple average). Default: 0.1
  247. affine: a boolean value that when set to ``True``, this module has
  248. learnable affine parameters. Default: ``True``
  249. track_running_stats: a boolean value that when set to ``True``, this
  250. module tracks the running mean and variance, and when set to ``False``,
  251. this module does not track such statistics, and initializes statistics
  252. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  253. When these buffers are ``None``, this module always uses batch statistics.
  254. in both training and eval modes. Default: ``True``
  255. Shape:
  256. - Input: :math:`(N, C)` or :math:`(N, C, L)`, where :math:`N` is the batch size,
  257. :math:`C` is the number of features or channels, and :math:`L` is the sequence length
  258. - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
  259. Examples::
  260. >>> # With Learnable Parameters
  261. >>> m = nn.BatchNorm1d(100)
  262. >>> # Without Learnable Parameters
  263. >>> m = nn.BatchNorm1d(100, affine=False)
  264. >>> input = torch.randn(20, 100)
  265. >>> output = m(input)
  266. """
  267. def _check_input_dim(self, input):
  268. if input.dim() != 2 and input.dim() != 3:
  269. raise ValueError(
  270. f"expected 2D or 3D input (got {input.dim()}D input)"
  271. )
  272. class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
  273. r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
  274. Lazy initialization based on the ``num_features`` argument of the :class:`BatchNorm1d` that is inferred
  275. from the ``input.size(1)``.
  276. The attributes that will be lazily initialized are `weight`, `bias`,
  277. `running_mean` and `running_var`.
  278. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  279. on lazy modules and their limitations.
  280. Args:
  281. eps: a value added to the denominator for numerical stability.
  282. Default: 1e-5
  283. momentum: the value used for the running_mean and running_var
  284. computation. Can be set to ``None`` for cumulative moving average
  285. (i.e. simple average). Default: 0.1
  286. affine: a boolean value that when set to ``True``, this module has
  287. learnable affine parameters. Default: ``True``
  288. track_running_stats: a boolean value that when set to ``True``, this
  289. module tracks the running mean and variance, and when set to ``False``,
  290. this module does not track such statistics, and initializes statistics
  291. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  292. When these buffers are ``None``, this module always uses batch statistics.
  293. in both training and eval modes. Default: ``True``
  294. """
  295. cls_to_become = BatchNorm1d # type: ignore[assignment]
  296. def _check_input_dim(self, input):
  297. if input.dim() != 2 and input.dim() != 3:
  298. raise ValueError(
  299. f"expected 2D or 3D input (got {input.dim()}D input)"
  300. )
  301. class BatchNorm2d(_BatchNorm):
  302. r"""Applies Batch Normalization over a 4D input.
  303. 4D is a mini-batch of 2D inputs
  304. with additional channel dimension. Method described in the paper
  305. `Batch Normalization: Accelerating Deep Network Training by Reducing
  306. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  307. .. math::
  308. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  309. The mean and standard-deviation are calculated per-dimension over
  310. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  311. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  312. to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
  313. standard-deviation is calculated via the biased estimator, equivalent to
  314. ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
  315. standard-deviation is calculated via the unbiased estimator, equivalent to
  316. ``torch.var(input, unbiased=True)``.
  317. Also by default, during training this layer keeps running estimates of its
  318. computed mean and variance, which are then used for normalization during
  319. evaluation. The running estimates are kept with a default :attr:`momentum`
  320. of 0.1.
  321. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  322. keep running estimates, and batch statistics are instead used during
  323. evaluation time as well.
  324. .. note::
  325. This :attr:`momentum` argument is different from one used in optimizer
  326. classes and the conventional notion of momentum. Mathematically, the
  327. update rule for running statistics here is
  328. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  329. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  330. new observed value.
  331. Because the Batch Normalization is done over the `C` dimension, computing statistics
  332. on `(N, H, W)` slices, it's common terminology to call this Spatial Batch Normalization.
  333. Args:
  334. num_features: :math:`C` from an expected input of size
  335. :math:`(N, C, H, W)`
  336. eps: a value added to the denominator for numerical stability.
  337. Default: 1e-5
  338. momentum: the value used for the running_mean and running_var
  339. computation. Can be set to ``None`` for cumulative moving average
  340. (i.e. simple average). Default: 0.1
  341. affine: a boolean value that when set to ``True``, this module has
  342. learnable affine parameters. Default: ``True``
  343. track_running_stats: a boolean value that when set to ``True``, this
  344. module tracks the running mean and variance, and when set to ``False``,
  345. this module does not track such statistics, and initializes statistics
  346. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  347. When these buffers are ``None``, this module always uses batch statistics.
  348. in both training and eval modes. Default: ``True``
  349. Shape:
  350. - Input: :math:`(N, C, H, W)`
  351. - Output: :math:`(N, C, H, W)` (same shape as input)
  352. Examples::
  353. >>> # With Learnable Parameters
  354. >>> m = nn.BatchNorm2d(100)
  355. >>> # Without Learnable Parameters
  356. >>> m = nn.BatchNorm2d(100, affine=False)
  357. >>> input = torch.randn(20, 100, 35, 45)
  358. >>> output = m(input)
  359. """
  360. def _check_input_dim(self, input):
  361. if input.dim() != 4:
  362. raise ValueError(f"expected 4D input (got {input.dim()}D input)")
  363. class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
  364. r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
  365. Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm2d` that is inferred
  366. from the ``input.size(1)``.
  367. The attributes that will be lazily initialized are `weight`, `bias`,
  368. `running_mean` and `running_var`.
  369. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  370. on lazy modules and their limitations.
  371. Args:
  372. eps: a value added to the denominator for numerical stability.
  373. Default: 1e-5
  374. momentum: the value used for the running_mean and running_var
  375. computation. Can be set to ``None`` for cumulative moving average
  376. (i.e. simple average). Default: 0.1
  377. affine: a boolean value that when set to ``True``, this module has
  378. learnable affine parameters. Default: ``True``
  379. track_running_stats: a boolean value that when set to ``True``, this
  380. module tracks the running mean and variance, and when set to ``False``,
  381. this module does not track such statistics, and initializes statistics
  382. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  383. When these buffers are ``None``, this module always uses batch statistics.
  384. in both training and eval modes. Default: ``True``
  385. """
  386. cls_to_become = BatchNorm2d # type: ignore[assignment]
  387. def _check_input_dim(self, input):
  388. if input.dim() != 4:
  389. raise ValueError(f"expected 4D input (got {input.dim()}D input)")
  390. class BatchNorm3d(_BatchNorm):
  391. r"""Applies Batch Normalization over a 5D input.
  392. 5D is a mini-batch of 3D inputs with additional channel dimension as described in the paper
  393. `Batch Normalization: Accelerating Deep Network Training by Reducing
  394. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  395. .. math::
  396. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  397. The mean and standard-deviation are calculated per-dimension over
  398. the mini-batches and :math:`\gamma` and :math:`\beta` are learnable parameter vectors
  399. of size `C` (where `C` is the input size). By default, the elements of :math:`\gamma` are set
  400. to 1 and the elements of :math:`\beta` are set to 0. At train time in the forward pass, the
  401. standard-deviation is calculated via the biased estimator, equivalent to
  402. ``torch.var(input, unbiased=False)``. However, the value stored in the moving average of the
  403. standard-deviation is calculated via the unbiased estimator, equivalent to
  404. ``torch.var(input, unbiased=True)``.
  405. Also by default, during training this layer keeps running estimates of its
  406. computed mean and variance, which are then used for normalization during
  407. evaluation. The running estimates are kept with a default :attr:`momentum`
  408. of 0.1.
  409. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  410. keep running estimates, and batch statistics are instead used during
  411. evaluation time as well.
  412. .. note::
  413. This :attr:`momentum` argument is different from one used in optimizer
  414. classes and the conventional notion of momentum. Mathematically, the
  415. update rule for running statistics here is
  416. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  417. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  418. new observed value.
  419. Because the Batch Normalization is done over the `C` dimension, computing statistics
  420. on `(N, D, H, W)` slices, it's common terminology to call this Volumetric Batch Normalization
  421. or Spatio-temporal Batch Normalization.
  422. Args:
  423. num_features: :math:`C` from an expected input of size
  424. :math:`(N, C, D, H, W)`
  425. eps: a value added to the denominator for numerical stability.
  426. Default: 1e-5
  427. momentum: the value used for the running_mean and running_var
  428. computation. Can be set to ``None`` for cumulative moving average
  429. (i.e. simple average). Default: 0.1
  430. affine: a boolean value that when set to ``True``, this module has
  431. learnable affine parameters. Default: ``True``
  432. track_running_stats: a boolean value that when set to ``True``, this
  433. module tracks the running mean and variance, and when set to ``False``,
  434. this module does not track such statistics, and initializes statistics
  435. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  436. When these buffers are ``None``, this module always uses batch statistics.
  437. in both training and eval modes. Default: ``True``
  438. Shape:
  439. - Input: :math:`(N, C, D, H, W)`
  440. - Output: :math:`(N, C, D, H, W)` (same shape as input)
  441. Examples::
  442. >>> # With Learnable Parameters
  443. >>> m = nn.BatchNorm3d(100)
  444. >>> # Without Learnable Parameters
  445. >>> m = nn.BatchNorm3d(100, affine=False)
  446. >>> input = torch.randn(20, 100, 35, 45, 10)
  447. >>> output = m(input)
  448. """
  449. def _check_input_dim(self, input):
  450. if input.dim() != 5:
  451. raise ValueError(f"expected 5D input (got {input.dim()}D input)")
  452. class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
  453. r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
  454. Lazy initialization is done for the ``num_features`` argument of the :class:`BatchNorm3d` that is inferred
  455. from the ``input.size(1)``.
  456. The attributes that will be lazily initialized are `weight`, `bias`,
  457. `running_mean` and `running_var`.
  458. Check the :class:`torch.nn.modules.lazy.LazyModuleMixin` for further documentation
  459. on lazy modules and their limitations.
  460. Args:
  461. eps: a value added to the denominator for numerical stability.
  462. Default: 1e-5
  463. momentum: the value used for the running_mean and running_var
  464. computation. Can be set to ``None`` for cumulative moving average
  465. (i.e. simple average). Default: 0.1
  466. affine: a boolean value that when set to ``True``, this module has
  467. learnable affine parameters. Default: ``True``
  468. track_running_stats: a boolean value that when set to ``True``, this
  469. module tracks the running mean and variance, and when set to ``False``,
  470. this module does not track such statistics, and initializes statistics
  471. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  472. When these buffers are ``None``, this module always uses batch statistics.
  473. in both training and eval modes. Default: ``True``
  474. """
  475. cls_to_become = BatchNorm3d # type: ignore[assignment]
  476. def _check_input_dim(self, input):
  477. if input.dim() != 5:
  478. raise ValueError(f"expected 5D input (got {input.dim()}D input)")
  479. class SyncBatchNorm(_BatchNorm):
  480. r"""Applies Batch Normalization over a N-Dimensional input.
  481. The N-D input is a mini-batch of [N-2]D inputs with additional channel dimension) as described in the paper
  482. `Batch Normalization: Accelerating Deep Network Training by Reducing
  483. Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .
  484. .. math::
  485. y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta
  486. The mean and standard-deviation are calculated per-dimension over all
  487. mini-batches of the same process groups. :math:`\gamma` and :math:`\beta`
  488. are learnable parameter vectors of size `C` (where `C` is the input size).
  489. By default, the elements of :math:`\gamma` are sampled from
  490. :math:`\mathcal{U}(0, 1)` and the elements of :math:`\beta` are set to 0.
  491. The standard-deviation is calculated via the biased estimator, equivalent to
  492. `torch.var(input, unbiased=False)`.
  493. Also by default, during training this layer keeps running estimates of its
  494. computed mean and variance, which are then used for normalization during
  495. evaluation. The running estimates are kept with a default :attr:`momentum`
  496. of 0.1.
  497. If :attr:`track_running_stats` is set to ``False``, this layer then does not
  498. keep running estimates, and batch statistics are instead used during
  499. evaluation time as well.
  500. .. note::
  501. This :attr:`momentum` argument is different from one used in optimizer
  502. classes and the conventional notion of momentum. Mathematically, the
  503. update rule for running statistics here is
  504. :math:`\hat{x}_\text{new} = (1 - \text{momentum}) \times \hat{x} + \text{momentum} \times x_t`,
  505. where :math:`\hat{x}` is the estimated statistic and :math:`x_t` is the
  506. new observed value.
  507. Because the Batch Normalization is done for each channel in the ``C`` dimension, computing
  508. statistics on ``(N, +)`` slices, it's common terminology to call this Volumetric Batch
  509. Normalization or Spatio-temporal Batch Normalization.
  510. Currently :class:`SyncBatchNorm` only supports
  511. :class:`~torch.nn.DistributedDataParallel` (DDP) with single GPU per process. Use
  512. :meth:`torch.nn.SyncBatchNorm.convert_sync_batchnorm()` to convert
  513. :attr:`BatchNorm*D` layer to :class:`SyncBatchNorm` before wrapping
  514. Network with DDP.
  515. Args:
  516. num_features: :math:`C` from an expected input of size
  517. :math:`(N, C, +)`
  518. eps: a value added to the denominator for numerical stability.
  519. Default: ``1e-5``
  520. momentum: the value used for the running_mean and running_var
  521. computation. Can be set to ``None`` for cumulative moving average
  522. (i.e. simple average). Default: 0.1
  523. affine: a boolean value that when set to ``True``, this module has
  524. learnable affine parameters. Default: ``True``
  525. track_running_stats: a boolean value that when set to ``True``, this
  526. module tracks the running mean and variance, and when set to ``False``,
  527. this module does not track such statistics, and initializes statistics
  528. buffers :attr:`running_mean` and :attr:`running_var` as ``None``.
  529. When these buffers are ``None``, this module always uses batch statistics.
  530. in both training and eval modes. Default: ``True``
  531. process_group: synchronization of stats happen within each process group
  532. individually. Default behavior is synchronization across the whole
  533. world
  534. Shape:
  535. - Input: :math:`(N, C, +)`
  536. - Output: :math:`(N, C, +)` (same shape as input)
  537. .. note::
  538. Synchronization of batchnorm statistics occurs only while training, i.e.
  539. synchronization is disabled when ``model.eval()`` is set or if
  540. ``self.training`` is otherwise ``False``.
  541. Examples::
  542. >>> # xdoctest: +SKIP
  543. >>> # With Learnable Parameters
  544. >>> m = nn.SyncBatchNorm(100)
  545. >>> # creating process group (optional)
  546. >>> # ranks is a list of int identifying rank ids.
  547. >>> ranks = list(range(8))
  548. >>> r1, r2 = ranks[:4], ranks[4:]
  549. >>> # Note: every rank calls into new_group for every
  550. >>> # process group created, even if that rank is not
  551. >>> # part of the group.
  552. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  553. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  554. >>> # Without Learnable Parameters
  555. >>> m = nn.BatchNorm3d(100, affine=False, process_group=process_group)
  556. >>> input = torch.randn(20, 100, 35, 45, 10)
  557. >>> output = m(input)
  558. >>> # network is nn.BatchNorm layer
  559. >>> sync_bn_network = nn.SyncBatchNorm.convert_sync_batchnorm(network, process_group)
  560. >>> # only single gpu per process is currently supported
  561. >>> ddp_sync_bn_network = torch.nn.parallel.DistributedDataParallel(
  562. >>> sync_bn_network,
  563. >>> device_ids=[args.local_rank],
  564. >>> output_device=args.local_rank)
  565. """
  566. def __init__(
  567. self,
  568. num_features: int,
  569. eps: float = 1e-5,
  570. momentum: Optional[float] = 0.1,
  571. affine: bool = True,
  572. track_running_stats: bool = True,
  573. process_group: Optional[Any] = None,
  574. device=None,
  575. dtype=None
  576. ) -> None:
  577. factory_kwargs = {'device': device, 'dtype': dtype}
  578. super().__init__(
  579. num_features, eps, momentum, affine, track_running_stats, **factory_kwargs
  580. )
  581. self.process_group = process_group
  582. def _check_input_dim(self, input):
  583. if input.dim() < 2:
  584. raise ValueError(
  585. f"expected at least 2D input (got {input.dim()}D input)"
  586. )
  587. def _check_non_zero_input_channels(self, input):
  588. if input.size(1) == 0:
  589. raise ValueError(
  590. "SyncBatchNorm number of input channels should be non-zero"
  591. )
  592. def forward(self, input: Tensor) -> Tensor:
  593. self._check_input_dim(input)
  594. self._check_non_zero_input_channels(input)
  595. # exponential_average_factor is set to self.momentum
  596. # (when it is available) only so that it gets updated
  597. # in ONNX graph when this node is exported to ONNX.
  598. if self.momentum is None:
  599. exponential_average_factor = 0.0
  600. else:
  601. exponential_average_factor = self.momentum
  602. if self.training and self.track_running_stats:
  603. assert self.num_batches_tracked is not None
  604. self.num_batches_tracked.add_(1)
  605. if self.momentum is None: # use cumulative moving average
  606. exponential_average_factor = 1.0 / self.num_batches_tracked.item()
  607. else: # use exponential moving average
  608. exponential_average_factor = self.momentum
  609. r"""
  610. Decide whether the mini-batch stats should be used for normalization rather than the buffers.
  611. Mini-batch stats are used in training mode, and in eval mode when buffers are None.
  612. """
  613. if self.training:
  614. bn_training = True
  615. else:
  616. bn_training = (self.running_mean is None) and (self.running_var is None)
  617. r"""
  618. Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
  619. passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
  620. used for normalization (i.e. in eval mode when buffers are not None).
  621. """
  622. # If buffers are not to be tracked, ensure that they won't be updated
  623. running_mean = (
  624. self.running_mean if not self.training or self.track_running_stats else None
  625. )
  626. running_var = (
  627. self.running_var if not self.training or self.track_running_stats else None
  628. )
  629. # Don't sync batchnorm stats in inference mode (model.eval()).
  630. need_sync = (bn_training and self.training and
  631. torch.distributed.is_available() and torch.distributed.is_initialized())
  632. if need_sync:
  633. # currently only GPU/PrivateUse1 input is supported
  634. if input.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
  635. raise ValueError("SyncBatchNorm expected input tensor to be on GPU or "
  636. f"{torch._C._get_privateuse1_backend_name()}")
  637. process_group = torch.distributed.group.WORLD
  638. if self.process_group:
  639. process_group = self.process_group
  640. world_size = torch.distributed.get_world_size(process_group)
  641. need_sync = world_size > 1
  642. # fallback to framework BN when synchronization is not necessary
  643. if not need_sync:
  644. return F.batch_norm(
  645. input,
  646. running_mean,
  647. running_var,
  648. self.weight,
  649. self.bias,
  650. bn_training,
  651. exponential_average_factor,
  652. self.eps,
  653. )
  654. else:
  655. assert bn_training
  656. return sync_batch_norm.apply(
  657. input,
  658. self.weight,
  659. self.bias,
  660. running_mean,
  661. running_var,
  662. self.eps,
  663. exponential_average_factor,
  664. process_group, # type: ignore[possibly-undefined]
  665. world_size, # type: ignore[possibly-undefined]
  666. )
  667. @classmethod
  668. def convert_sync_batchnorm(cls, module, process_group=None):
  669. r"""Converts all :attr:`BatchNorm*D` layers in the model to :class:`torch.nn.SyncBatchNorm` layers.
  670. Args:
  671. module (nn.Module): module containing one or more :attr:`BatchNorm*D` layers
  672. process_group (optional): process group to scope synchronization,
  673. default is the whole world
  674. Returns:
  675. The original :attr:`module` with the converted :class:`torch.nn.SyncBatchNorm`
  676. layers. If the original :attr:`module` is a :attr:`BatchNorm*D` layer,
  677. a new :class:`torch.nn.SyncBatchNorm` layer object will be returned
  678. instead.
  679. Example::
  680. >>> # Network with nn.BatchNorm layer
  681. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  682. >>> module = torch.nn.Sequential(
  683. >>> torch.nn.Linear(20, 100),
  684. >>> torch.nn.BatchNorm1d(100),
  685. >>> ).cuda()
  686. >>> # creating process group (optional)
  687. >>> # ranks is a list of int identifying rank ids.
  688. >>> ranks = list(range(8))
  689. >>> r1, r2 = ranks[:4], ranks[4:]
  690. >>> # Note: every rank calls into new_group for every
  691. >>> # process group created, even if that rank is not
  692. >>> # part of the group.
  693. >>> # xdoctest: +SKIP("distributed")
  694. >>> process_groups = [torch.distributed.new_group(pids) for pids in [r1, r2]]
  695. >>> process_group = process_groups[0 if dist.get_rank() <= 3 else 1]
  696. >>> sync_bn_module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module, process_group)
  697. """
  698. module_output = module
  699. if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
  700. module_output = torch.nn.SyncBatchNorm(
  701. module.num_features,
  702. module.eps,
  703. module.momentum,
  704. module.affine,
  705. module.track_running_stats,
  706. process_group,
  707. )
  708. if module.affine:
  709. with torch.no_grad():
  710. module_output.weight = module.weight
  711. module_output.bias = module.bias
  712. module_output.running_mean = module.running_mean
  713. module_output.running_var = module.running_var
  714. module_output.num_batches_tracked = module.num_batches_tracked
  715. module_output.training = module.training
  716. if hasattr(module, "qconfig"):
  717. module_output.qconfig = module.qconfig
  718. for name, child in module.named_children():
  719. module_output.add_module(
  720. name, cls.convert_sync_batchnorm(child, process_group)
  721. )
  722. del module
  723. return module_output