_utils.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. from dataclasses import dataclass
  5. from typing import List, Tuple, Union
  6. import torch
  7. from torch import fx
  8. logger = logging.getLogger(__name__)
  9. def flatten_args_detach(args):
  10. """
  11. Flatten the args into a list form and detach the tensors from computational graph.
  12. """
  13. flat_detached_args = []
  14. def extract_tensor_args(a):
  15. nonlocal flat_detached_args
  16. if isinstance(a, torch.Tensor):
  17. val = a.detach().requires_grad_(a.requires_grad)
  18. flat_detached_args.append(val)
  19. return val
  20. else:
  21. flat_detached_args.append(a)
  22. return a
  23. new_args = fx.node.map_aggregate(
  24. args,
  25. extract_tensor_args,
  26. )
  27. return new_args, flat_detached_args
  28. def flatten_args(args):
  29. """
  30. Flatten the args into a list form.
  31. """
  32. flat_args = []
  33. def extract_tensor_args(a):
  34. nonlocal flat_args
  35. flat_args.append(a)
  36. return a
  37. fx.node.map_aggregate(
  38. args,
  39. extract_tensor_args,
  40. )
  41. return flat_args
  42. class PipeliningShapeError(RuntimeError):
  43. """Shape mismatch between configured and runtime values."""
  44. def validate_tensor_metadata(desc, expected, given):
  45. if not expected.shape == given.shape:
  46. raise PipeliningShapeError(
  47. f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
  48. )
  49. if not expected.dtype == given.dtype:
  50. raise PipeliningShapeError(
  51. f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
  52. )
  53. if not expected.stride() == given.stride():
  54. raise PipeliningShapeError(
  55. f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
  56. )
  57. def validate_tensors_metadata(
  58. desc,
  59. expected_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
  60. actual_tensors: Union[List[torch.Tensor], Tuple[torch.Tensor, ...]],
  61. ):
  62. if len(expected_tensors) != len(actual_tensors):
  63. raise PipeliningShapeError(
  64. f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
  65. )
  66. for i in range(len(expected_tensors)):
  67. validate_tensor_metadata(
  68. f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
  69. )
  70. @dataclass
  71. class PipeInfo:
  72. """
  73. Captures information for a pipeline (`Pipe` object).
  74. """
  75. graph: fx.Graph
  76. num_stages: int
  77. has_loss_and_backward: bool