dlpack.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. from typing import Any
  2. import torch
  3. import enum
  4. from torch._C import _from_dlpack
  5. from torch._C import _to_dlpack as to_dlpack
  6. class DLDeviceType(enum.IntEnum):
  7. # Enums as in DLPack specification (aten/src/ATen/dlpack.h)
  8. kDLCPU = 1,
  9. kDLGPU = 2,
  10. kDLCPUPinned = 3,
  11. kDLOpenCL = 4,
  12. kDLVulkan = 7,
  13. kDLMetal = 8,
  14. kDLVPI = 9,
  15. kDLROCM = 10,
  16. kDLExtDev = 12,
  17. kDLOneAPI = 14,
  18. torch._C._add_docstr(to_dlpack, r"""to_dlpack(tensor) -> PyCapsule
  19. Returns an opaque object (a "DLPack capsule") representing the tensor.
  20. .. note::
  21. ``to_dlpack`` is a legacy DLPack interface. The capsule it returns
  22. cannot be used for anything in Python other than use it as input to
  23. ``from_dlpack``. The more idiomatic use of DLPack is to call
  24. ``from_dlpack`` directly on the tensor object - this works when that
  25. object has a ``__dlpack__`` method, which PyTorch and most other
  26. libraries indeed have now.
  27. .. warning::
  28. Only call ``from_dlpack`` once per capsule produced with ``to_dlpack``.
  29. Behavior when a capsule is consumed multiple times is undefined.
  30. Args:
  31. tensor: a tensor to be exported
  32. The DLPack capsule shares the tensor's memory.
  33. """)
  34. # TODO: add a typing.Protocol to be able to tell Mypy that only objects with
  35. # __dlpack__ and __dlpack_device__ methods are accepted.
  36. def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
  37. """from_dlpack(ext_tensor) -> Tensor
  38. Converts a tensor from an external library into a ``torch.Tensor``.
  39. The returned PyTorch tensor will share the memory with the input tensor
  40. (which may have come from another library). Note that in-place operations
  41. will therefore also affect the data of the input tensor. This may lead to
  42. unexpected issues (e.g., other libraries may have read-only flags or
  43. immutable data structures), so the user should only do this if they know
  44. for sure that this is fine.
  45. Args:
  46. ext_tensor (object with ``__dlpack__`` attribute, or a DLPack capsule):
  47. The tensor or DLPack capsule to convert.
  48. If ``ext_tensor`` is a tensor (or ndarray) object, it must support
  49. the ``__dlpack__`` protocol (i.e., have a ``ext_tensor.__dlpack__``
  50. method). Otherwise ``ext_tensor`` may be a DLPack capsule, which is
  51. an opaque ``PyCapsule`` instance, typically produced by a
  52. ``to_dlpack`` function or method.
  53. Examples::
  54. >>> import torch.utils.dlpack
  55. >>> t = torch.arange(4)
  56. # Convert a tensor directly (supported in PyTorch >= 1.10)
  57. >>> t2 = torch.from_dlpack(t)
  58. >>> t2[:2] = -1 # show that memory is shared
  59. >>> t2
  60. tensor([-1, -1, 2, 3])
  61. >>> t
  62. tensor([-1, -1, 2, 3])
  63. # The old-style DLPack usage, with an intermediate capsule object
  64. >>> capsule = torch.utils.dlpack.to_dlpack(t)
  65. >>> capsule
  66. <capsule object "dltensor" at ...>
  67. >>> t3 = torch.from_dlpack(capsule)
  68. >>> t3
  69. tensor([-1, -1, 2, 3])
  70. >>> t3[0] = -9 # now we're sharing memory between 3 tensors
  71. >>> t3
  72. tensor([-9, -1, 2, 3])
  73. >>> t2
  74. tensor([-9, -1, 2, 3])
  75. >>> t
  76. tensor([-9, -1, 2, 3])
  77. """
  78. if hasattr(ext_tensor, '__dlpack__'):
  79. device = ext_tensor.__dlpack_device__()
  80. # device is either CUDA or ROCm, we need to pass the current
  81. # stream
  82. if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM):
  83. stream = torch.cuda.current_stream(f'cuda:{device[1]}')
  84. # cuda_stream is the pointer to the stream and it is a public
  85. # attribute, but it is not documented
  86. # The array API specify that the default legacy stream must be passed
  87. # with a value of 1 for CUDA
  88. # https://data-apis.org/array-api/latest/API_specification/array_object.html?dlpack-self-stream-none#dlpack-self-stream-none
  89. is_cuda = device[0] == DLDeviceType.kDLGPU
  90. # Since pytorch is not using PTDS by default, lets directly pass
  91. # the legacy stream
  92. stream_ptr = 1 if is_cuda and stream.cuda_stream == 0 else stream.cuda_stream
  93. dlpack = ext_tensor.__dlpack__(stream=stream_ptr)
  94. else:
  95. dlpack = ext_tensor.__dlpack__()
  96. else:
  97. # Old versions just call the converter
  98. dlpack = ext_tensor
  99. return _from_dlpack(dlpack)