fusion.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from __future__ import annotations
  2. import copy
  3. from typing import Optional, Tuple, TypeVar
  4. import torch
  5. __all__ = ['fuse_conv_bn_eval', 'fuse_conv_bn_weights', 'fuse_linear_bn_eval', 'fuse_linear_bn_weights']
  6. ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd")
  7. LinearT = TypeVar("LinearT", bound="torch.nn.Linear")
  8. def fuse_conv_bn_eval(conv: ConvT, bn: torch.nn.modules.batchnorm._BatchNorm, transpose: bool = False) -> ConvT:
  9. r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module.
  10. Args:
  11. conv (torch.nn.modules.conv._ConvNd): A convolutional module.
  12. bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
  13. transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False.
  14. Returns:
  15. torch.nn.modules.conv._ConvNd: The fused convolutional module.
  16. .. note::
  17. Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
  18. """
  19. assert not (conv.training or bn.training), "Fusion only for eval!"
  20. fused_conv = copy.deepcopy(conv)
  21. assert bn.running_mean is not None and bn.running_var is not None
  22. fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
  23. fused_conv.weight, fused_conv.bias,
  24. bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)
  25. return fused_conv
  26. def fuse_conv_bn_weights(
  27. conv_w: torch.Tensor,
  28. conv_b: Optional[torch.Tensor],
  29. bn_rm: torch.Tensor,
  30. bn_rv: torch.Tensor,
  31. bn_eps: float,
  32. bn_w: Optional[torch.Tensor],
  33. bn_b: Optional[torch.Tensor],
  34. transpose: bool = False
  35. ) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
  36. r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
  37. Args:
  38. conv_w (torch.Tensor): Convolutional weight.
  39. conv_b (Optional[torch.Tensor]): Convolutional bias.
  40. bn_rm (torch.Tensor): BatchNorm running mean.
  41. bn_rv (torch.Tensor): BatchNorm running variance.
  42. bn_eps (float): BatchNorm epsilon.
  43. bn_w (Optional[torch.Tensor]): BatchNorm weight.
  44. bn_b (Optional[torch.Tensor]): BatchNorm bias.
  45. transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
  46. Returns:
  47. Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias.
  48. """
  49. conv_weight_dtype = conv_w.dtype
  50. conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
  51. if conv_b is None:
  52. conv_b = torch.zeros_like(bn_rm)
  53. if bn_w is None:
  54. bn_w = torch.ones_like(bn_rm)
  55. if bn_b is None:
  56. bn_b = torch.zeros_like(bn_rm)
  57. bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
  58. if transpose:
  59. shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
  60. else:
  61. shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
  62. fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(dtype=conv_weight_dtype)
  63. fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(dtype=conv_bias_dtype)
  64. return (
  65. torch.nn.Parameter(fused_conv_w, conv_w.requires_grad), torch.nn.Parameter(fused_conv_b, conv_b.requires_grad)
  66. )
  67. def fuse_linear_bn_eval(linear: LinearT, bn: torch.nn.modules.batchnorm._BatchNorm) -> LinearT:
  68. r"""Fuse a linear module and a BatchNorm module into a single, new linear module.
  69. Args:
  70. linear (torch.nn.Linear): A Linear module.
  71. bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
  72. Returns:
  73. torch.nn.Linear: The fused linear module.
  74. .. note::
  75. Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
  76. """
  77. assert not (linear.training or bn.training), "Fusion only for eval!"
  78. fused_linear = copy.deepcopy(linear)
  79. """
  80. Linear-BN needs to be fused while preserving the shapes of linear weight/bias.
  81. To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear,
  82. because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in).
  83. To be broadcastable, the number of features in bn and
  84. the number of output features from linear must satisfy the following condition:
  85. 1. they are equal, or
  86. 2. the number of features in bn is 1
  87. Otherwise, skip the folding path
  88. """
  89. assert (
  90. linear.out_features == bn.num_features or bn.num_features == 1
  91. ), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
  92. assert bn.running_mean is not None and bn.running_var is not None
  93. fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
  94. fused_linear.weight, fused_linear.bias,
  95. bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias)
  96. return fused_linear
  97. def fuse_linear_bn_weights(
  98. linear_w: torch.Tensor,
  99. linear_b: Optional[torch.Tensor],
  100. bn_rm: torch.Tensor,
  101. bn_rv: torch.Tensor,
  102. bn_eps: float,
  103. bn_w: torch.Tensor,
  104. bn_b: torch.Tensor,
  105. ) -> Tuple[torch.nn.Parameter, torch.nn.Parameter]:
  106. r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
  107. Args:
  108. linear_w (torch.Tensor): Linear weight.
  109. linear_b (Optional[torch.Tensor]): Linear bias.
  110. bn_rm (torch.Tensor): BatchNorm running mean.
  111. bn_rv (torch.Tensor): BatchNorm running variance.
  112. bn_eps (float): BatchNorm epsilon.
  113. bn_w (torch.Tensor): BatchNorm weight.
  114. bn_b (torch.Tensor): BatchNorm bias.
  115. transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
  116. Returns:
  117. Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias.
  118. """
  119. if linear_b is None:
  120. linear_b = torch.zeros_like(bn_rm)
  121. bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
  122. fused_w = linear_w * bn_scale.unsqueeze(-1)
  123. fused_b = (linear_b - bn_rm) * bn_scale + bn_b
  124. return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(fused_b, linear_b.requires_grad)