init.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. # mypy: allow-untyped-defs
  2. """This file contains utilities for initializing neural network parameters."""
  3. import math
  4. import warnings
  5. from torch import Tensor
  6. import torch
  7. from typing import Optional as _Optional
  8. # These no_grad_* functions are necessary as wrappers around the parts of these
  9. # functions that use `with torch.no_grad()`. The JIT doesn't support context
  10. # managers, so these need to be implemented as builtins. Using these wrappers
  11. # lets us keep those builtins small and re-usable.
  12. def _no_grad_uniform_(tensor, a, b, generator=None):
  13. with torch.no_grad():
  14. return tensor.uniform_(a, b, generator=generator)
  15. def _no_grad_normal_(tensor, mean, std, generator=None):
  16. with torch.no_grad():
  17. return tensor.normal_(mean, std, generator=generator)
  18. def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
  19. # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
  20. def norm_cdf(x):
  21. # Computes standard normal cumulative distribution function
  22. return (1. + math.erf(x / math.sqrt(2.))) / 2.
  23. if (mean < a - 2 * std) or (mean > b + 2 * std):
  24. warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
  25. "The distribution of values may be incorrect.",
  26. stacklevel=2)
  27. with torch.no_grad():
  28. # Values are generated by using a truncated uniform distribution and
  29. # then using the inverse CDF for the normal distribution.
  30. # Get upper and lower cdf values
  31. l = norm_cdf((a - mean) / std)
  32. u = norm_cdf((b - mean) / std)
  33. # Uniformly fill tensor with values from [l, u], then translate to
  34. # [2l-1, 2u-1].
  35. tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator)
  36. # Use inverse cdf transform for normal distribution to get truncated
  37. # standard normal
  38. tensor.erfinv_()
  39. # Transform to proper mean, std
  40. tensor.mul_(std * math.sqrt(2.))
  41. tensor.add_(mean)
  42. # Clamp to ensure it's in the proper range
  43. tensor.clamp_(min=a, max=b)
  44. return tensor
  45. def _no_grad_fill_(tensor, val):
  46. with torch.no_grad():
  47. return tensor.fill_(val)
  48. def _no_grad_zero_(tensor):
  49. with torch.no_grad():
  50. return tensor.zero_()
  51. def calculate_gain(nonlinearity, param=None):
  52. r"""Return the recommended gain value for the given nonlinearity function.
  53. The values are as follows:
  54. ================= ====================================================
  55. nonlinearity gain
  56. ================= ====================================================
  57. Linear / Identity :math:`1`
  58. Conv{1,2,3}D :math:`1`
  59. Sigmoid :math:`1`
  60. Tanh :math:`\frac{5}{3}`
  61. ReLU :math:`\sqrt{2}`
  62. Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
  63. SELU :math:`\frac{3}{4}`
  64. ================= ====================================================
  65. .. warning::
  66. In order to implement `Self-Normalizing Neural Networks`_ ,
  67. you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``.
  68. This gives the initial weights a variance of ``1 / N``,
  69. which is necessary to induce a stable fixed point in the forward pass.
  70. In contrast, the default gain for ``SELU`` sacrifices the normalization
  71. effect for more stable gradient flow in rectangular layers.
  72. Args:
  73. nonlinearity: the non-linear function (`nn.functional` name)
  74. param: optional parameter for the non-linear function
  75. Examples:
  76. >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
  77. .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
  78. """
  79. linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
  80. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  81. return 1
  82. elif nonlinearity == 'tanh':
  83. return 5.0 / 3
  84. elif nonlinearity == 'relu':
  85. return math.sqrt(2.0)
  86. elif nonlinearity == 'leaky_relu':
  87. if param is None:
  88. negative_slope = 0.01
  89. elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
  90. # True/False are instances of int, hence check above
  91. negative_slope = param
  92. else:
  93. raise ValueError(f"negative_slope {param} not a valid number")
  94. return math.sqrt(2.0 / (1 + negative_slope ** 2))
  95. elif nonlinearity == 'selu':
  96. return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
  97. else:
  98. raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
  99. def uniform_(
  100. tensor: Tensor,
  101. a: float = 0.0,
  102. b: float = 1.0,
  103. generator: _Optional[torch.Generator] = None,
  104. ) -> Tensor:
  105. r"""Fill the input Tensor with values drawn from the uniform distribution.
  106. :math:`\mathcal{U}(a, b)`.
  107. Args:
  108. tensor: an n-dimensional `torch.Tensor`
  109. a: the lower bound of the uniform distribution
  110. b: the upper bound of the uniform distribution
  111. generator: the torch Generator to sample from (default: None)
  112. Examples:
  113. >>> w = torch.empty(3, 5)
  114. >>> nn.init.uniform_(w)
  115. """
  116. if torch.overrides.has_torch_function_variadic(tensor):
  117. return torch.overrides.handle_torch_function(
  118. uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
  119. )
  120. return _no_grad_uniform_(tensor, a, b, generator)
  121. def normal_(
  122. tensor: Tensor,
  123. mean: float = 0.0,
  124. std: float = 1.0,
  125. generator: _Optional[torch.Generator] = None,
  126. ) -> Tensor:
  127. r"""Fill the input Tensor with values drawn from the normal distribution.
  128. :math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
  129. Args:
  130. tensor: an n-dimensional `torch.Tensor`
  131. mean: the mean of the normal distribution
  132. std: the standard deviation of the normal distribution
  133. generator: the torch Generator to sample from (default: None)
  134. Examples:
  135. >>> w = torch.empty(3, 5)
  136. >>> nn.init.normal_(w)
  137. """
  138. if torch.overrides.has_torch_function_variadic(tensor):
  139. return torch.overrides.handle_torch_function(
  140. normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
  141. )
  142. return _no_grad_normal_(tensor, mean, std, generator)
  143. def trunc_normal_(
  144. tensor: Tensor,
  145. mean: float = 0.,
  146. std: float = 1.,
  147. a: float = -2.,
  148. b: float = 2.,
  149. generator: _Optional[torch.Generator] = None
  150. ) -> Tensor:
  151. r"""Fill the input Tensor with values drawn from a truncated normal distribution.
  152. The values are effectively drawn from the
  153. normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
  154. with values outside :math:`[a, b]` redrawn until they are within
  155. the bounds. The method used for generating the random values works
  156. best when :math:`a \leq \text{mean} \leq b`.
  157. Args:
  158. tensor: an n-dimensional `torch.Tensor`
  159. mean: the mean of the normal distribution
  160. std: the standard deviation of the normal distribution
  161. a: the minimum cutoff value
  162. b: the maximum cutoff value
  163. generator: the torch Generator to sample from (default: None)
  164. Examples:
  165. >>> w = torch.empty(3, 5)
  166. >>> nn.init.trunc_normal_(w)
  167. """
  168. return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator)
  169. def constant_(tensor: Tensor, val: float) -> Tensor:
  170. r"""Fill the input Tensor with the value :math:`\text{val}`.
  171. Args:
  172. tensor: an n-dimensional `torch.Tensor`
  173. val: the value to fill the tensor with
  174. Examples:
  175. >>> w = torch.empty(3, 5)
  176. >>> nn.init.constant_(w, 0.3)
  177. """
  178. if torch.overrides.has_torch_function_variadic(tensor):
  179. return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val)
  180. return _no_grad_fill_(tensor, val)
  181. def ones_(tensor: Tensor) -> Tensor:
  182. r"""Fill the input Tensor with the scalar value `1`.
  183. Args:
  184. tensor: an n-dimensional `torch.Tensor`
  185. Examples:
  186. >>> w = torch.empty(3, 5)
  187. >>> nn.init.ones_(w)
  188. """
  189. return _no_grad_fill_(tensor, 1.)
  190. def zeros_(tensor: Tensor) -> Tensor:
  191. r"""Fill the input Tensor with the scalar value `0`.
  192. Args:
  193. tensor: an n-dimensional `torch.Tensor`
  194. Examples:
  195. >>> w = torch.empty(3, 5)
  196. >>> nn.init.zeros_(w)
  197. """
  198. return _no_grad_zero_(tensor)
  199. def eye_(tensor):
  200. r"""Fill the 2-dimensional input `Tensor` with the identity matrix.
  201. Preserves the identity of the inputs in `Linear` layers, where as
  202. many inputs are preserved as possible.
  203. Args:
  204. tensor: a 2-dimensional `torch.Tensor`
  205. Examples:
  206. >>> w = torch.empty(3, 5)
  207. >>> nn.init.eye_(w)
  208. """
  209. if tensor.ndimension() != 2:
  210. raise ValueError("Only tensors with 2 dimensions are supported")
  211. with torch.no_grad():
  212. torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad)
  213. return tensor
  214. def dirac_(tensor, groups=1):
  215. r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function.
  216. Preserves the identity of the inputs in `Convolutional`
  217. layers, where as many input channels are preserved as possible. In case
  218. of groups>1, each group of channels preserves identity
  219. Args:
  220. tensor: a {3, 4, 5}-dimensional `torch.Tensor`
  221. groups (int, optional): number of groups in the conv layer (default: 1)
  222. Examples:
  223. >>> w = torch.empty(3, 16, 5, 5)
  224. >>> nn.init.dirac_(w)
  225. >>> w = torch.empty(3, 24, 5, 5)
  226. >>> nn.init.dirac_(w, 3)
  227. """
  228. dimensions = tensor.ndimension()
  229. if dimensions not in [3, 4, 5]:
  230. raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported")
  231. sizes = tensor.size()
  232. if sizes[0] % groups != 0:
  233. raise ValueError('dim 0 must be divisible by groups')
  234. out_chans_per_grp = sizes[0] // groups
  235. min_dim = min(out_chans_per_grp, sizes[1])
  236. with torch.no_grad():
  237. tensor.zero_()
  238. for g in range(groups):
  239. for d in range(min_dim):
  240. if dimensions == 3: # Temporal convolution
  241. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1
  242. elif dimensions == 4: # Spatial convolution
  243. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
  244. tensor.size(3) // 2] = 1
  245. else: # Volumetric convolution
  246. tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2,
  247. tensor.size(3) // 2, tensor.size(4) // 2] = 1
  248. return tensor
  249. def _calculate_fan_in_and_fan_out(tensor):
  250. dimensions = tensor.dim()
  251. if dimensions < 2:
  252. raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
  253. num_input_fmaps = tensor.size(1)
  254. num_output_fmaps = tensor.size(0)
  255. receptive_field_size = 1
  256. if tensor.dim() > 2:
  257. # math.prod is not always available, accumulate the product manually
  258. # we could use functools.reduce but that is not supported by TorchScript
  259. for s in tensor.shape[2:]:
  260. receptive_field_size *= s
  261. fan_in = num_input_fmaps * receptive_field_size
  262. fan_out = num_output_fmaps * receptive_field_size
  263. return fan_in, fan_out
  264. def xavier_uniform_(
  265. tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None
  266. ) -> Tensor:
  267. r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
  268. The method is described in `Understanding the difficulty of training
  269. deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010).
  270. The resulting tensor will have values sampled from
  271. :math:`\mathcal{U}(-a, a)` where
  272. .. math::
  273. a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}
  274. Also known as Glorot initialization.
  275. Args:
  276. tensor: an n-dimensional `torch.Tensor`
  277. gain: an optional scaling factor
  278. generator: the torch Generator to sample from (default: None)
  279. Examples:
  280. >>> w = torch.empty(3, 5)
  281. >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
  282. """
  283. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  284. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  285. a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  286. return _no_grad_uniform_(tensor, -a, a, generator)
  287. def xavier_normal_(
  288. tensor: Tensor,
  289. gain: float = 1.0,
  290. generator: _Optional[torch.Generator] = None,
  291. ) -> Tensor:
  292. r"""Fill the input `Tensor` with values using a Xavier normal distribution.
  293. The method is described in `Understanding the difficulty of training deep feedforward
  294. neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor
  295. will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where
  296. .. math::
  297. \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}}
  298. Also known as Glorot initialization.
  299. Args:
  300. tensor: an n-dimensional `torch.Tensor`
  301. gain: an optional scaling factor
  302. generator: the torch Generator to sample from (default: None)
  303. Examples:
  304. >>> w = torch.empty(3, 5)
  305. >>> nn.init.xavier_normal_(w)
  306. """
  307. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  308. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  309. return _no_grad_normal_(tensor, 0., std, generator)
  310. def _calculate_correct_fan(tensor, mode):
  311. mode = mode.lower()
  312. valid_modes = ['fan_in', 'fan_out']
  313. if mode not in valid_modes:
  314. raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}")
  315. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
  316. return fan_in if mode == 'fan_in' else fan_out
  317. def kaiming_uniform_(
  318. tensor: Tensor,
  319. a: float = 0,
  320. mode: str = "fan_in",
  321. nonlinearity: str = "leaky_relu",
  322. generator: _Optional[torch.Generator] = None,
  323. ):
  324. r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
  325. The method is described in `Delving deep into rectifiers: Surpassing
  326. human-level performance on ImageNet classification` - He, K. et al. (2015).
  327. The resulting tensor will have values sampled from
  328. :math:`\mathcal{U}(-\text{bound}, \text{bound})` where
  329. .. math::
  330. \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
  331. Also known as He initialization.
  332. Args:
  333. tensor: an n-dimensional `torch.Tensor`
  334. a: the negative slope of the rectifier used after this layer (only
  335. used with ``'leaky_relu'``)
  336. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  337. preserves the magnitude of the variance of the weights in the
  338. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  339. backwards pass.
  340. nonlinearity: the non-linear function (`nn.functional` name),
  341. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  342. generator: the torch Generator to sample from (default: None)
  343. Examples:
  344. >>> w = torch.empty(3, 5)
  345. >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
  346. """
  347. if torch.overrides.has_torch_function_variadic(tensor):
  348. return torch.overrides.handle_torch_function(
  349. kaiming_uniform_,
  350. (tensor,),
  351. tensor=tensor,
  352. a=a,
  353. mode=mode,
  354. nonlinearity=nonlinearity,
  355. generator=generator)
  356. if 0 in tensor.shape:
  357. warnings.warn("Initializing zero-element tensors is a no-op")
  358. return tensor
  359. fan = _calculate_correct_fan(tensor, mode)
  360. gain = calculate_gain(nonlinearity, a)
  361. std = gain / math.sqrt(fan)
  362. bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
  363. with torch.no_grad():
  364. return tensor.uniform_(-bound, bound, generator=generator)
  365. def kaiming_normal_(
  366. tensor: Tensor,
  367. a: float = 0,
  368. mode: str = "fan_in",
  369. nonlinearity: str = "leaky_relu",
  370. generator: _Optional[torch.Generator] = None,
  371. ):
  372. r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
  373. The method is described in `Delving deep into rectifiers: Surpassing
  374. human-level performance on ImageNet classification` - He, K. et al. (2015).
  375. The resulting tensor will have values sampled from
  376. :math:`\mathcal{N}(0, \text{std}^2)` where
  377. .. math::
  378. \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
  379. Also known as He initialization.
  380. Args:
  381. tensor: an n-dimensional `torch.Tensor`
  382. a: the negative slope of the rectifier used after this layer (only
  383. used with ``'leaky_relu'``)
  384. mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
  385. preserves the magnitude of the variance of the weights in the
  386. forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
  387. backwards pass.
  388. nonlinearity: the non-linear function (`nn.functional` name),
  389. recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
  390. generator: the torch Generator to sample from (default: None)
  391. Examples:
  392. >>> w = torch.empty(3, 5)
  393. >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
  394. """
  395. if 0 in tensor.shape:
  396. warnings.warn("Initializing zero-element tensors is a no-op")
  397. return tensor
  398. fan = _calculate_correct_fan(tensor, mode)
  399. gain = calculate_gain(nonlinearity, a)
  400. std = gain / math.sqrt(fan)
  401. with torch.no_grad():
  402. return tensor.normal_(0, std, generator=generator)
  403. def orthogonal_(
  404. tensor,
  405. gain=1,
  406. generator: _Optional[torch.Generator] = None,
  407. ):
  408. r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
  409. Described in `Exact solutions to the nonlinear dynamics of learning in deep
  410. linear neural networks` - Saxe, A. et al. (2013). The input tensor must have
  411. at least 2 dimensions, and for tensors with more than 2 dimensions the
  412. trailing dimensions are flattened.
  413. Args:
  414. tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
  415. gain: optional scaling factor
  416. generator: the torch Generator to sample from (default: None)
  417. Examples:
  418. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  419. >>> w = torch.empty(3, 5)
  420. >>> nn.init.orthogonal_(w)
  421. """
  422. if tensor.ndimension() < 2:
  423. raise ValueError("Only tensors with 2 or more dimensions are supported")
  424. if tensor.numel() == 0:
  425. # no-op
  426. return tensor
  427. rows = tensor.size(0)
  428. cols = tensor.numel() // rows
  429. flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator)
  430. if rows < cols:
  431. flattened.t_()
  432. # Compute the qr factorization
  433. q, r = torch.linalg.qr(flattened)
  434. # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
  435. d = torch.diag(r, 0)
  436. ph = d.sign()
  437. q *= ph
  438. if rows < cols:
  439. q.t_()
  440. with torch.no_grad():
  441. tensor.view_as(q).copy_(q)
  442. tensor.mul_(gain)
  443. return tensor
  444. def sparse_(
  445. tensor,
  446. sparsity,
  447. std=0.01,
  448. generator: _Optional[torch.Generator] = None,
  449. ):
  450. r"""Fill the 2D input `Tensor` as a sparse matrix.
  451. The non-zero elements will be drawn from the normal distribution
  452. :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
  453. Hessian-free optimization` - Martens, J. (2010).
  454. Args:
  455. tensor: an n-dimensional `torch.Tensor`
  456. sparsity: The fraction of elements in each column to be set to zero
  457. std: the standard deviation of the normal distribution used to generate
  458. the non-zero values
  459. generator: the torch Generator to sample from (default: None)
  460. Examples:
  461. >>> w = torch.empty(3, 5)
  462. >>> nn.init.sparse_(w, sparsity=0.1)
  463. """
  464. if tensor.ndimension() != 2:
  465. raise ValueError("Only tensors with 2 dimensions are supported")
  466. rows, cols = tensor.shape
  467. num_zeros = int(math.ceil(sparsity * rows))
  468. with torch.no_grad():
  469. tensor.normal_(0, std, generator=generator)
  470. for col_idx in range(cols):
  471. row_indices = torch.randperm(rows)
  472. zero_indices = row_indices[:num_zeros]
  473. tensor[zero_indices, col_idx] = 0
  474. return tensor
  475. # for backward compatibility
  476. def _make_deprecate(meth):
  477. new_name = meth.__name__
  478. old_name = new_name[:-1]
  479. def deprecated_init(*args, **kwargs):
  480. warnings.warn(
  481. f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.",
  482. FutureWarning,
  483. stacklevel=2,
  484. )
  485. return meth(*args, **kwargs)
  486. deprecated_init.__doc__ = fr"""
  487. {old_name}(...)
  488. .. warning::
  489. This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`.
  490. See :func:`~torch.nn.init.{new_name}` for details."""
  491. deprecated_init.__name__ = old_name
  492. return deprecated_init
  493. uniform = _make_deprecate(uniform_)
  494. normal = _make_deprecate(normal_)
  495. constant = _make_deprecate(constant_)
  496. eye = _make_deprecate(eye_)
  497. dirac = _make_deprecate(dirac_)
  498. xavier_uniform = _make_deprecate(xavier_uniform_)
  499. xavier_normal = _make_deprecate(xavier_normal_)
  500. kaiming_uniform = _make_deprecate(kaiming_uniform_)
  501. kaiming_normal = _make_deprecate(kaiming_normal_)
  502. orthogonal = _make_deprecate(orthogonal_)
  503. sparse = _make_deprecate(sparse_)