parametrizations.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
  1. # mypy: allow-untyped-defs
  2. from enum import Enum, auto
  3. import torch
  4. from torch import Tensor
  5. from ..utils import parametrize
  6. from ..modules import Module
  7. from .. import functional as F
  8. from typing import Optional
  9. __all__ = ['orthogonal', 'spectral_norm', 'weight_norm']
  10. def _is_orthogonal(Q, eps=None):
  11. n, k = Q.size(-2), Q.size(-1)
  12. Id = torch.eye(k, dtype=Q.dtype, device=Q.device)
  13. # A reasonable eps, but not too large
  14. eps = 10. * n * torch.finfo(Q.dtype).eps
  15. return torch.allclose(Q.mH @ Q, Id, atol=eps)
  16. def _make_orthogonal(A):
  17. """Assume that A is a tall matrix.
  18. Compute the Q factor s.t. A = QR (A may be complex) and diag(R) is real and non-negative.
  19. """
  20. X, tau = torch.geqrf(A)
  21. Q = torch.linalg.householder_product(X, tau)
  22. # The diagonal of X is the diagonal of R (which is always real) so we normalise by its signs
  23. Q *= X.diagonal(dim1=-2, dim2=-1).sgn().unsqueeze(-2)
  24. return Q
  25. class _OrthMaps(Enum):
  26. matrix_exp = auto()
  27. cayley = auto()
  28. householder = auto()
  29. class _Orthogonal(Module):
  30. base: Tensor
  31. def __init__(self,
  32. weight,
  33. orthogonal_map: _OrthMaps,
  34. *,
  35. use_trivialization=True) -> None:
  36. super().__init__()
  37. # Note [Householder complex]
  38. # For complex tensors, it is not possible to compute the tensor `tau` necessary for
  39. # linalg.householder_product from the reflectors.
  40. # To see this, note that the reflectors have a shape like:
  41. # 0 0 0
  42. # * 0 0
  43. # * * 0
  44. # which, for complex matrices, give n(n-1) (real) parameters. Now, you need n^2 parameters
  45. # to parametrize the unitary matrices. Saving tau on its own does not work either, because
  46. # not every combination of `(A, tau)` gives a unitary matrix, meaning that if we optimise
  47. # them as independent tensors we would not maintain the constraint
  48. # An equivalent reasoning holds for rectangular matrices
  49. if weight.is_complex() and orthogonal_map == _OrthMaps.householder:
  50. raise ValueError("The householder parametrization does not support complex tensors.")
  51. self.shape = weight.shape
  52. self.orthogonal_map = orthogonal_map
  53. if use_trivialization:
  54. self.register_buffer("base", None)
  55. def forward(self, X: torch.Tensor) -> torch.Tensor:
  56. n, k = X.size(-2), X.size(-1)
  57. transposed = n < k
  58. if transposed:
  59. X = X.mT
  60. n, k = k, n
  61. # Here n > k and X is a tall matrix
  62. if self.orthogonal_map == _OrthMaps.matrix_exp or self.orthogonal_map == _OrthMaps.cayley:
  63. # We just need n x k - k(k-1)/2 parameters
  64. X = X.tril()
  65. if n != k:
  66. # Embed into a square matrix
  67. X = torch.cat([X, X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
  68. A = X - X.mH
  69. # A is skew-symmetric (or skew-hermitian)
  70. if self.orthogonal_map == _OrthMaps.matrix_exp:
  71. Q = torch.matrix_exp(A)
  72. elif self.orthogonal_map == _OrthMaps.cayley:
  73. # Computes the Cayley retraction (I+A/2)(I-A/2)^{-1}
  74. Id = torch.eye(n, dtype=A.dtype, device=A.device)
  75. Q = torch.linalg.solve(torch.add(Id, A, alpha=-0.5), torch.add(Id, A, alpha=0.5))
  76. # Q is now orthogonal (or unitary) of size (..., n, n)
  77. if n != k:
  78. Q = Q[..., :k]
  79. # Q is now the size of the X (albeit perhaps transposed)
  80. else:
  81. # X is real here, as we do not support householder with complex numbers
  82. A = X.tril(diagonal=-1)
  83. tau = 2. / (1. + (A * A).sum(dim=-2))
  84. Q = torch.linalg.householder_product(A, tau)
  85. # The diagonal of X is 1's and -1's
  86. # We do not want to differentiate through this or update the diagonal of X hence the casting
  87. Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
  88. if hasattr(self, "base"):
  89. Q = self.base @ Q
  90. if transposed:
  91. Q = Q.mT
  92. return Q # type: ignore[possibly-undefined]
  93. @torch.autograd.no_grad()
  94. def right_inverse(self, Q: torch.Tensor) -> torch.Tensor:
  95. if Q.shape != self.shape:
  96. raise ValueError(f"Expected a matrix or batch of matrices of shape {self.shape}. "
  97. f"Got a tensor of shape {Q.shape}.")
  98. Q_init = Q
  99. n, k = Q.size(-2), Q.size(-1)
  100. transpose = n < k
  101. if transpose:
  102. Q = Q.mT
  103. n, k = k, n
  104. # We always make sure to always copy Q in every path
  105. if not hasattr(self, "base"):
  106. # Note [right_inverse expm cayley]
  107. # If we do not have use_trivialization=True, we just implement the inverse of the forward
  108. # map for the Householder. To see why, think that for the Cayley map,
  109. # we would need to find the matrix X \in R^{n x k} such that:
  110. # Y = torch.cat([X.tril(), X.new_zeros(n, n - k).expand(*X.shape[:-2], -1, -1)], dim=-1)
  111. # A = Y - Y.mH
  112. # cayley(A)[:, :k]
  113. # gives the original tensor. It is not clear how to do this.
  114. # Perhaps via some algebraic manipulation involving the QR like that of
  115. # Corollary 2.2 in Edelman, Arias and Smith?
  116. if self.orthogonal_map == _OrthMaps.cayley or self.orthogonal_map == _OrthMaps.matrix_exp:
  117. raise NotImplementedError("It is not possible to assign to the matrix exponential "
  118. "or the Cayley parametrizations when use_trivialization=False.")
  119. # If parametrization == _OrthMaps.householder, make Q orthogonal via the QR decomposition.
  120. # Here Q is always real because we do not support householder and complex matrices.
  121. # See note [Householder complex]
  122. A, tau = torch.geqrf(Q)
  123. # We want to have a decomposition X = QR with diag(R) > 0, as otherwise we could
  124. # decompose an orthogonal matrix Q as Q = (-Q)@(-Id), which is a valid QR decomposition
  125. # The diagonal of Q is the diagonal of R from the qr decomposition
  126. A.diagonal(dim1=-2, dim2=-1).sign_()
  127. # Equality with zero is ok because LAPACK returns exactly zero when it does not want
  128. # to use a particular reflection
  129. A.diagonal(dim1=-2, dim2=-1)[tau == 0.] *= -1
  130. return A.mT if transpose else A
  131. else:
  132. if n == k:
  133. # We check whether Q is orthogonal
  134. if not _is_orthogonal(Q):
  135. Q = _make_orthogonal(Q)
  136. else: # Is orthogonal
  137. Q = Q.clone()
  138. else:
  139. # Complete Q into a full n x n orthogonal matrix
  140. N = torch.randn(*(Q.size()[:-2] + (n, n - k)), dtype=Q.dtype, device=Q.device)
  141. Q = torch.cat([Q, N], dim=-1)
  142. Q = _make_orthogonal(Q)
  143. self.base = Q
  144. # It is necessary to return the -Id, as we use the diagonal for the
  145. # Householder parametrization. Using -Id makes:
  146. # householder(torch.zeros(m,n)) == torch.eye(m,n)
  147. # Poor man's version of eye_like
  148. neg_Id = torch.zeros_like(Q_init)
  149. neg_Id.diagonal(dim1=-2, dim2=-1).fill_(-1.)
  150. return neg_Id
  151. def orthogonal(module: Module,
  152. name: str = 'weight',
  153. orthogonal_map: Optional[str] = None,
  154. *,
  155. use_trivialization: bool = True) -> Module:
  156. r"""Apply an orthogonal or unitary parametrization to a matrix or a batch of matrices.
  157. Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, the parametrized
  158. matrix :math:`Q \in \mathbb{K}^{m \times n}` is **orthogonal** as
  159. .. math::
  160. \begin{align*}
  161. Q^{\text{H}}Q &= \mathrm{I}_n \mathrlap{\qquad \text{if }m \geq n}\\
  162. QQ^{\text{H}} &= \mathrm{I}_m \mathrlap{\qquad \text{if }m < n}
  163. \end{align*}
  164. where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex
  165. and the transpose when :math:`Q` is real-valued, and
  166. :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix.
  167. In plain words, :math:`Q` will have orthonormal columns whenever :math:`m \geq n`
  168. and orthonormal rows otherwise.
  169. If the tensor has more than two dimensions, we consider it as a batch of matrices of shape `(..., m, n)`.
  170. The matrix :math:`Q` may be parametrized via three different ``orthogonal_map`` in terms of the original tensor:
  171. - ``"matrix_exp"``/``"cayley"``:
  172. the :func:`~torch.matrix_exp` :math:`Q = \exp(A)` and the `Cayley map`_
  173. :math:`Q = (\mathrm{I}_n + A/2)(\mathrm{I}_n - A/2)^{-1}` are applied to a skew-symmetric
  174. :math:`A` to give an orthogonal matrix.
  175. - ``"householder"``: computes a product of Householder reflectors
  176. (:func:`~torch.linalg.householder_product`).
  177. ``"matrix_exp"``/``"cayley"`` often make the parametrized weight converge faster than
  178. ``"householder"``, but they are slower to compute for very thin or very wide matrices.
  179. If ``use_trivialization=True`` (default), the parametrization implements the "Dynamic Trivialization Framework",
  180. where an extra matrix :math:`B \in \mathbb{K}^{n \times n}` is stored under
  181. ``module.parametrizations.weight[0].base``. This helps the
  182. convergence of the parametrized layer at the expense of some extra memory use.
  183. See `Trivializations for Gradient-Based Optimization on Manifolds`_ .
  184. Initial value of :math:`Q`:
  185. If the original tensor is not parametrized and ``use_trivialization=True`` (default), the initial value
  186. of :math:`Q` is that of the original tensor if it is orthogonal (or unitary in the complex case)
  187. and it is orthogonalized via the QR decomposition otherwise (see :func:`torch.linalg.qr`).
  188. Same happens when it is not parametrized and ``orthogonal_map="householder"`` even when ``use_trivialization=False``.
  189. Otherwise, the initial value is the result of the composition of all the registered
  190. parametrizations applied to the original tensor.
  191. .. note::
  192. This function is implemented using the parametrization functionality
  193. in :func:`~torch.nn.utils.parametrize.register_parametrization`.
  194. .. _`Cayley map`: https://en.wikipedia.org/wiki/Cayley_transform#Matrix_map
  195. .. _`Trivializations for Gradient-Based Optimization on Manifolds`: https://arxiv.org/abs/1909.09501
  196. Args:
  197. module (nn.Module): module on which to register the parametrization.
  198. name (str, optional): name of the tensor to make orthogonal. Default: ``"weight"``.
  199. orthogonal_map (str, optional): One of the following: ``"matrix_exp"``, ``"cayley"``, ``"householder"``.
  200. Default: ``"matrix_exp"`` if the matrix is square or complex, ``"householder"`` otherwise.
  201. use_trivialization (bool, optional): whether to use the dynamic trivialization framework.
  202. Default: ``True``.
  203. Returns:
  204. The original module with an orthogonal parametrization registered to the specified
  205. weight
  206. Example::
  207. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  208. >>> orth_linear = orthogonal(nn.Linear(20, 40))
  209. >>> orth_linear
  210. ParametrizedLinear(
  211. in_features=20, out_features=40, bias=True
  212. (parametrizations): ModuleDict(
  213. (weight): ParametrizationList(
  214. (0): _Orthogonal()
  215. )
  216. )
  217. )
  218. >>> # xdoctest: +IGNORE_WANT
  219. >>> Q = orth_linear.weight
  220. >>> torch.dist(Q.T @ Q, torch.eye(20))
  221. tensor(4.9332e-07)
  222. """
  223. weight = getattr(module, name, None)
  224. if not isinstance(weight, Tensor):
  225. raise ValueError(
  226. f"Module '{module}' has no parameter or buffer with name '{name}'"
  227. )
  228. # We could implement this for 1-dim tensors as the maps on the sphere
  229. # but I believe it'd bite more people than it'd help
  230. if weight.ndim < 2:
  231. raise ValueError("Expected a matrix or batch of matrices. "
  232. f"Got a tensor of {weight.ndim} dimensions.")
  233. if orthogonal_map is None:
  234. orthogonal_map = "matrix_exp" if weight.size(-2) == weight.size(-1) or weight.is_complex() else "householder"
  235. orth_enum = getattr(_OrthMaps, orthogonal_map, None)
  236. if orth_enum is None:
  237. raise ValueError('orthogonal_map has to be one of "matrix_exp", "cayley", "householder". '
  238. f'Got: {orthogonal_map}')
  239. orth = _Orthogonal(weight,
  240. orth_enum,
  241. use_trivialization=use_trivialization)
  242. parametrize.register_parametrization(module, name, orth, unsafe=True)
  243. return module
  244. class _WeightNorm(Module):
  245. def __init__(
  246. self,
  247. dim: Optional[int] = 0,
  248. ) -> None:
  249. super().__init__()
  250. if dim is None:
  251. dim = -1
  252. self.dim = dim
  253. def forward(self, weight_g, weight_v):
  254. return torch._weight_norm(weight_v, weight_g, self.dim)
  255. def right_inverse(self, weight):
  256. weight_g = torch.norm_except_dim(weight, 2, self.dim)
  257. weight_v = weight
  258. return weight_g, weight_v
  259. def weight_norm(module: Module, name: str = 'weight', dim: int = 0):
  260. r"""Apply weight normalization to a parameter in the given module.
  261. .. math::
  262. \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}
  263. Weight normalization is a reparameterization that decouples the magnitude
  264. of a weight tensor from its direction. This replaces the parameter specified
  265. by :attr:`name` with two parameters: one specifying the magnitude
  266. and one specifying the direction.
  267. By default, with ``dim=0``, the norm is computed independently per output
  268. channel/plane. To compute a norm over the entire weight tensor, use
  269. ``dim=None``.
  270. See https://arxiv.org/abs/1602.07868
  271. Args:
  272. module (Module): containing module
  273. name (str, optional): name of weight parameter
  274. dim (int, optional): dimension over which to compute the norm
  275. Returns:
  276. The original module with the weight norm hook
  277. Example::
  278. >>> m = weight_norm(nn.Linear(20, 40), name='weight')
  279. >>> m
  280. ParametrizedLinear(
  281. in_features=20, out_features=40, bias=True
  282. (parametrizations): ModuleDict(
  283. (weight): ParametrizationList(
  284. (0): _WeightNorm()
  285. )
  286. )
  287. )
  288. >>> m.parametrizations.weight.original0.size()
  289. torch.Size([40, 1])
  290. >>> m.parametrizations.weight.original1.size()
  291. torch.Size([40, 20])
  292. """
  293. _weight_norm = _WeightNorm(dim)
  294. parametrize.register_parametrization(module, name, _weight_norm, unsafe=True)
  295. def _weight_norm_compat_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
  296. g_key = f"{prefix}{name}_g"
  297. v_key = f"{prefix}{name}_v"
  298. if g_key in state_dict and v_key in state_dict:
  299. original0 = state_dict.pop(g_key)
  300. original1 = state_dict.pop(v_key)
  301. state_dict[f"{prefix}parametrizations.{name}.original0"] = original0
  302. state_dict[f"{prefix}parametrizations.{name}.original1"] = original1
  303. module._register_load_state_dict_pre_hook(_weight_norm_compat_hook)
  304. return module
  305. class _SpectralNorm(Module):
  306. def __init__(
  307. self,
  308. weight: torch.Tensor,
  309. n_power_iterations: int = 1,
  310. dim: int = 0,
  311. eps: float = 1e-12
  312. ) -> None:
  313. super().__init__()
  314. ndim = weight.ndim
  315. if dim >= ndim or dim < -ndim:
  316. raise IndexError("Dimension out of range (expected to be in range of "
  317. f"[-{ndim}, {ndim - 1}] but got {dim})")
  318. if n_power_iterations <= 0:
  319. raise ValueError('Expected n_power_iterations to be positive, but '
  320. f'got n_power_iterations={n_power_iterations}')
  321. self.dim = dim if dim >= 0 else dim + ndim
  322. self.eps = eps
  323. if ndim > 1:
  324. # For ndim == 1 we do not need to approximate anything (see _SpectralNorm.forward)
  325. self.n_power_iterations = n_power_iterations
  326. weight_mat = self._reshape_weight_to_matrix(weight)
  327. h, w = weight_mat.size()
  328. u = weight_mat.new_empty(h).normal_(0, 1)
  329. v = weight_mat.new_empty(w).normal_(0, 1)
  330. self.register_buffer('_u', F.normalize(u, dim=0, eps=self.eps))
  331. self.register_buffer('_v', F.normalize(v, dim=0, eps=self.eps))
  332. # Start with u, v initialized to some reasonable values by performing a number
  333. # of iterations of the power method
  334. self._power_method(weight_mat, 15)
  335. def _reshape_weight_to_matrix(self, weight: torch.Tensor) -> torch.Tensor:
  336. # Precondition
  337. assert weight.ndim > 1
  338. if self.dim != 0:
  339. # permute dim to front
  340. weight = weight.permute(self.dim, *(d for d in range(weight.dim()) if d != self.dim))
  341. return weight.flatten(1)
  342. @torch.autograd.no_grad()
  343. def _power_method(self, weight_mat: torch.Tensor, n_power_iterations: int) -> None:
  344. # See original note at torch/nn/utils/spectral_norm.py
  345. # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
  346. # updated in power iteration **in-place**. This is very important
  347. # because in `DataParallel` forward, the vectors (being buffers) are
  348. # broadcast from the parallelized module to each module replica,
  349. # which is a new module object created on the fly. And each replica
  350. # runs its own spectral norm power iteration. So simply assigning
  351. # the updated vectors to the module this function runs on will cause
  352. # the update to be lost forever. And the next time the parallelized
  353. # module is replicated, the same randomly initialized vectors are
  354. # broadcast and used!
  355. #
  356. # Therefore, to make the change propagate back, we rely on two
  357. # important behaviors (also enforced via tests):
  358. # 1. `DataParallel` doesn't clone storage if the broadcast tensor
  359. # is already on correct device; and it makes sure that the
  360. # parallelized module is already on `device[0]`.
  361. # 2. If the out tensor in `out=` kwarg has correct shape, it will
  362. # just fill in the values.
  363. # Therefore, since the same power iteration is performed on all
  364. # devices, simply updating the tensors in-place will make sure that
  365. # the module replica on `device[0]` will update the _u vector on the
  366. # parallelized module (by shared storage).
  367. #
  368. # However, after we update `u` and `v` in-place, we need to **clone**
  369. # them before using them to normalize the weight. This is to support
  370. # backproping through two forward passes, e.g., the common pattern in
  371. # GAN training: loss = D(real) - D(fake). Otherwise, engine will
  372. # complain that variables needed to do backward for the first forward
  373. # (i.e., the `u` and `v` vectors) are changed in the second forward.
  374. # Precondition
  375. assert weight_mat.ndim > 1
  376. for _ in range(n_power_iterations):
  377. # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
  378. # are the first left and right singular vectors.
  379. # This power iteration produces approximations of `u` and `v`.
  380. self._u = F.normalize(torch.mv(weight_mat, self._v), # type: ignore[has-type]
  381. dim=0, eps=self.eps, out=self._u) # type: ignore[has-type]
  382. self._v = F.normalize(torch.mv(weight_mat.H, self._u),
  383. dim=0, eps=self.eps, out=self._v) # type: ignore[has-type]
  384. def forward(self, weight: torch.Tensor) -> torch.Tensor:
  385. if weight.ndim == 1:
  386. # Faster and more exact path, no need to approximate anything
  387. return F.normalize(weight, dim=0, eps=self.eps)
  388. else:
  389. weight_mat = self._reshape_weight_to_matrix(weight)
  390. if self.training:
  391. self._power_method(weight_mat, self.n_power_iterations)
  392. # See above on why we need to clone
  393. u = self._u.clone(memory_format=torch.contiguous_format)
  394. v = self._v.clone(memory_format=torch.contiguous_format)
  395. # The proper way of computing this should be through F.bilinear, but
  396. # it seems to have some efficiency issues:
  397. # https://github.com/pytorch/pytorch/issues/58093
  398. sigma = torch.vdot(u, torch.mv(weight_mat, v))
  399. return weight / sigma
  400. def right_inverse(self, value: torch.Tensor) -> torch.Tensor:
  401. # we may want to assert here that the passed value already
  402. # satisfies constraints
  403. return value
  404. def spectral_norm(module: Module,
  405. name: str = 'weight',
  406. n_power_iterations: int = 1,
  407. eps: float = 1e-12,
  408. dim: Optional[int] = None) -> Module:
  409. r"""Apply spectral normalization to a parameter in the given module.
  410. .. math::
  411. \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
  412. \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
  413. When applied on a vector, it simplifies to
  414. .. math::
  415. \mathbf{x}_{SN} = \dfrac{\mathbf{x}}{\|\mathbf{x}\|_2}
  416. Spectral normalization stabilizes the training of discriminators (critics)
  417. in Generative Adversarial Networks (GANs) by reducing the Lipschitz constant
  418. of the model. :math:`\sigma` is approximated performing one iteration of the
  419. `power method`_ every time the weight is accessed. If the dimension of the
  420. weight tensor is greater than 2, it is reshaped to 2D in power iteration
  421. method to get spectral norm.
  422. See `Spectral Normalization for Generative Adversarial Networks`_ .
  423. .. _`power method`: https://en.wikipedia.org/wiki/Power_iteration
  424. .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
  425. .. note::
  426. This function is implemented using the parametrization functionality
  427. in :func:`~torch.nn.utils.parametrize.register_parametrization`. It is a
  428. reimplementation of :func:`torch.nn.utils.spectral_norm`.
  429. .. note::
  430. When this constraint is registered, the singular vectors associated to the largest
  431. singular value are estimated rather than sampled at random. These are then updated
  432. performing :attr:`n_power_iterations` of the `power method`_ whenever the tensor
  433. is accessed with the module on `training` mode.
  434. .. note::
  435. If the `_SpectralNorm` module, i.e., `module.parametrization.weight[idx]`,
  436. is in training mode on removal, it will perform another power iteration.
  437. If you'd like to avoid this iteration, set the module to eval mode
  438. before its removal.
  439. Args:
  440. module (nn.Module): containing module
  441. name (str, optional): name of weight parameter. Default: ``"weight"``.
  442. n_power_iterations (int, optional): number of power iterations to
  443. calculate spectral norm. Default: ``1``.
  444. eps (float, optional): epsilon for numerical stability in
  445. calculating norms. Default: ``1e-12``.
  446. dim (int, optional): dimension corresponding to number of outputs.
  447. Default: ``0``, except for modules that are instances of
  448. ConvTranspose{1,2,3}d, when it is ``1``
  449. Returns:
  450. The original module with a new parametrization registered to the specified
  451. weight
  452. Example::
  453. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  454. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  455. >>> snm = spectral_norm(nn.Linear(20, 40))
  456. >>> snm
  457. ParametrizedLinear(
  458. in_features=20, out_features=40, bias=True
  459. (parametrizations): ModuleDict(
  460. (weight): ParametrizationList(
  461. (0): _SpectralNorm()
  462. )
  463. )
  464. )
  465. >>> torch.linalg.matrix_norm(snm.weight, 2)
  466. tensor(1.0081, grad_fn=<AmaxBackward0>)
  467. """
  468. weight = getattr(module, name, None)
  469. if not isinstance(weight, Tensor):
  470. raise ValueError(
  471. f"Module '{module}' has no parameter or buffer with name '{name}'"
  472. )
  473. if dim is None:
  474. if isinstance(module, (torch.nn.ConvTranspose1d,
  475. torch.nn.ConvTranspose2d,
  476. torch.nn.ConvTranspose3d)):
  477. dim = 1
  478. else:
  479. dim = 0
  480. parametrize.register_parametrization(module, name, _SpectralNorm(weight, n_power_iterations, dim, eps))
  481. return module