| 123456789101112131415 |
- # mypy: allow-untyped-defs
- # Allows one to expose an API in a private submodule publicly as per the definition
- # in PyTorch's public api policy.
- #
- # It is a temporary solution while we figure out if it should be the long-term solution
- # or if we should amend PyTorch's public api policy. The concern is that this approach
- # may not be very robust because it's not clear what __module__ is used for.
- # However, both numpy and jax overwrite the __module__ attribute of their APIs
- # without problem, so it seems fine.
- def exposed_in(module):
- def wrapper(fn):
- fn.__module__ = module
- return fn
- return wrapper
|