remote_device.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # mypy: allow-untyped-defs
  2. from typing import Optional, Union
  3. import torch
  4. class _remote_device:
  5. """
  6. Represents a device on a remote worker.
  7. Args:
  8. remote_device (str or torch.device): Represents a device on a remote worker.
  9. The string format should be one of the following:
  10. 1. "<workername>/<device>", where the device field can be parsed as torch.device type.
  11. E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
  12. In addition, the device field can be optional and the default value is "cpu".
  13. 2. "rank:<rank>/<device>", where <rank> is the rank of the
  14. process and device can be parsed as torch.device type.
  15. E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
  16. 3. <workername> and <rank> are optional and formats like "cpu"
  17. and "cuda:1", just represent local devices.
  18. """
  19. def __init__(self, remote_device: Union[str, torch.device]):
  20. PARSE_ERROR = (
  21. f"Could not parse remote_device: {remote_device}. The valid format is "
  22. "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'"
  23. )
  24. self._worker_name = None
  25. self._rank = None
  26. self._device: Optional[Union[str, int, torch.device]] = None
  27. if isinstance(remote_device, torch.device):
  28. self._device = remote_device
  29. elif isinstance(remote_device, str):
  30. fields = remote_device.split("/")
  31. if len(fields) == 2:
  32. self._worker_name, self._device = fields
  33. elif len(fields) == 1:
  34. # Check if this is a valid device.
  35. if _remote_device._is_valid_local_device(fields[0]):
  36. self._device = fields[0]
  37. else:
  38. self._worker_name = fields[0]
  39. self._device = "cpu"
  40. else:
  41. raise ValueError(PARSE_ERROR)
  42. else:
  43. raise TypeError(f'Invalid type for remote_device: {type(remote_device)}')
  44. # Do some basic sanity check (no empty string)
  45. if self._worker_name is not None and not self._worker_name:
  46. raise ValueError(PARSE_ERROR)
  47. # Validate the device.
  48. self._device = torch.device(self._device)
  49. # Check for rank based format.
  50. if self._worker_name is not None:
  51. fields = self._worker_name.split(":")
  52. if len(fields) == 2:
  53. # rank:<rank>/device format, extract rank
  54. if fields[0] == "rank" and fields[1].isdigit():
  55. self._rank = int(fields[1]) # type: ignore[assignment]
  56. self._worker_name = None
  57. else:
  58. raise ValueError(PARSE_ERROR)
  59. elif len(fields) > 2:
  60. raise ValueError(PARSE_ERROR)
  61. @staticmethod
  62. def _is_valid_local_device(device):
  63. # Check for torch.device
  64. try:
  65. torch.device(device)
  66. return True
  67. except Exception:
  68. return False
  69. def worker_name(self) -> Optional[str]:
  70. """Return the name of remote worker representing the remote device and ``None`` if no worker name is available."""
  71. return self._worker_name
  72. def rank(self) -> Optional[int]:
  73. """
  74. Returns the rank of remote worker representing the remote device.
  75. Returns ``None`` if no rank is available.
  76. """
  77. return self._rank
  78. def device(self) -> torch.device:
  79. """Return the local device on the remote worker."""
  80. return self._device # type: ignore[return-value]
  81. def __repr__(self):
  82. if self._device is not None:
  83. if self._worker_name is not None:
  84. return f'{self._worker_name}/{self._device}'
  85. elif self._rank is not None:
  86. return f'rank:{self._rank}/{self._device}'
  87. else:
  88. return str(self._device)
  89. else:
  90. if self._worker_name is not None:
  91. return f'{self._worker_name}'
  92. elif self._rank is not None:
  93. return f'{self._rank}'
  94. else:
  95. raise RuntimeError('Invalid state!')
  96. def __eq__(self, other):
  97. if not isinstance(other, _remote_device):
  98. return False
  99. if (
  100. self._worker_name == other._worker_name
  101. and self._device == other._device
  102. and self._rank == other._rank
  103. ):
  104. return True
  105. return False
  106. def __hash__(self):
  107. return hash(self._worker_name) ^ \
  108. hash(self._device) ^ \
  109. hash(self._rank)