| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- # mypy: allow-untyped-defs
- import dataclasses
- import traceback
- from typing import (
- Any,
- Callable,
- Container,
- Dict,
- List,
- Optional,
- OrderedDict,
- overload,
- Tuple,
- TypeVar,
- )
- import torch
- import torch.distributed as dist
- from torch import nn
- from torch.nn.parallel._functions import _get_stream
- from torch.nn.parallel.scatter_gather import _is_namedtuple
- from torch.nn.utils.rnn import PackedSequence
- __all__ = [] # type: ignore[var-annotated]
- def _pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[Any, ...], Tuple[str, ...]]:
- """
- Turn argument list into separate key list and value list (unpack_kwargs does the opposite).
- Inspiration: https://github.com/facebookresearch/fairscale/blob/eeb6684/fairscale/internal/containers.py#L70
- Usage::
- kwarg_keys, flat_args = pack_kwargs(1, 2, a=3, b=4)
- assert kwarg_keys == ("a", "b")
- assert flat_args == (1, 2, 3, 4)
- args, kwargs = unpack_kwargs(kwarg_keys, flat_args)
- assert args == (1, 2)
- assert kwargs == {"a": 3, "b": 4}
- Returns:
- Tuple[Tuple[Any, ...], Tuple[str, ...]]: The first tuple element gives
- gives both positional args and kwarg values, where the positional args
- proceed kwarg values and kwarg values are ordered consistently with the
- kwarg keys. The second tuple element gives the kwarg keys.
- The second tuple element's length is at most the first tuple element's length.
- """
- kwarg_keys: List[str] = []
- flat_args: List[Any] = list(args)
- for k, v in kwargs.items():
- kwarg_keys.append(k)
- flat_args.append(v)
- return tuple(flat_args), tuple(kwarg_keys)
- def _cast_forward_inputs(
- dtype: Optional[torch.dtype],
- *args: Any,
- **kwargs: Any,
- ) -> Tuple[Any, Any]:
- """
- Cast floating point tensors in ``args`` and ``kwargs`` to ``input_dtype``.
- This respects the existing ``requires_grad`` on the tensors.
- """
- if dtype is None:
- return args, kwargs
- def cast_fn(x: torch.Tensor) -> torch.Tensor:
- if not torch.is_floating_point(x) or x.dtype == dtype:
- return x
- return x.to(dtype)
- return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
- def _unpack_kwargs(
- flat_args: Tuple[Any, ...], kwarg_keys: Tuple[str, ...]
- ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
- """See _pack_kwargs."""
- assert len(kwarg_keys) <= len(
- flat_args
- ), f"too many keys {len(kwarg_keys)} vs. {len(flat_args)}"
- if len(kwarg_keys) == 0:
- return flat_args, {}
- args = flat_args[: -len(kwarg_keys)]
- kwargs = dict(zip(kwarg_keys, flat_args[-len(kwarg_keys) :]))
- return args, kwargs
- S = TypeVar("S", dict, list, tuple)
- T = TypeVar("T", torch.Tensor, PackedSequence)
- @overload
- def _recursive_to(
- inputs: S, target_device: torch.device, use_side_stream_for_tensor_copies: bool
- ) -> List[S]:
- ...
- @overload
- def _recursive_to(
- inputs: T, target_device: torch.device, use_side_stream_for_tensor_copies: bool
- ) -> Tuple[T]:
- ...
- def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
- r"""Recursively moves input to the target_device."""
- def to_map(obj):
- if isinstance(obj, (torch.Tensor, PackedSequence)):
- device = obj.data.device if isinstance(obj, PackedSequence) else obj.device
- if device == target_device:
- return (obj,)
- if not use_side_stream_for_tensor_copies:
- return (obj.to(target_device),)
- else:
- # If the custom module is not registered to torch, stream is not used for acceleration
- device_mod = getattr(torch, device.type, None)
- if device.type == "cpu" or device_mod is None:
- return (obj.to(target_device),)
- # Perform CPU -> target_device copies in a background stream. This code is
- # motivated from similar logic in torch/nn/parallel/_functions.py
- stream = _get_stream(target_device)
- with device_mod.stream(stream):
- output = obj.to(target_device)
- # synchronize with the copy stream
- with device_mod.device(target_device.index):
- current_stream = device_mod.current_stream()
- # Sync the current stream with the copy stream
- current_stream.wait_stream(stream)
- # Ensure tensor memory is not reused until work on
- # main stream is complete
- if isinstance(obj, PackedSequence):
- output.data.record_stream(current_stream) # type: ignore[arg-type]
- else:
- assert isinstance(output, torch.Tensor)
- output.record_stream(current_stream) # type: ignore[arg-type]
- return (output,)
- if _is_namedtuple(obj):
- return [type(obj)(*args) for args in zip(*map(to_map, obj))]
- if isinstance(obj, tuple) and len(obj) > 0:
- return list(zip(*map(to_map, obj)))
- if isinstance(obj, list) and len(obj) > 0:
- return [list(i) for i in zip(*map(to_map, obj))]
- if isinstance(obj, dict) and len(obj) > 0:
- return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
- return [obj]
- # Avoid reference cycle
- try:
- res = to_map(inputs)
- finally:
- to_map = None # type: ignore[assignment]
- return res
- def _p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
- """Alternate to ``assert`` when in the backward context to print the error message ``s`` since otherwise, it is swallowed."""
- if not cond:
- print(s)
- traceback.print_stack()
- if raise_assertion_error:
- raise AssertionError(s)
- def _alloc_storage(tensor: torch.Tensor, size: torch.Size) -> None:
- """
- Allocate storage for ``tensor`` with the given size.
- Returns:
- bool: ``True`` if this method allocated storage and ``False`` if the
- storage was already allocated.
- """
- with torch.no_grad():
- if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
- already_allocated = tensor._typed_storage()._size() == size.numel()
- if not already_allocated:
- tensor_storage_size = tensor._typed_storage()._size()
- _p_assert(
- tensor_storage_size == 0,
- "Tensor storage should have been resized to be 0 but got PLACEHOLDEr",
- )
- tensor._typed_storage()._resize_(size.numel())
- def _free_storage(tensor: torch.Tensor):
- """
- Frees the underlying storage of ``tensor``.
- Returns:
- bool: ``True`` if the method freed the storage and ``False`` if the
- storage was already freed.
- """
- with torch.no_grad():
- if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
- already_freed = tensor._typed_storage()._size() == 0
- if not already_freed:
- _p_assert(
- tensor.storage_offset() == 0,
- "Freeing a tensor's storage is unsafe when it is not the sole occupant\n"
- f"storage offset: {tensor.storage_offset()}\n"
- f"storage size: {tensor._typed_storage()._size()}\n"
- f"tensor shape: {tensor.shape}",
- )
- tensor._typed_storage()._resize_(0)
- Q = TypeVar("Q")
- R = TypeVar("R", dict, list, tuple, set, OrderedDict, PackedSequence, Any)
- @overload
- def _apply_to_tensors(fn: Callable[[torch.Tensor], Q], container: torch.Tensor) -> Q:
- ...
- @overload
- def _apply_to_tensors(fn: Callable[[torch.Tensor], Any], container: R) -> R:
- ...
- def _apply_to_tensors(fn, container):
- """Recursively apply to all tensor in different kinds of container types."""
- def apply(x):
- if isinstance(x, torch.Tensor):
- return fn(x)
- elif hasattr(x, "__dataclass_fields__"):
- dc = dataclasses.replace(x)
- for f in dataclasses.fields(dc):
- name = f.name
- setattr(dc, name, apply(getattr(dc, name)))
- return dc
- elif isinstance(x, OrderedDict):
- od = x.__class__()
- for key, value in x.items():
- od[key] = apply(value)
- return od
- elif isinstance(x, PackedSequence):
- apply(x.data)
- return x
- elif isinstance(x, dict):
- return {key: apply(value) for key, value in x.items()}
- elif _is_namedtuple(x):
- res = (apply(el) for el in x)
- return type(x)(*res)
- elif isinstance(x, (list, tuple, set)):
- return type(x)(apply(el) for el in x)
- else:
- return x
- return apply(container)
- def _to_kwargs(
- inputs: Tuple[Any, ...],
- kwargs: Optional[Dict[str, Any]],
- target_device: torch.device,
- use_side_stream_for_tensor_copies: bool,
- ) -> Tuple[Tuple[Any, ...], Tuple[Dict[str, Any], ...]]:
- moved_inputs = (
- _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies)
- if inputs
- else []
- )
- moved_kwargs = (
- _recursive_to(kwargs, target_device, use_side_stream_for_tensor_copies)
- if kwargs
- else []
- )
- if len(moved_inputs) < len(moved_kwargs):
- moved_inputs.extend([() for _ in range(len(moved_kwargs) - len(inputs))])
- elif len(moved_kwargs) < len(moved_inputs):
- moved_kwargs.extend([{} for _ in range(len(moved_inputs) - len(moved_kwargs))])
- return tuple(moved_inputs), tuple(moved_kwargs)
- def _verify_param_shape_across_processes(
- process_group: dist.ProcessGroup,
- tensors: List[torch.Tensor],
- logger: Optional[dist.Logger] = None,
- ):
- return dist._verify_params_across_processes(process_group, tensors, logger)
- def _sync_module_states(
- module: nn.Module,
- process_group: dist.ProcessGroup,
- broadcast_bucket_size: int,
- src: int,
- params_and_buffers_to_ignore: Container[str],
- broadcast_buffers: bool = True,
- ) -> None:
- """
- Sync ``module``'s parameters and buffers state.
- Syncs ``module``'s parameters and buffers state so that all ranks contain
- the same module state across all ranks. Note that this API assumes that all
- parameter shapes are consistent before running the synchronization. This can
- be checked with ``_verify_param_shape_across_processes``.
- """
- module_states: List[torch.Tensor] = []
- for name, param in module.named_parameters():
- if name not in params_and_buffers_to_ignore:
- module_states.append(param.detach())
- if broadcast_buffers:
- for name, buffer in module.named_buffers():
- if name not in params_and_buffers_to_ignore:
- module_states.append(buffer.detach())
- _sync_params_and_buffers(process_group, module_states, broadcast_bucket_size, src)
- def _sync_params_and_buffers(
- process_group: dist.ProcessGroup,
- module_states: List[torch.Tensor],
- broadcast_bucket_size: int,
- src: int,
- ) -> None:
- """Synchronize ``module_states`` (list of tensors) across all processes by broadcasting them from rank 0."""
- if len(module_states) > 0:
- dist._broadcast_coalesced(
- process_group, module_states, broadcast_bucket_size, src
- )
- def _replace_by_prefix(
- state_dict: Dict[str, Any],
- old_prefix: str,
- new_prefix: str,
- ) -> None:
- """
- Replace all keys that match a given old_prefix with a new_prefix (in-place).
- Usage::
- state_dict = {"layer.xyz": torch.tensor(1)}
- replace_by_prefix_(state_dict, "layer.", "module.layer.")
- assert state_dict == {"module.layer.xyz": torch.tensor(1)}
- """
- if old_prefix == new_prefix:
- raise ValueError("old_prefix and new_prefix must be distinct")
- for key in list(state_dict.keys()):
- if not key.startswith(old_prefix):
- continue
- new_key = new_prefix + key[len(old_prefix) :]
- state_dict[new_key] = state_dict[key]
- del state_dict[key]
- def _data_ptr_allocated(tensor: torch.Tensor) -> bool:
- return tensor.untyped_storage().data_ptr() > 0
|