annotate.py 956 B

12345678910111213141516171819202122
  1. # mypy: allow-untyped-defs
  2. from torch.fx.proxy import Proxy
  3. from ._compatibility import compatibility
  4. @compatibility(is_backward_compatible=False)
  5. def annotate(val, type):
  6. # val could be either a regular value (not tracing)
  7. # or fx.Proxy (tracing)
  8. if isinstance(val, Proxy):
  9. if val.node.type:
  10. raise RuntimeError(f"Tried to annotate a value that already had a type on it!"
  11. f" Existing type is {val.node.type} "
  12. f"and new type is {type}. "
  13. f"This could happen if you tried to annotate a function parameter "
  14. f"value (in which case you should use the type slot "
  15. f"on the function signature) or you called "
  16. f"annotate on the same value twice")
  17. else:
  18. val.node.type = type
  19. return val
  20. else:
  21. return val