convert_parameters.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import torch
  2. from typing import Iterable, Optional
  3. def parameters_to_vector(parameters: Iterable[torch.Tensor]) -> torch.Tensor:
  4. r"""Flatten an iterable of parameters into a single vector.
  5. Args:
  6. parameters (Iterable[Tensor]): an iterable of Tensors that are the
  7. parameters of a model.
  8. Returns:
  9. The parameters represented by a single vector
  10. """
  11. # Flag for the device where the parameter is located
  12. param_device = None
  13. vec = []
  14. for param in parameters:
  15. # Ensure the parameters are located in the same device
  16. param_device = _check_param_device(param, param_device)
  17. vec.append(param.view(-1))
  18. return torch.cat(vec)
  19. def vector_to_parameters(vec: torch.Tensor, parameters: Iterable[torch.Tensor]) -> None:
  20. r"""Copy slices of a vector into an iterable of parameters.
  21. Args:
  22. vec (Tensor): a single vector representing the parameters of a model.
  23. parameters (Iterable[Tensor]): an iterable of Tensors that are the
  24. parameters of a model.
  25. """
  26. # Ensure vec of type Tensor
  27. if not isinstance(vec, torch.Tensor):
  28. raise TypeError(f'expected torch.Tensor, but got: {torch.typename(vec)}')
  29. # Flag for the device where the parameter is located
  30. param_device = None
  31. # Pointer for slicing the vector for each parameter
  32. pointer = 0
  33. for param in parameters:
  34. # Ensure the parameters are located in the same device
  35. param_device = _check_param_device(param, param_device)
  36. # The length of the parameter
  37. num_param = param.numel()
  38. # Slice the vector, reshape it, and replace the old data of the parameter
  39. param.data = vec[pointer:pointer + num_param].view_as(param).data
  40. # Increment the pointer
  41. pointer += num_param
  42. def _check_param_device(param: torch.Tensor, old_param_device: Optional[int]) -> int:
  43. r"""Check if the parameters are located on the same device.
  44. Currently, the conversion between model parameters and single vector form is not supported
  45. for multiple allocations, e.g. parameters in different GPUs/PrivateUse1s, or mixture of CPU/GPU/PrivateUse1.
  46. Args:
  47. param ([Tensor]): a Tensor of a parameter of a model
  48. old_param_device (int): the device where the first parameter of a
  49. model is allocated.
  50. Returns:
  51. old_param_device (int): report device for the first time
  52. """
  53. # Meet the first parameter
  54. support_device_types = ["cuda", torch._C._get_privateuse1_backend_name()]
  55. if old_param_device is None:
  56. old_param_device = param.get_device() if param.device.type in support_device_types else -1
  57. else:
  58. warn = False
  59. if param.device.type in support_device_types: # Check if in same GPU/PrivateUse1
  60. warn = (param.get_device() != old_param_device)
  61. else: # Check if in CPU
  62. warn = (old_param_device != -1)
  63. if warn:
  64. raise TypeError('Found two parameters on different devices, '
  65. 'this is currently not supported.')
  66. return old_param_device