rref_proxy.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # mypy: allow-untyped-defs
  2. from functools import partial
  3. from . import functions
  4. from . import rpc_async
  5. import torch
  6. from .constants import UNSET_RPC_TIMEOUT
  7. from torch.futures import Future
  8. def _local_invoke(rref, func_name, args, kwargs):
  9. return getattr(rref.local_value(), func_name)(*args, **kwargs)
  10. @functions.async_execution
  11. def _local_invoke_async_execution(rref, func_name, args, kwargs):
  12. return getattr(rref.local_value(), func_name)(*args, **kwargs)
  13. def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs):
  14. def _rref_type_cont(rref_fut):
  15. rref_type = rref_fut.value()
  16. _invoke_func = _local_invoke
  17. # Bypass ScriptModules when checking for async function attribute.
  18. bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass(
  19. rref_type, torch._C.ScriptModule
  20. )
  21. if not bypass_type:
  22. func = getattr(rref_type, func_name)
  23. if hasattr(func, "_wrapped_async_rpc_function"):
  24. _invoke_func = _local_invoke_async_execution
  25. return rpc_api(
  26. rref.owner(),
  27. _invoke_func,
  28. args=(rref, func_name, args, kwargs),
  29. timeout=timeout
  30. )
  31. rref_fut = rref._get_type(timeout=timeout, blocking=False)
  32. if rpc_api != rpc_async:
  33. rref_fut.wait()
  34. return _rref_type_cont(rref_fut)
  35. else:
  36. # A little explanation on this.
  37. # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]`
  38. # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]`
  39. # To address that, we return a Future that is completed with the result of the async call.
  40. result: Future = Future()
  41. def _wrap_rref_type_cont(fut):
  42. try:
  43. _rref_type_cont(fut).then(_complete_op)
  44. except BaseException as ex:
  45. result.set_exception(ex)
  46. def _complete_op(fut):
  47. try:
  48. result.set_result(fut.value())
  49. except BaseException as ex:
  50. result.set_exception(ex)
  51. rref_fut.then(_wrap_rref_type_cont)
  52. return result
  53. # This class manages proxied RPC API calls for RRefs. It is entirely used from
  54. # C++ (see python_rpc_handler.cpp).
  55. class RRefProxy:
  56. def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT):
  57. self.rref = rref
  58. self.rpc_api = rpc_api
  59. self.rpc_timeout = timeout
  60. def __getattr__(self, func_name):
  61. return partial(_invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout)