traceback.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # mypy: allow-untyped-defs
  2. import traceback
  3. from contextlib import contextmanager
  4. from typing import List, Any, Dict
  5. from ._compatibility import compatibility
  6. __all__ = ['preserve_node_meta', 'has_preserved_node_meta',
  7. 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
  8. 'format_stack', 'set_current_meta', 'get_current_meta']
  9. current_meta: Dict[str, Any] = {}
  10. should_preserve_node_meta = False
  11. @compatibility(is_backward_compatible=False)
  12. @contextmanager
  13. def preserve_node_meta():
  14. global should_preserve_node_meta
  15. global current_meta
  16. saved_should_preserve_node_meta = should_preserve_node_meta
  17. # Shallow copy is OK since fields of current_meta are not mutated
  18. saved_current_meta = current_meta.copy()
  19. try:
  20. should_preserve_node_meta = True
  21. yield
  22. finally:
  23. should_preserve_node_meta = saved_should_preserve_node_meta
  24. current_meta = saved_current_meta
  25. @compatibility(is_backward_compatible=False)
  26. def set_stack_trace(stack : List[str]):
  27. global current_meta
  28. if should_preserve_node_meta and stack:
  29. current_meta["stack_trace"] = "".join(stack)
  30. @compatibility(is_backward_compatible=False)
  31. def set_grad_fn_seq_nr(seq_nr):
  32. global current_meta
  33. if should_preserve_node_meta:
  34. # The seq_nr is captured by eager mode in the grad_fn during forward
  35. current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr]
  36. current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
  37. @compatibility(is_backward_compatible=False)
  38. def reset_grad_fn_seq_nr():
  39. # NB: reset state properly, this would be helpful towards supporting
  40. # reentrant autograd if we actually wanted to do that.
  41. global current_meta
  42. if should_preserve_node_meta:
  43. current_level = current_meta.get("in_grad_fn", 0)
  44. assert current_level > 0
  45. if current_level == 1:
  46. del current_meta["in_grad_fn"]
  47. del current_meta["grad_fn_seq_nr"]
  48. else:
  49. current_meta["in_grad_fn"] = current_level - 1
  50. current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
  51. @compatibility(is_backward_compatible=False)
  52. def format_stack() -> List[str]:
  53. if should_preserve_node_meta:
  54. return [current_meta.get("stack_trace", "")]
  55. else:
  56. # fallback to traceback.format_stack()
  57. return traceback.format_list(traceback.extract_stack()[:-1])
  58. @compatibility(is_backward_compatible=False)
  59. def has_preserved_node_meta() -> bool:
  60. return should_preserve_node_meta
  61. @compatibility(is_backward_compatible=False)
  62. @contextmanager
  63. def set_current_meta(node):
  64. global current_meta
  65. if should_preserve_node_meta and node.meta:
  66. saved_meta = current_meta
  67. try:
  68. current_meta = node.meta.copy()
  69. # Append (node.name, node.target) onto "from_node" for provenance tracking
  70. if "from_node" not in current_meta:
  71. current_meta["from_node"] = [(node.name, node.target)]
  72. elif current_meta["from_node"][-1][0] != node.name:
  73. current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)]
  74. yield
  75. finally:
  76. current_meta = saved_meta
  77. else:
  78. yield
  79. @compatibility(is_backward_compatible=False)
  80. def get_current_meta() -> Dict[str, Any]:
  81. return current_meta