_contextlib.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # mypy: allow-untyped-defs
  2. # Extra utilities for working with context managers that should have been
  3. # in the standard library but are not
  4. import functools
  5. import inspect
  6. import warnings
  7. import sys
  8. from typing import Any, Callable, TypeVar, cast
  9. # Used for annotating the decorator usage of _DecoratorContextManager (e.g.,
  10. # 'no_grad' and 'enable_grad').
  11. # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
  12. FuncType = Callable[..., Any]
  13. F = TypeVar('F', bound=FuncType)
  14. def _wrap_generator(ctx_factory, func):
  15. """
  16. Wrap each generator invocation with the context manager factory.
  17. The input should be a function that returns a context manager,
  18. not a context manager itself, to handle one-shot context managers.
  19. """
  20. @functools.wraps(func)
  21. def generator_context(*args, **kwargs):
  22. gen = func(*args, **kwargs)
  23. # Generators are suspended and unsuspended at `yield`, hence we
  24. # make sure the grad mode is properly set every time the execution
  25. # flow returns into the wrapped generator and restored when it
  26. # returns through our `yield` to our caller (see PR #49017).
  27. try:
  28. # Issuing `None` to a generator fires it up
  29. with ctx_factory():
  30. response = gen.send(None)
  31. while True:
  32. try:
  33. # Forward the response to our caller and get its next request
  34. request = yield response
  35. except GeneratorExit:
  36. # Inform the still active generator about its imminent closure
  37. with ctx_factory():
  38. gen.close()
  39. raise
  40. except BaseException:
  41. # Propagate the exception thrown at us by the caller
  42. with ctx_factory():
  43. response = gen.throw(*sys.exc_info())
  44. else:
  45. # Pass the last request to the generator and get its response
  46. with ctx_factory():
  47. response = gen.send(request)
  48. # We let the exceptions raised above by the generator's `.throw` or
  49. # `.send` methods bubble up to our caller, except for StopIteration
  50. except StopIteration as e:
  51. # The generator informed us that it is done: take whatever its
  52. # returned value (if any) was and indicate that we're done too
  53. # by returning it (see docs for python's return-statement).
  54. return e.value
  55. return generator_context
  56. def context_decorator(ctx, func):
  57. """
  58. Like contextlib.ContextDecorator.
  59. But with the following differences:
  60. 1. Is done by wrapping, rather than inheritance, so it works with context
  61. managers that are implemented from C and thus cannot easily inherit from
  62. Python classes
  63. 2. Wraps generators in the intuitive way (c.f. https://bugs.python.org/issue37743)
  64. 3. Errors out if you try to wrap a class, because it is ambiguous whether
  65. or not you intended to wrap only the constructor
  66. The input argument can either be a context manager (in which case it must
  67. be a multi-shot context manager that can be directly invoked multiple times)
  68. or a callable that produces a context manager.
  69. """
  70. assert not (callable(ctx) and hasattr(ctx, '__enter__')), (
  71. f"Passed in {ctx} is both callable and also a valid context manager "
  72. "(has __enter__), making it ambiguous which interface to use. If you "
  73. "intended to pass a context manager factory, rewrite your call as "
  74. "context_decorator(lambda: ctx()); if you intended to pass a context "
  75. "manager directly, rewrite your call as context_decorator(lambda: ctx)"
  76. )
  77. if not callable(ctx):
  78. def ctx_factory():
  79. return ctx
  80. else:
  81. ctx_factory = ctx
  82. if inspect.isclass(func):
  83. raise RuntimeError(
  84. "Cannot decorate classes; it is ambiguous whether or not only the "
  85. "constructor or all methods should have the context manager applied; "
  86. "additionally, decorating a class at definition-site will prevent "
  87. "use of the identifier as a conventional type. "
  88. "To specify which methods to decorate, decorate each of them "
  89. "individually."
  90. )
  91. if inspect.isgeneratorfunction(func):
  92. return _wrap_generator(ctx_factory, func)
  93. @functools.wraps(func)
  94. def decorate_context(*args, **kwargs):
  95. with ctx_factory():
  96. return func(*args, **kwargs)
  97. return decorate_context
  98. class _DecoratorContextManager:
  99. """Allow a context manager to be used as a decorator."""
  100. def __call__(self, orig_func: F) -> F:
  101. if inspect.isclass(orig_func):
  102. warnings.warn(
  103. "Decorating classes is deprecated and will be disabled in "
  104. "future versions. You should only decorate functions or methods. "
  105. "To preserve the current behavior of class decoration, you can "
  106. "directly decorate the `__init__` method and nothing else.",
  107. FutureWarning,
  108. stacklevel=2,
  109. )
  110. func = cast(F, lambda *args, **kwargs: orig_func(*args, **kwargs))
  111. else:
  112. func = orig_func
  113. return cast(F, context_decorator(self.clone, func))
  114. def __enter__(self) -> None:
  115. raise NotImplementedError
  116. def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
  117. raise NotImplementedError
  118. def clone(self):
  119. # override this method if your children class takes __init__ parameters
  120. return self.__class__()
  121. class _NoParamDecoratorContextManager(_DecoratorContextManager):
  122. """Allow a context manager to be used as a decorator without parentheses."""
  123. def __new__(cls, orig_func=None):
  124. if orig_func is None:
  125. return super().__new__(cls)
  126. return cls()(orig_func)