_exposed_in.py 629 B

123456789101112131415
  1. # mypy: allow-untyped-defs
  2. # Allows one to expose an API in a private submodule publicly as per the definition
  3. # in PyTorch's public api policy.
  4. #
  5. # It is a temporary solution while we figure out if it should be the long-term solution
  6. # or if we should amend PyTorch's public api policy. The concern is that this approach
  7. # may not be very robust because it's not clear what __module__ is used for.
  8. # However, both numpy and jax overwrite the __module__ attribute of their APIs
  9. # without problem, so it seems fine.
  10. def exposed_in(module):
  11. def wrapper(fn):
  12. fn.__module__ = module
  13. return fn
  14. return wrapper