parallel_apply.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. import threading
  2. import torch
  3. from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
  4. from ..modules import Module
  5. from torch.cuda._utils import _get_device_index
  6. from torch._utils import ExceptionWrapper
  7. __all__ = ['get_a_var', 'parallel_apply']
  8. def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]:
  9. if isinstance(obj, torch.Tensor):
  10. return obj
  11. if isinstance(obj, (list, tuple)):
  12. for result in map(get_a_var, obj):
  13. if isinstance(result, torch.Tensor):
  14. return result
  15. if isinstance(obj, dict):
  16. for result in map(get_a_var, obj.items()):
  17. if isinstance(result, torch.Tensor):
  18. return result
  19. return None
  20. def parallel_apply(
  21. modules: Sequence[Module],
  22. inputs: Sequence[Any],
  23. kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None,
  24. devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None,
  25. ) -> List[Any]:
  26. r"""Apply each `module` in :attr:`modules` in parallel on each of :attr:`devices`.
  27. Args:
  28. modules (Module): modules to be parallelized
  29. inputs (tensor): inputs to the modules
  30. devices (list of int or torch.device): CUDA devices
  31. :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and
  32. :attr:`devices` (if given) should all have same length. Moreover, each
  33. element of :attr:`inputs` can either be a single object as the only argument
  34. to a module, or a collection of positional arguments.
  35. """
  36. assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}'
  37. if kwargs_tup is not None:
  38. assert len(modules) == len(kwargs_tup)
  39. else:
  40. kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules)
  41. if devices is not None:
  42. assert len(modules) == len(devices)
  43. else:
  44. devices = [None] * len(modules)
  45. devices = [_get_device_index(x, True) for x in devices]
  46. streams = [torch.cuda.current_stream(x) for x in devices]
  47. lock = threading.Lock()
  48. results = {}
  49. grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled()
  50. def _worker(
  51. i: int,
  52. module: Module,
  53. input: Any,
  54. kwargs: Dict[str, Any],
  55. device: Optional[Union[int, torch.device]] = None,
  56. stream: Optional[torch.cuda.Stream] = None,
  57. ) -> None:
  58. torch.set_grad_enabled(grad_enabled)
  59. if device is None:
  60. t = get_a_var(input)
  61. if t is None:
  62. with lock:
  63. results[i] = ExceptionWrapper(
  64. where=f"in replica {i}, no device was provided and no tensor input was found; "
  65. "device cannot be resolved")
  66. return
  67. device = t.get_device()
  68. if stream is None:
  69. stream = torch.cuda.current_stream(device)
  70. try:
  71. with torch.cuda.device(device), torch.cuda.stream(
  72. stream
  73. ), torch.amp.autocast("cuda", enabled=autocast_enabled):
  74. # this also avoids accidental slicing of `input` if it is a Tensor
  75. if not isinstance(input, (list, tuple)):
  76. input = (input,)
  77. output = module(*input, **kwargs)
  78. with lock:
  79. results[i] = output
  80. except Exception:
  81. with lock:
  82. results[i] = ExceptionWrapper(
  83. where=f"in replica {i} on device {device}")
  84. if len(modules) > 1:
  85. threads = [threading.Thread(target=_worker,
  86. args=(i, module, input, kwargs, device, stream))
  87. for i, (module, input, kwargs, device, stream) in
  88. enumerate(zip(modules, inputs, kwargs_tup, devices, streams))]
  89. for thread in threads:
  90. thread.start()
  91. for thread in threads:
  92. thread.join()
  93. else:
  94. _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0])
  95. outputs = []
  96. for i in range(len(inputs)):
  97. output = results[i]
  98. if isinstance(output, ExceptionWrapper):
  99. output.reraise()
  100. outputs.append(output)
  101. return outputs