__future__.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. _overwrite_module_params_on_conversion: bool = False
  2. _swap_module_params_on_conversion: bool = False
  3. def set_overwrite_module_params_on_conversion(value: bool) -> None:
  4. """
  5. Sets whether to assign new tensors to the parameters instead of changing the
  6. existing parameters in-place when converting an ``nn.Module``.
  7. When enabled, the following methods will assign new parameters to the module:
  8. #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
  9. #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
  10. #. :meth:`nn.Module.to`
  11. #. :meth:`nn.Module.to_empty`
  12. Args:
  13. value (bool): Whether to assign new tensors or not.
  14. """
  15. global _overwrite_module_params_on_conversion
  16. _overwrite_module_params_on_conversion = value
  17. def get_overwrite_module_params_on_conversion() -> bool:
  18. """
  19. Returns whether to assign new tensors to the parameters instead of changing the
  20. existing parameters in-place when converting an :class:`torch.nn.Module`. Defaults to ``False``.
  21. See :func:`~torch.__future__.set_overwrite_module_params_on_conversion` for more information.
  22. """
  23. return _overwrite_module_params_on_conversion
  24. def set_swap_module_params_on_conversion(value: bool) -> None:
  25. """
  26. Sets whether to use :func:`~torch.utils.swap_tensors` instead of setting ``.data`` to
  27. change the existing parameters in-place when converting an ``nn.Module`` and instead
  28. of ``param.copy_(state_dict[key])`` when loading a state dict into an ``nn.Module``.
  29. .. note::
  30. This function takes precedence over :func:`~torch.__future__.get_overwrite_module_params_on_conversion`
  31. When enabled, the following methods will swap the existing parameters in-place:
  32. #. ``module.{device}()`` (e.g. :meth:`nn.Module.cuda()`) for moving a module between devices
  33. #. ``module.{dtype}()`` (e.g. :meth:`nn.Module.float()`) for converting a module to a different dtype
  34. #. :meth:`nn.Module.to`
  35. #. :meth:`nn.Module.to_empty`
  36. #. :meth:`nn.Module.load_state_dict`
  37. The semantics for :meth:`~nn.Module.load_state_dict` when this is set are as follows:
  38. #. For each parameter/buffer, its corresponding ``state_dict['key']`` is transformed via
  39. :meth:`~torch.Tensor.module_load` (i.e. ``res = param.module_load(state_dict['key'])``)
  40. #. If necessary, ``res`` will be wrapped in an :class:`~nn.Parameter`
  41. #. The parameter/buffer in the module will be swapped via :func:`~torch.utils.swap_tensors`
  42. with ``res``
  43. Args:
  44. value (bool): Whether to use :func:`~torch.utils.swap_tensors` or not.
  45. """
  46. global _swap_module_params_on_conversion
  47. _swap_module_params_on_conversion = value
  48. def get_swap_module_params_on_conversion() -> bool:
  49. """
  50. Returns whether to use :func:`~torch.utils.swap_tensors` instead of setting .data to
  51. change the existing parameters in-place when converting an ``nn.Module``. Defaults to ``False``.
  52. See :func:`~torch.__future__.set_swap_module_params_on_conversion` for more information.
  53. """
  54. return _swap_module_params_on_conversion