_composable_state.py 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637
  1. from typing import cast, Dict, Optional
  2. import torch.nn as nn
  3. class _State:
  4. pass
  5. _module_state_mapping: Dict[nn.Module, _State] = {}
  6. def _insert_module_state(module: nn.Module, state: _State) -> None:
  7. global _module_state_mapping
  8. assert module not in _module_state_mapping, f"Inserting {module} more than once."
  9. _module_state_mapping[module] = state
  10. def _get_module_state(module: nn.Module) -> Optional[_State]:
  11. """
  12. Return the ``_State`` in ``model``.
  13. Given a ``module``, this API finds out if the module is also a ``_State``
  14. instance or if the module is managed by a composable API. If the module
  15. is also a ``_State``, ``module`` will be casted to ``_State` and returned.
  16. If it is managed by a composable API, the corresponding ``_State`` will
  17. be returned.
  18. """
  19. global _module_state_mapping
  20. if isinstance(module, _State):
  21. return cast(_State, module)
  22. else:
  23. # https://github.com/pytorch/pytorch/issues/107054
  24. if module in _module_state_mapping:
  25. return _module_state_mapping[module]
  26. else:
  27. return None