container.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913
  1. # mypy: allow-untyped-defs
  2. from collections import OrderedDict, abc as container_abcs
  3. from itertools import chain, islice
  4. import operator
  5. import torch
  6. from .module import Module
  7. from ..parameter import Parameter
  8. from torch._jit_internal import _copy_to_script_wrapper
  9. from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union
  10. from typing_extensions import Self
  11. from typing_extensions import deprecated
  12. __all__ = ['Container', 'Sequential', 'ModuleList', 'ModuleDict', 'ParameterList', 'ParameterDict']
  13. T = TypeVar('T', bound=Module)
  14. # Copied from torch.nn.modules.module, required for a custom __repr__ for ModuleList
  15. def _addindent(s_, numSpaces):
  16. s = s_.split('\n')
  17. # don't do anything for single-line stuff
  18. if len(s) == 1:
  19. return s_
  20. first = s.pop(0)
  21. s = [(numSpaces * ' ') + line for line in s]
  22. s = '\n'.join(s)
  23. s = first + '\n' + s
  24. return s
  25. @deprecated(
  26. "`nn.Container` is deprecated. "
  27. "All of it's functionality is now implemented in `nn.Module`. Subclass that instead.",
  28. category=FutureWarning,
  29. )
  30. class Container(Module):
  31. def __init__(self, **kwargs: Any) -> None:
  32. super().__init__()
  33. for key, value in kwargs.items():
  34. self.add_module(key, value)
  35. class Sequential(Module):
  36. r"""A sequential container.
  37. Modules will be added to it in the order they are passed in the
  38. constructor. Alternatively, an ``OrderedDict`` of modules can be
  39. passed in. The ``forward()`` method of ``Sequential`` accepts any
  40. input and forwards it to the first module it contains. It then
  41. "chains" outputs to inputs sequentially for each subsequent module,
  42. finally returning the output of the last module.
  43. The value a ``Sequential`` provides over manually calling a sequence
  44. of modules is that it allows treating the whole container as a
  45. single module, such that performing a transformation on the
  46. ``Sequential`` applies to each of the modules it stores (which are
  47. each a registered submodule of the ``Sequential``).
  48. What's the difference between a ``Sequential`` and a
  49. :class:`torch.nn.ModuleList`? A ``ModuleList`` is exactly what it
  50. sounds like--a list for storing ``Module`` s! On the other hand,
  51. the layers in a ``Sequential`` are connected in a cascading way.
  52. Example::
  53. # Using Sequential to create a small model. When `model` is run,
  54. # input will first be passed to `Conv2d(1,20,5)`. The output of
  55. # `Conv2d(1,20,5)` will be used as the input to the first
  56. # `ReLU`; the output of the first `ReLU` will become the input
  57. # for `Conv2d(20,64,5)`. Finally, the output of
  58. # `Conv2d(20,64,5)` will be used as input to the second `ReLU`
  59. model = nn.Sequential(
  60. nn.Conv2d(1,20,5),
  61. nn.ReLU(),
  62. nn.Conv2d(20,64,5),
  63. nn.ReLU()
  64. )
  65. # Using Sequential with OrderedDict. This is functionally the
  66. # same as the above code
  67. model = nn.Sequential(OrderedDict([
  68. ('conv1', nn.Conv2d(1,20,5)),
  69. ('relu1', nn.ReLU()),
  70. ('conv2', nn.Conv2d(20,64,5)),
  71. ('relu2', nn.ReLU())
  72. ]))
  73. """
  74. _modules: Dict[str, Module] # type: ignore[assignment]
  75. @overload
  76. def __init__(self, *args: Module) -> None:
  77. ...
  78. @overload
  79. def __init__(self, arg: 'OrderedDict[str, Module]') -> None:
  80. ...
  81. def __init__(self, *args):
  82. super().__init__()
  83. if len(args) == 1 and isinstance(args[0], OrderedDict):
  84. for key, module in args[0].items():
  85. self.add_module(key, module)
  86. else:
  87. for idx, module in enumerate(args):
  88. self.add_module(str(idx), module)
  89. def _get_item_by_idx(self, iterator, idx) -> T: # type: ignore[misc, type-var]
  90. """Get the idx-th item of the iterator."""
  91. size = len(self)
  92. idx = operator.index(idx)
  93. if not -size <= idx < size:
  94. raise IndexError(f'index {idx} is out of range')
  95. idx %= size
  96. return next(islice(iterator, idx, None))
  97. @_copy_to_script_wrapper
  98. def __getitem__(self, idx: Union[slice, int]) -> Union['Sequential', T]:
  99. if isinstance(idx, slice):
  100. return self.__class__(OrderedDict(list(self._modules.items())[idx]))
  101. else:
  102. return self._get_item_by_idx(self._modules.values(), idx)
  103. def __setitem__(self, idx: int, module: Module) -> None:
  104. key: str = self._get_item_by_idx(self._modules.keys(), idx)
  105. return setattr(self, key, module)
  106. def __delitem__(self, idx: Union[slice, int]) -> None:
  107. if isinstance(idx, slice):
  108. for key in list(self._modules.keys())[idx]:
  109. delattr(self, key)
  110. else:
  111. key = self._get_item_by_idx(self._modules.keys(), idx)
  112. delattr(self, key)
  113. # To preserve numbering
  114. str_indices = [str(i) for i in range(len(self._modules))]
  115. self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
  116. @_copy_to_script_wrapper
  117. def __len__(self) -> int:
  118. return len(self._modules)
  119. def __add__(self, other) -> 'Sequential':
  120. if isinstance(other, Sequential):
  121. ret = Sequential()
  122. for layer in self:
  123. ret.append(layer)
  124. for layer in other:
  125. ret.append(layer)
  126. return ret
  127. else:
  128. raise ValueError('add operator supports only objects '
  129. f'of Sequential class, but {str(type(other))} is given.')
  130. def pop(self, key: Union[int, slice]) -> Module:
  131. v = self[key]
  132. del self[key]
  133. return v
  134. def __iadd__(self, other) -> Self:
  135. if isinstance(other, Sequential):
  136. offset = len(self)
  137. for i, module in enumerate(other):
  138. self.add_module(str(i + offset), module)
  139. return self
  140. else:
  141. raise ValueError('add operator supports only objects '
  142. f'of Sequential class, but {str(type(other))} is given.')
  143. def __mul__(self, other: int) -> 'Sequential':
  144. if not isinstance(other, int):
  145. raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
  146. elif (other <= 0):
  147. raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
  148. else:
  149. combined = Sequential()
  150. offset = 0
  151. for _ in range(other):
  152. for module in self:
  153. combined.add_module(str(offset), module)
  154. offset += 1
  155. return combined
  156. def __rmul__(self, other: int) -> 'Sequential':
  157. return self.__mul__(other)
  158. def __imul__(self, other: int) -> Self:
  159. if not isinstance(other, int):
  160. raise TypeError(f"unsupported operand type(s) for *: {type(self)} and {type(other)}")
  161. elif (other <= 0):
  162. raise ValueError(f"Non-positive multiplication factor {other} for {type(self)}")
  163. else:
  164. len_original = len(self)
  165. offset = len(self)
  166. for _ in range(other - 1):
  167. for i in range(len_original):
  168. self.add_module(str(i + offset), self._modules[str(i)])
  169. offset += len_original
  170. return self
  171. @_copy_to_script_wrapper
  172. def __dir__(self):
  173. keys = super().__dir__()
  174. keys = [key for key in keys if not key.isdigit()]
  175. return keys
  176. @_copy_to_script_wrapper
  177. def __iter__(self) -> Iterator[Module]:
  178. return iter(self._modules.values())
  179. # NB: We can't really type check this function as the type of input
  180. # may change dynamically (as is tested in
  181. # TestScript.test_sequential_intermediary_types). Cannot annotate
  182. # with Any as TorchScript expects a more precise type
  183. def forward(self, input):
  184. for module in self:
  185. input = module(input)
  186. return input
  187. def append(self, module: Module) -> 'Sequential':
  188. r"""Append a given module to the end.
  189. Args:
  190. module (nn.Module): module to append
  191. """
  192. self.add_module(str(len(self)), module)
  193. return self
  194. def insert(self, index: int, module: Module) -> 'Sequential':
  195. if not isinstance(module, Module):
  196. raise AssertionError(
  197. f'module should be of type: {Module}')
  198. n = len(self._modules)
  199. if not (-n <= index <= n):
  200. raise IndexError(
  201. f'Index out of range: {index}')
  202. if index < 0:
  203. index += n
  204. for i in range(n, index, -1):
  205. self._modules[str(i)] = self._modules[str(i - 1)]
  206. self._modules[str(index)] = module
  207. return self
  208. def extend(self, sequential) -> 'Sequential':
  209. for layer in sequential:
  210. self.append(layer)
  211. return self
  212. class ModuleList(Module):
  213. r"""Holds submodules in a list.
  214. :class:`~torch.nn.ModuleList` can be indexed like a regular Python list, but
  215. modules it contains are properly registered, and will be visible by all
  216. :class:`~torch.nn.Module` methods.
  217. Args:
  218. modules (iterable, optional): an iterable of modules to add
  219. Example::
  220. class MyModule(nn.Module):
  221. def __init__(self):
  222. super().__init__()
  223. self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
  224. def forward(self, x):
  225. # ModuleList can act as an iterable, or be indexed using ints
  226. for i, l in enumerate(self.linears):
  227. x = self.linears[i // 2](x) + l(x)
  228. return x
  229. """
  230. _modules: Dict[str, Module] # type: ignore[assignment]
  231. def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
  232. super().__init__()
  233. if modules is not None:
  234. self += modules
  235. def _get_abs_string_index(self, idx):
  236. """Get the absolute index for the list of modules."""
  237. idx = operator.index(idx)
  238. if not (-len(self) <= idx < len(self)):
  239. raise IndexError(f'index {idx} is out of range')
  240. if idx < 0:
  241. idx += len(self)
  242. return str(idx)
  243. @_copy_to_script_wrapper
  244. def __getitem__(self, idx: Union[int, slice]) -> Union[Module, 'ModuleList']:
  245. if isinstance(idx, slice):
  246. return self.__class__(list(self._modules.values())[idx])
  247. else:
  248. return self._modules[self._get_abs_string_index(idx)]
  249. def __setitem__(self, idx: int, module: Module) -> None:
  250. idx = self._get_abs_string_index(idx)
  251. return setattr(self, str(idx), module)
  252. def __delitem__(self, idx: Union[int, slice]) -> None:
  253. if isinstance(idx, slice):
  254. for k in range(len(self._modules))[idx]:
  255. delattr(self, str(k))
  256. else:
  257. delattr(self, self._get_abs_string_index(idx))
  258. # To preserve numbering, self._modules is being reconstructed with modules after deletion
  259. str_indices = [str(i) for i in range(len(self._modules))]
  260. self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
  261. @_copy_to_script_wrapper
  262. def __len__(self) -> int:
  263. return len(self._modules)
  264. @_copy_to_script_wrapper
  265. def __iter__(self) -> Iterator[Module]:
  266. return iter(self._modules.values())
  267. def __iadd__(self, modules: Iterable[Module]) -> Self:
  268. return self.extend(modules)
  269. def __add__(self, other: Iterable[Module]) -> 'ModuleList':
  270. combined = ModuleList()
  271. for i, module in enumerate(chain(self, other)):
  272. combined.add_module(str(i), module)
  273. return combined
  274. def __repr__(self):
  275. """Return a custom repr for ModuleList that compresses repeated module representations."""
  276. list_of_reprs = [repr(item) for item in self]
  277. if len(list_of_reprs) == 0:
  278. return self._get_name() + '()'
  279. start_end_indices = [[0, 0]]
  280. repeated_blocks = [list_of_reprs[0]]
  281. for i, r in enumerate(list_of_reprs[1:], 1):
  282. if r == repeated_blocks[-1]:
  283. start_end_indices[-1][1] += 1
  284. continue
  285. start_end_indices.append([i, i])
  286. repeated_blocks.append(r)
  287. lines = []
  288. main_str = self._get_name() + '('
  289. for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
  290. local_repr = f"({start_id}): {b}" # default repr
  291. if start_id != end_id:
  292. n = end_id - start_id + 1
  293. local_repr = f"({start_id}-{end_id}): {n} x {b}"
  294. local_repr = _addindent(local_repr, 2)
  295. lines.append(local_repr)
  296. main_str += '\n ' + '\n '.join(lines) + '\n'
  297. main_str += ')'
  298. return main_str
  299. @_copy_to_script_wrapper
  300. def __dir__(self):
  301. keys = super().__dir__()
  302. keys = [key for key in keys if not key.isdigit()]
  303. return keys
  304. def insert(self, index: int, module: Module) -> None:
  305. r"""Insert a given module before a given index in the list.
  306. Args:
  307. index (int): index to insert.
  308. module (nn.Module): module to insert
  309. """
  310. for i in range(len(self._modules), index, -1):
  311. self._modules[str(i)] = self._modules[str(i - 1)]
  312. self._modules[str(index)] = module
  313. def append(self, module: Module) -> 'ModuleList':
  314. r"""Append a given module to the end of the list.
  315. Args:
  316. module (nn.Module): module to append
  317. """
  318. self.add_module(str(len(self)), module)
  319. return self
  320. def pop(self, key: Union[int, slice]) -> Module:
  321. v = self[key]
  322. del self[key]
  323. return v
  324. def extend(self, modules: Iterable[Module]) -> Self:
  325. r"""Append modules from a Python iterable to the end of the list.
  326. Args:
  327. modules (iterable): iterable of modules to append
  328. """
  329. if not isinstance(modules, container_abcs.Iterable):
  330. raise TypeError("ModuleList.extend should be called with an "
  331. "iterable, but got " + type(modules).__name__)
  332. offset = len(self)
  333. for i, module in enumerate(modules):
  334. self.add_module(str(offset + i), module)
  335. return self
  336. # remove forward alltogether to fallback on Module's _forward_unimplemented
  337. class ModuleDict(Module):
  338. r"""Holds submodules in a dictionary.
  339. :class:`~torch.nn.ModuleDict` can be indexed like a regular Python dictionary,
  340. but modules it contains are properly registered, and will be visible by all
  341. :class:`~torch.nn.Module` methods.
  342. :class:`~torch.nn.ModuleDict` is an **ordered** dictionary that respects
  343. * the order of insertion, and
  344. * in :meth:`~torch.nn.ModuleDict.update`, the order of the merged
  345. ``OrderedDict``, ``dict`` (started from Python 3.6) or another
  346. :class:`~torch.nn.ModuleDict` (the argument to
  347. :meth:`~torch.nn.ModuleDict.update`).
  348. Note that :meth:`~torch.nn.ModuleDict.update` with other unordered mapping
  349. types (e.g., Python's plain ``dict`` before Python version 3.6) does not
  350. preserve the order of the merged mapping.
  351. Args:
  352. modules (iterable, optional): a mapping (dictionary) of (string: module)
  353. or an iterable of key-value pairs of type (string, module)
  354. Example::
  355. class MyModule(nn.Module):
  356. def __init__(self):
  357. super().__init__()
  358. self.choices = nn.ModuleDict({
  359. 'conv': nn.Conv2d(10, 10, 3),
  360. 'pool': nn.MaxPool2d(3)
  361. })
  362. self.activations = nn.ModuleDict([
  363. ['lrelu', nn.LeakyReLU()],
  364. ['prelu', nn.PReLU()]
  365. ])
  366. def forward(self, x, choice, act):
  367. x = self.choices[choice](x)
  368. x = self.activations[act](x)
  369. return x
  370. """
  371. _modules: Dict[str, Module] # type: ignore[assignment]
  372. def __init__(self, modules: Optional[Mapping[str, Module]] = None) -> None:
  373. super().__init__()
  374. if modules is not None:
  375. self.update(modules)
  376. @_copy_to_script_wrapper
  377. def __getitem__(self, key: str) -> Module:
  378. return self._modules[key]
  379. def __setitem__(self, key: str, module: Module) -> None:
  380. self.add_module(key, module)
  381. def __delitem__(self, key: str) -> None:
  382. del self._modules[key]
  383. @_copy_to_script_wrapper
  384. def __len__(self) -> int:
  385. return len(self._modules)
  386. @_copy_to_script_wrapper
  387. def __iter__(self) -> Iterator[str]:
  388. return iter(self._modules)
  389. @_copy_to_script_wrapper
  390. def __contains__(self, key: str) -> bool:
  391. return key in self._modules
  392. def clear(self) -> None:
  393. """Remove all items from the ModuleDict."""
  394. self._modules.clear()
  395. def pop(self, key: str) -> Module:
  396. r"""Remove key from the ModuleDict and return its module.
  397. Args:
  398. key (str): key to pop from the ModuleDict
  399. """
  400. v = self[key]
  401. del self[key]
  402. return v
  403. @_copy_to_script_wrapper
  404. def keys(self) -> Iterable[str]:
  405. r"""Return an iterable of the ModuleDict keys."""
  406. return self._modules.keys()
  407. @_copy_to_script_wrapper
  408. def items(self) -> Iterable[Tuple[str, Module]]:
  409. r"""Return an iterable of the ModuleDict key/value pairs."""
  410. return self._modules.items()
  411. @_copy_to_script_wrapper
  412. def values(self) -> Iterable[Module]:
  413. r"""Return an iterable of the ModuleDict values."""
  414. return self._modules.values()
  415. def update(self, modules: Mapping[str, Module]) -> None:
  416. r"""Update the :class:`~torch.nn.ModuleDict` with key-value pairs from a mapping, overwriting existing keys.
  417. .. note::
  418. If :attr:`modules` is an ``OrderedDict``, a :class:`~torch.nn.ModuleDict`, or
  419. an iterable of key-value pairs, the order of new elements in it is preserved.
  420. Args:
  421. modules (iterable): a mapping (dictionary) from string to :class:`~torch.nn.Module`,
  422. or an iterable of key-value pairs of type (string, :class:`~torch.nn.Module`)
  423. """
  424. if not isinstance(modules, container_abcs.Iterable):
  425. raise TypeError("ModuleDict.update should be called with an "
  426. "iterable of key/value pairs, but got " +
  427. type(modules).__name__)
  428. if isinstance(modules, (OrderedDict, ModuleDict, container_abcs.Mapping)):
  429. for key, module in modules.items():
  430. self[key] = module
  431. else:
  432. # modules here can be a list with two items
  433. for j, m in enumerate(modules):
  434. if not isinstance(m, container_abcs.Iterable):
  435. raise TypeError("ModuleDict update sequence element "
  436. "#" + str(j) + " should be Iterable; is" +
  437. type(m).__name__)
  438. if not len(m) == 2:
  439. raise ValueError("ModuleDict update sequence element "
  440. "#" + str(j) + " has length " + str(len(m)) +
  441. "; 2 is required")
  442. # modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
  443. # that's too cumbersome to type correctly with overloads, so we add an ignore here
  444. self[m[0]] = m[1] # type: ignore[assignment]
  445. # remove forward alltogether to fallback on Module's _forward_unimplemented
  446. class ParameterList(Module):
  447. r"""Holds parameters in a list.
  448. :class:`~torch.nn.ParameterList` can be used like a regular Python
  449. list, but Tensors that are :class:`~torch.nn.Parameter` are properly registered,
  450. and will be visible by all :class:`~torch.nn.Module` methods.
  451. Note that the constructor, assigning an element of the list, the
  452. :meth:`~torch.nn.ParameterDict.append` method and the :meth:`~torch.nn.ParameterDict.extend`
  453. method will convert any :class:`~torch.Tensor` into :class:`~torch.nn.Parameter`.
  454. Args:
  455. parameters (iterable, optional): an iterable of elements to add to the list.
  456. Example::
  457. class MyModule(nn.Module):
  458. def __init__(self):
  459. super().__init__()
  460. self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
  461. def forward(self, x):
  462. # ParameterList can act as an iterable, or be indexed using ints
  463. for i, p in enumerate(self.params):
  464. x = self.params[i // 2].mm(x) + p.mm(x)
  465. return x
  466. """
  467. def __init__(self, values: Optional[Iterable[Any]] = None) -> None:
  468. super().__init__()
  469. self._size = 0
  470. if values is not None:
  471. self += values
  472. def _get_abs_string_index(self, idx):
  473. """Get the absolute index for the list of modules."""
  474. idx = operator.index(idx)
  475. if not (-len(self) <= idx < len(self)):
  476. raise IndexError(f'index {idx} is out of range')
  477. if idx < 0:
  478. idx += len(self)
  479. return str(idx)
  480. @overload
  481. def __getitem__(self, idx: int) -> Any:
  482. ...
  483. @overload
  484. def __getitem__(self: T, idx: slice) -> T:
  485. ...
  486. def __getitem__(self, idx):
  487. if isinstance(idx, slice):
  488. start, stop, step = idx.indices(len(self))
  489. out = self.__class__()
  490. for i in range(start, stop, step):
  491. out.append(self[i])
  492. return out
  493. else:
  494. idx = self._get_abs_string_index(idx)
  495. return getattr(self, str(idx))
  496. def __setitem__(self, idx: int, param: Any) -> None:
  497. # Note that all other function that add an entry to the list part of
  498. # the ParameterList end up here. So this is the only place where we need
  499. # to wrap things into Parameter if needed.
  500. # Objects added via setattr() are not in the list part and thus won't
  501. # call into this function.
  502. idx = self._get_abs_string_index(idx)
  503. if isinstance(param, torch.Tensor) and not isinstance(param, Parameter):
  504. param = Parameter(param)
  505. return setattr(self, str(idx), param)
  506. def __len__(self) -> int:
  507. return self._size
  508. def __iter__(self) -> Iterator[Any]:
  509. return iter(self[i] for i in range(len(self)))
  510. def __iadd__(self, parameters: Iterable[Any]) -> Self:
  511. return self.extend(parameters)
  512. def __dir__(self):
  513. keys = super().__dir__()
  514. keys = [key for key in keys if not key.isdigit()]
  515. return keys
  516. def append(self, value: Any) -> 'ParameterList':
  517. """Append a given value at the end of the list.
  518. Args:
  519. value (Any): value to append
  520. """
  521. new_idx = len(self)
  522. self._size += 1
  523. self[new_idx] = value
  524. return self
  525. def extend(self, values: Iterable[Any]) -> Self:
  526. """Append values from a Python iterable to the end of the list.
  527. Args:
  528. values (iterable): iterable of values to append
  529. """
  530. # Tensor is an iterable but we never want to unpack it here
  531. if not isinstance(values, container_abcs.Iterable) or isinstance(values, torch.Tensor):
  532. raise TypeError("ParameterList.extend should be called with an "
  533. "iterable, but got " + type(values).__name__)
  534. for value in values:
  535. self.append(value)
  536. return self
  537. def extra_repr(self) -> str:
  538. child_lines = []
  539. for k, p in enumerate(self):
  540. if isinstance(p, torch.Tensor):
  541. size_str = 'x'.join(str(size) for size in p.size())
  542. if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
  543. device_str = f' ({p.device})'
  544. else:
  545. device_str = ''
  546. parastr = '{} containing: [{} of size {}{}]'.format(
  547. "Parameter" if isinstance(p, Parameter) else "Tensor",
  548. p.dtype, size_str, device_str)
  549. child_lines.append(' (' + str(k) + '): ' + parastr)
  550. else:
  551. child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
  552. tmpstr = '\n'.join(child_lines)
  553. return tmpstr
  554. def __call__(self, *args, **kwargs):
  555. raise RuntimeError('ParameterList should not be called.')
  556. class ParameterDict(Module):
  557. r"""Holds parameters in a dictionary.
  558. ParameterDict can be indexed like a regular Python dictionary, but Parameters it
  559. contains are properly registered, and will be visible by all Module methods.
  560. Other objects are treated as would be done by a regular Python dictionary
  561. :class:`~torch.nn.ParameterDict` is an **ordered** dictionary.
  562. :meth:`~torch.nn.ParameterDict.update` with other unordered mapping
  563. types (e.g., Python's plain ``dict``) does not preserve the order of the
  564. merged mapping. On the other hand, ``OrderedDict`` or another :class:`~torch.nn.ParameterDict`
  565. will preserve their ordering.
  566. Note that the constructor, assigning an element of the dictionary and the
  567. :meth:`~torch.nn.ParameterDict.update` method will convert any :class:`~torch.Tensor` into
  568. :class:`~torch.nn.Parameter`.
  569. Args:
  570. values (iterable, optional): a mapping (dictionary) of
  571. (string : Any) or an iterable of key-value pairs
  572. of type (string, Any)
  573. Example::
  574. class MyModule(nn.Module):
  575. def __init__(self):
  576. super().__init__()
  577. self.params = nn.ParameterDict({
  578. 'left': nn.Parameter(torch.randn(5, 10)),
  579. 'right': nn.Parameter(torch.randn(5, 10))
  580. })
  581. def forward(self, x, choice):
  582. x = self.params[choice].mm(x)
  583. return x
  584. """
  585. def __init__(self, parameters: Any = None) -> None:
  586. super().__init__()
  587. self._keys: Dict[str, None] = {}
  588. if parameters is not None:
  589. self.update(parameters)
  590. def _key_to_attr(self, key: str) -> str:
  591. if not isinstance(key, str):
  592. raise TypeError("Index given to ParameterDict cannot be used as a key as it is "
  593. f"not a string (type is '{type(key).__name__}'). Open an issue on "
  594. "github if you need non-string keys.")
  595. else:
  596. # Use the key as-is so that `.named_parameters()` returns the right thing
  597. return key
  598. def __getitem__(self, key: str) -> Any:
  599. attr = self._key_to_attr(key)
  600. return getattr(self, attr)
  601. def __setitem__(self, key: str, value: Any) -> None:
  602. # Note that all other function that add an entry to the dictionary part of
  603. # the ParameterDict end up here. So this is the only place where we need
  604. # to wrap things into Parameter if needed.
  605. # Objects added via setattr() are not in the dictionary part and thus won't
  606. # call into this function.
  607. self._keys[key] = None
  608. attr = self._key_to_attr(key)
  609. if isinstance(value, torch.Tensor) and not isinstance(value, Parameter):
  610. value = Parameter(value)
  611. setattr(self, attr, value)
  612. def __delitem__(self, key: str) -> None:
  613. del self._keys[key]
  614. attr = self._key_to_attr(key)
  615. delattr(self, attr)
  616. def __len__(self) -> int:
  617. return len(self._keys)
  618. def __iter__(self) -> Iterator[str]:
  619. return iter(self._keys)
  620. def __reversed__(self) -> Iterator[str]:
  621. return reversed(list(self._keys))
  622. def copy(self) -> 'ParameterDict':
  623. """Return a copy of this :class:`~torch.nn.ParameterDict` instance."""
  624. # We have to use an OrderedDict because the ParameterDict constructor
  625. # behaves differently on plain dict vs OrderedDict
  626. return ParameterDict(OrderedDict((k, self[k]) for k in self._keys))
  627. def __contains__(self, key: str) -> bool:
  628. return key in self._keys
  629. def setdefault(self, key: str, default: Optional[Any] = None) -> Any:
  630. """Set the default for a key in the Parameterdict.
  631. If key is in the ParameterDict, return its value.
  632. If not, insert `key` with a parameter `default` and return `default`.
  633. `default` defaults to `None`.
  634. Args:
  635. key (str): key to set default for
  636. default (Any): the parameter set to the key
  637. """
  638. if key not in self:
  639. self[key] = default
  640. return self[key]
  641. def clear(self) -> None:
  642. """Remove all items from the ParameterDict."""
  643. for k in self._keys.copy():
  644. del self[k]
  645. def pop(self, key: str) -> Any:
  646. r"""Remove key from the ParameterDict and return its parameter.
  647. Args:
  648. key (str): key to pop from the ParameterDict
  649. """
  650. v = self[key]
  651. del self[key]
  652. return v
  653. def popitem(self) -> Tuple[str, Any]:
  654. """Remove and return the last inserted `(key, parameter)` pair from the ParameterDict."""
  655. k, _ = self._keys.popitem()
  656. # We need the key in the _keys to be able to access/del
  657. self._keys[k] = None
  658. val = self[k]
  659. del self[k]
  660. return k, val
  661. def get(self, key: str, default: Optional[Any] = None) -> Any:
  662. r"""Return the parameter associated with key if present. Otherwise return default if provided, None if not.
  663. Args:
  664. key (str): key to get from the ParameterDict
  665. default (Parameter, optional): value to return if key not present
  666. """
  667. return self[key] if key in self else default
  668. def fromkeys(self, keys: Iterable[str], default: Optional[Any] = None) -> 'ParameterDict':
  669. r"""Return a new ParameterDict with the keys provided.
  670. Args:
  671. keys (iterable, string): keys to make the new ParameterDict from
  672. default (Parameter, optional): value to set for all keys
  673. """
  674. return ParameterDict((k, default) for k in keys)
  675. def keys(self) -> Iterable[str]:
  676. r"""Return an iterable of the ParameterDict keys."""
  677. return self._keys.keys()
  678. def items(self) -> Iterable[Tuple[str, Any]]:
  679. r"""Return an iterable of the ParameterDict key/value pairs."""
  680. return ((k, self[k]) for k in self._keys)
  681. def values(self) -> Iterable[Any]:
  682. r"""Return an iterable of the ParameterDict values."""
  683. return (self[k] for k in self._keys)
  684. def update(self, parameters: Union[Mapping[str, Any], 'ParameterDict']) -> None:
  685. r"""Update the :class:`~torch.nn.ParameterDict` with key-value pairs from ``parameters``, overwriting existing keys.
  686. .. note::
  687. If :attr:`parameters` is an ``OrderedDict``, a :class:`~torch.nn.ParameterDict`, or
  688. an iterable of key-value pairs, the order of new elements in it is preserved.
  689. Args:
  690. parameters (iterable): a mapping (dictionary) from string to
  691. :class:`~torch.nn.Parameter`, or an iterable of
  692. key-value pairs of type (string, :class:`~torch.nn.Parameter`)
  693. """
  694. if not isinstance(parameters, container_abcs.Iterable):
  695. raise TypeError("ParametersDict.update should be called with an "
  696. "iterable of key/value pairs, but got " +
  697. type(parameters).__name__)
  698. if isinstance(parameters, (OrderedDict, ParameterDict)):
  699. for key, parameter in parameters.items():
  700. self[key] = parameter
  701. elif isinstance(parameters, container_abcs.Mapping):
  702. for key, parameter in sorted(parameters.items()):
  703. self[key] = parameter
  704. else:
  705. for j, p in enumerate(parameters):
  706. if not isinstance(p, container_abcs.Iterable):
  707. raise TypeError("ParameterDict update sequence element "
  708. "#" + str(j) + " should be Iterable; is" +
  709. type(p).__name__)
  710. if not len(p) == 2:
  711. raise ValueError("ParameterDict update sequence element "
  712. "#" + str(j) + " has length " + str(len(p)) +
  713. "; 2 is required")
  714. # parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
  715. self[p[0]] = p[1] # type: ignore[assignment]
  716. def extra_repr(self) -> str:
  717. child_lines = []
  718. for k, p in self.items():
  719. if isinstance(p, torch.Tensor):
  720. size_str = 'x'.join(str(size) for size in p.size())
  721. if p.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
  722. device_str = f' ({p.device})'
  723. else:
  724. device_str = ''
  725. parastr = '{} containing: [{} of size {}{}]'.format(
  726. "Parameter" if isinstance(p, Parameter) else "Tensor",
  727. torch.typename(p), size_str, device_str)
  728. child_lines.append(' (' + str(k) + '): ' + parastr)
  729. else:
  730. child_lines.append(' (' + str(k) + '): Object of type: ' + type(p).__name__)
  731. tmpstr = '\n'.join(child_lines)
  732. return tmpstr
  733. def __call__(self, input):
  734. raise RuntimeError('ParameterDict should not be called.')
  735. def __or__(self, other: 'ParameterDict') -> 'ParameterDict':
  736. copy = self.copy()
  737. copy.update(other)
  738. return copy
  739. def __ror__(self, other: 'ParameterDict') -> 'ParameterDict':
  740. copy = other.copy()
  741. copy.update(self)
  742. return copy
  743. def __ior__(self, other : 'ParameterDict') -> Self:
  744. self.update(other)
  745. return self