| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104 |
- # mypy: allow-untyped-defs
- import traceback
- from contextlib import contextmanager
- from typing import List, Any, Dict
- from ._compatibility import compatibility
- __all__ = ['preserve_node_meta', 'has_preserved_node_meta',
- 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
- 'format_stack', 'set_current_meta', 'get_current_meta']
- current_meta: Dict[str, Any] = {}
- should_preserve_node_meta = False
- @compatibility(is_backward_compatible=False)
- @contextmanager
- def preserve_node_meta():
- global should_preserve_node_meta
- global current_meta
- saved_should_preserve_node_meta = should_preserve_node_meta
- # Shallow copy is OK since fields of current_meta are not mutated
- saved_current_meta = current_meta.copy()
- try:
- should_preserve_node_meta = True
- yield
- finally:
- should_preserve_node_meta = saved_should_preserve_node_meta
- current_meta = saved_current_meta
- @compatibility(is_backward_compatible=False)
- def set_stack_trace(stack : List[str]):
- global current_meta
- if should_preserve_node_meta and stack:
- current_meta["stack_trace"] = "".join(stack)
- @compatibility(is_backward_compatible=False)
- def set_grad_fn_seq_nr(seq_nr):
- global current_meta
- if should_preserve_node_meta:
- # The seq_nr is captured by eager mode in the grad_fn during forward
- current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr]
- current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
- @compatibility(is_backward_compatible=False)
- def reset_grad_fn_seq_nr():
- # NB: reset state properly, this would be helpful towards supporting
- # reentrant autograd if we actually wanted to do that.
- global current_meta
- if should_preserve_node_meta:
- current_level = current_meta.get("in_grad_fn", 0)
- assert current_level > 0
- if current_level == 1:
- del current_meta["in_grad_fn"]
- del current_meta["grad_fn_seq_nr"]
- else:
- current_meta["in_grad_fn"] = current_level - 1
- current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
- @compatibility(is_backward_compatible=False)
- def format_stack() -> List[str]:
- if should_preserve_node_meta:
- return [current_meta.get("stack_trace", "")]
- else:
- # fallback to traceback.format_stack()
- return traceback.format_list(traceback.extract_stack()[:-1])
- @compatibility(is_backward_compatible=False)
- def has_preserved_node_meta() -> bool:
- return should_preserve_node_meta
- @compatibility(is_backward_compatible=False)
- @contextmanager
- def set_current_meta(node):
- global current_meta
- if should_preserve_node_meta and node.meta:
- saved_meta = current_meta
- try:
- current_meta = node.meta.copy()
- # Append (node.name, node.target) onto "from_node" for provenance tracking
- if "from_node" not in current_meta:
- current_meta["from_node"] = [(node.name, node.target)]
- elif current_meta["from_node"][-1][0] != node.name:
- current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)]
- yield
- finally:
- current_meta = saved_meta
- else:
- yield
- @compatibility(is_backward_compatible=False)
- def get_current_meta() -> Dict[str, Any]:
- return current_meta
|