__init__.py 938 B

1234567891011121314151617181920212223242526272829303132333435363738
  1. from ._flat_param import FlatParameter as FlatParameter
  2. from .fully_sharded_data_parallel import (
  3. BackwardPrefetch,
  4. CPUOffload,
  5. FullOptimStateDictConfig,
  6. FullStateDictConfig,
  7. FullyShardedDataParallel,
  8. LocalOptimStateDictConfig,
  9. LocalStateDictConfig,
  10. MixedPrecision,
  11. OptimStateDictConfig,
  12. OptimStateKeyType,
  13. ShardedOptimStateDictConfig,
  14. ShardedStateDictConfig,
  15. ShardingStrategy,
  16. StateDictConfig,
  17. StateDictSettings,
  18. StateDictType,
  19. )
  20. __all__ = [
  21. "BackwardPrefetch",
  22. "CPUOffload",
  23. "FullOptimStateDictConfig",
  24. "FullStateDictConfig",
  25. "FullyShardedDataParallel",
  26. "LocalOptimStateDictConfig",
  27. "LocalStateDictConfig",
  28. "MixedPrecision",
  29. "OptimStateDictConfig",
  30. "OptimStateKeyType",
  31. "ShardedOptimStateDictConfig",
  32. "ShardedStateDictConfig",
  33. "ShardingStrategy",
  34. "StateDictConfig",
  35. "StateDictSettings",
  36. "StateDictType",
  37. ]