_debug.py 557 B

123456789101112131415161718192021
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import torch
  4. def friendly_debug_info(v):
  5. """
  6. Helper function to print out debug info in a friendly way.
  7. """
  8. if isinstance(v, torch.Tensor):
  9. return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})"
  10. else:
  11. return str(v)
  12. def map_debug_info(a):
  13. """
  14. Helper function to apply `friendly_debug_info` to items in `a`.
  15. `a` may be a list, tuple, or dict.
  16. """
  17. return torch.fx.node.map_aggregate(a, friendly_debug_info)