_ir_utils.py 677 B

1234567891011121314151617181920212223242526
  1. # mypy: allow-untyped-defs
  2. from typing import Union
  3. import torch
  4. class _InsertPoint:
  5. def __init__(
  6. self,
  7. insert_point_graph: torch._C.Graph,
  8. insert_point: Union[torch._C.Node, torch._C.Block],
  9. ):
  10. self.insert_point = insert_point
  11. self.g = insert_point_graph
  12. self.guard = None
  13. def __enter__(self):
  14. self.prev_insert_point = self.g.insertPoint()
  15. self.g.setInsertPoint(self.insert_point)
  16. def __exit__(self, *args):
  17. self.g.setInsertPoint(self.prev_insert_point)
  18. def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
  19. return _InsertPoint(self, insert_point)