collate.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317
  1. # mypy: allow-untyped-defs
  2. r"""Contains definitions of the methods used by the _BaseDataLoaderIter workers.
  3. These methods are used to collate samples fetched from dataset into Tensor(s).
  4. These **needs** to be in global scope since Py2 doesn't support serializing
  5. static methods.
  6. `default_collate` and `default_convert` are exposed to users via 'dataloader.py'.
  7. """
  8. import collections
  9. import contextlib
  10. import copy
  11. import re
  12. import torch
  13. from typing import Callable, Dict, Optional, Tuple, Type, Union
  14. np_str_obj_array_pattern = re.compile(r'[SaUO]')
  15. def default_convert(data):
  16. r"""
  17. Convert each NumPy array element into a :class:`torch.Tensor`.
  18. If the input is a `Sequence`, `Collection`, or `Mapping`, it tries to convert each element inside to a :class:`torch.Tensor`.
  19. If the input is not an NumPy array, it is left unchanged.
  20. This is used as the default function for collation when both `batch_sampler` and `batch_size`
  21. are NOT defined in :class:`~torch.utils.data.DataLoader`.
  22. The general input type to output type mapping is similar to that
  23. of :func:`~torch.utils.data.default_collate`. See the description there for more details.
  24. Args:
  25. data: a single data point to be converted
  26. Examples:
  27. >>> # xdoctest: +SKIP
  28. >>> # Example with `int`
  29. >>> default_convert(0)
  30. 0
  31. >>> # Example with NumPy array
  32. >>> default_convert(np.array([0, 1]))
  33. tensor([0, 1])
  34. >>> # Example with NamedTuple
  35. >>> Point = namedtuple('Point', ['x', 'y'])
  36. >>> default_convert(Point(0, 0))
  37. Point(x=0, y=0)
  38. >>> default_convert(Point(np.array(0), np.array(0)))
  39. Point(x=tensor(0), y=tensor(0))
  40. >>> # Example with List
  41. >>> default_convert([np.array([0, 1]), np.array([2, 3])])
  42. [tensor([0, 1]), tensor([2, 3])]
  43. """
  44. elem_type = type(data)
  45. if isinstance(data, torch.Tensor):
  46. return data
  47. elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
  48. and elem_type.__name__ != 'string_':
  49. # array of string classes and object
  50. if elem_type.__name__ == 'ndarray' \
  51. and np_str_obj_array_pattern.search(data.dtype.str) is not None:
  52. return data
  53. return torch.as_tensor(data)
  54. elif isinstance(data, collections.abc.Mapping):
  55. try:
  56. if isinstance(data, collections.abc.MutableMapping):
  57. # The mapping type may have extra properties, so we can't just
  58. # use `type(data)(...)` to create the new mapping.
  59. # Create a clone and update it if the mapping type is mutable.
  60. clone = copy.copy(data)
  61. clone.update({key: default_convert(data[key]) for key in data})
  62. return clone
  63. else:
  64. return elem_type({key: default_convert(data[key]) for key in data})
  65. except TypeError:
  66. # The mapping type may not support `copy()` / `update(mapping)`
  67. # or `__init__(iterable)`.
  68. return {key: default_convert(data[key]) for key in data}
  69. elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
  70. return elem_type(*(default_convert(d) for d in data))
  71. elif isinstance(data, tuple):
  72. return [default_convert(d) for d in data] # Backwards compatibility.
  73. elif isinstance(data, collections.abc.Sequence) and not isinstance(data, (str, bytes)):
  74. try:
  75. if isinstance(data, collections.abc.MutableSequence):
  76. # The sequence type may have extra properties, so we can't just
  77. # use `type(data)(...)` to create the new sequence.
  78. # Create a clone and update it if the sequence type is mutable.
  79. clone = copy.copy(data) # type: ignore[arg-type]
  80. for i, d in enumerate(data):
  81. clone[i] = default_convert(d)
  82. return clone
  83. else:
  84. return elem_type([default_convert(d) for d in data])
  85. except TypeError:
  86. # The sequence type may not support `copy()` / `__setitem__(index, item)`
  87. # or `__init__(iterable)` (e.g., `range`).
  88. return [default_convert(d) for d in data]
  89. else:
  90. return data
  91. default_collate_err_msg_format = (
  92. "default_collate: batch must contain tensors, numpy arrays, numbers, "
  93. "dicts or lists; found {}")
  94. def collate(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  95. r"""
  96. General collate function that handles collection type of element within each batch.
  97. The function also opens function registry to deal with specific element types. `default_collate_fn_map`
  98. provides default collate functions for tensors, numpy arrays, numbers and strings.
  99. Args:
  100. batch: a single batch to be collated
  101. collate_fn_map: Optional dictionary mapping from element type to the corresponding collate function.
  102. If the element type isn't present in this dictionary,
  103. this function will go through each key of the dictionary in the insertion order to
  104. invoke the corresponding collate function if the element type is a subclass of the key.
  105. Examples:
  106. >>> def collate_tensor_fn(batch, *, collate_fn_map):
  107. ... # Extend this function to handle batch of tensors
  108. ... return torch.stack(batch, 0)
  109. >>> def custom_collate(batch):
  110. ... collate_map = {torch.Tensor: collate_tensor_fn}
  111. ... return collate(batch, collate_fn_map=collate_map)
  112. >>> # Extend `default_collate` by in-place modifying `default_collate_fn_map`
  113. >>> default_collate_fn_map.update({torch.Tensor: collate_tensor_fn})
  114. Note:
  115. Each collate function requires a positional argument for batch and a keyword argument
  116. for the dictionary of collate functions as `collate_fn_map`.
  117. """
  118. elem = batch[0]
  119. elem_type = type(elem)
  120. if collate_fn_map is not None:
  121. if elem_type in collate_fn_map:
  122. return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  123. for collate_type in collate_fn_map:
  124. if isinstance(elem, collate_type):
  125. return collate_fn_map[collate_type](batch, collate_fn_map=collate_fn_map)
  126. if isinstance(elem, collections.abc.Mapping):
  127. try:
  128. if isinstance(elem, collections.abc.MutableMapping):
  129. # The mapping type may have extra properties, so we can't just
  130. # use `type(data)(...)` to create the new mapping.
  131. # Create a clone and update it if the mapping type is mutable.
  132. clone = copy.copy(elem)
  133. clone.update({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  134. return clone
  135. else:
  136. return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
  137. except TypeError:
  138. # The mapping type may not support `copy()` / `update(mapping)`
  139. # or `__init__(iterable)`.
  140. return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
  141. elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
  142. return elem_type(*(collate(samples, collate_fn_map=collate_fn_map) for samples in zip(*batch)))
  143. elif isinstance(elem, collections.abc.Sequence):
  144. # check to make sure that the elements in batch have consistent size
  145. it = iter(batch)
  146. elem_size = len(next(it))
  147. if not all(len(elem) == elem_size for elem in it):
  148. raise RuntimeError('each element in list of batch should be of equal size')
  149. transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
  150. if isinstance(elem, tuple):
  151. return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed] # Backwards compatibility.
  152. else:
  153. try:
  154. if isinstance(elem, collections.abc.MutableSequence):
  155. # The sequence type may have extra properties, so we can't just
  156. # use `type(data)(...)` to create the new sequence.
  157. # Create a clone and update it if the sequence type is mutable.
  158. clone = copy.copy(elem) # type: ignore[arg-type]
  159. for i, samples in enumerate(transposed):
  160. clone[i] = collate(samples, collate_fn_map=collate_fn_map)
  161. return clone
  162. else:
  163. return elem_type([collate(samples, collate_fn_map=collate_fn_map) for samples in transposed])
  164. except TypeError:
  165. # The sequence type may not support `copy()` / `__setitem__(index, item)`
  166. # or `__init__(iterable)` (e.g., `range`).
  167. return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]
  168. raise TypeError(default_collate_err_msg_format.format(elem_type))
  169. def collate_tensor_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  170. elem = batch[0]
  171. out = None
  172. if elem.is_nested:
  173. raise RuntimeError(
  174. "Batches of nested tensors are not currently supported by the default collate_fn; "
  175. "please provide a custom collate_fn to handle them appropriately."
  176. )
  177. if elem.layout in {torch.sparse_coo, torch.sparse_csr, torch.sparse_bsr, torch.sparse_csc, torch.sparse_bsc}:
  178. raise RuntimeError(
  179. "Batches of sparse tensors are not currently supported by the default collate_fn; "
  180. "please provide a custom collate_fn to handle them appropriately."
  181. )
  182. if torch.utils.data.get_worker_info() is not None:
  183. # If we're in a background process, concatenate directly into a
  184. # shared memory tensor to avoid an extra copy
  185. numel = sum(x.numel() for x in batch)
  186. storage = elem._typed_storage()._new_shared(numel, device=elem.device)
  187. out = elem.new(storage).resize_(len(batch), *list(elem.size()))
  188. return torch.stack(batch, 0, out=out)
  189. def collate_numpy_array_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  190. elem = batch[0]
  191. # array of string classes and object
  192. if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
  193. raise TypeError(default_collate_err_msg_format.format(elem.dtype))
  194. return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)
  195. def collate_numpy_scalar_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  196. return torch.as_tensor(batch)
  197. def collate_float_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  198. return torch.tensor(batch, dtype=torch.float64)
  199. def collate_int_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  200. return torch.tensor(batch)
  201. def collate_str_fn(batch, *, collate_fn_map: Optional[Dict[Union[Type, Tuple[Type, ...]], Callable]] = None):
  202. return batch
  203. default_collate_fn_map: Dict[Union[Type, Tuple[Type, ...]], Callable] = {torch.Tensor: collate_tensor_fn}
  204. with contextlib.suppress(ImportError):
  205. import numpy as np
  206. # For both ndarray and memmap (subclass of ndarray)
  207. default_collate_fn_map[np.ndarray] = collate_numpy_array_fn
  208. # See scalars hierarchy: https://numpy.org/doc/stable/reference/arrays.scalars.html
  209. # Skip string scalars
  210. default_collate_fn_map[(np.bool_, np.number, np.object_)] = collate_numpy_scalar_fn
  211. default_collate_fn_map[float] = collate_float_fn
  212. default_collate_fn_map[int] = collate_int_fn
  213. default_collate_fn_map[str] = collate_str_fn
  214. default_collate_fn_map[bytes] = collate_str_fn
  215. def default_collate(batch):
  216. r"""
  217. Take in a batch of data and put the elements within the batch into a tensor with an additional outer dimension - batch size.
  218. The exact output type can be a :class:`torch.Tensor`, a `Sequence` of :class:`torch.Tensor`, a
  219. Collection of :class:`torch.Tensor`, or left unchanged, depending on the input type.
  220. This is used as the default function for collation when
  221. `batch_size` or `batch_sampler` is defined in :class:`~torch.utils.data.DataLoader`.
  222. Here is the general input type (based on the type of the element within the batch) to output type mapping:
  223. * :class:`torch.Tensor` -> :class:`torch.Tensor` (with an added outer dimension batch size)
  224. * NumPy Arrays -> :class:`torch.Tensor`
  225. * `float` -> :class:`torch.Tensor`
  226. * `int` -> :class:`torch.Tensor`
  227. * `str` -> `str` (unchanged)
  228. * `bytes` -> `bytes` (unchanged)
  229. * `Mapping[K, V_i]` -> `Mapping[K, default_collate([V_1, V_2, ...])]`
  230. * `NamedTuple[V1_i, V2_i, ...]` -> `NamedTuple[default_collate([V1_1, V1_2, ...]),
  231. default_collate([V2_1, V2_2, ...]), ...]`
  232. * `Sequence[V1_i, V2_i, ...]` -> `Sequence[default_collate([V1_1, V1_2, ...]),
  233. default_collate([V2_1, V2_2, ...]), ...]`
  234. Args:
  235. batch: a single batch to be collated
  236. Examples:
  237. >>> # xdoctest: +SKIP
  238. >>> # Example with a batch of `int`s:
  239. >>> default_collate([0, 1, 2, 3])
  240. tensor([0, 1, 2, 3])
  241. >>> # Example with a batch of `str`s:
  242. >>> default_collate(['a', 'b', 'c'])
  243. ['a', 'b', 'c']
  244. >>> # Example with `Map` inside the batch:
  245. >>> default_collate([{'A': 0, 'B': 1}, {'A': 100, 'B': 100}])
  246. {'A': tensor([ 0, 100]), 'B': tensor([ 1, 100])}
  247. >>> # Example with `NamedTuple` inside the batch:
  248. >>> Point = namedtuple('Point', ['x', 'y'])
  249. >>> default_collate([Point(0, 0), Point(1, 1)])
  250. Point(x=tensor([0, 1]), y=tensor([0, 1]))
  251. >>> # Example with `Tuple` inside the batch:
  252. >>> default_collate([(0, 1), (2, 3)])
  253. [tensor([0, 2]), tensor([1, 3])]
  254. >>> # Example with `List` inside the batch:
  255. >>> default_collate([[0, 1], [2, 3]])
  256. [tensor([0, 2]), tensor([1, 3])]
  257. >>> # Two options to extend `default_collate` to handle specific type
  258. >>> # Option 1: Write custom collate function and invoke `default_collate`
  259. >>> def custom_collate(batch):
  260. ... elem = batch[0]
  261. ... if isinstance(elem, CustomType): # Some custom condition
  262. ... return ...
  263. ... else: # Fall back to `default_collate`
  264. ... return default_collate(batch)
  265. >>> # Option 2: In-place modify `default_collate_fn_map`
  266. >>> def collate_customtype_fn(batch, *, collate_fn_map=None):
  267. ... return ...
  268. >>> default_collate_fn_map.update(CustomType, collate_customtype_fn)
  269. >>> default_collate(batch) # Handle `CustomType` automatically
  270. """
  271. return collate(batch, collate_fn_map=default_collate_fn_map)