computation.py 919 B

123456789101112131415161718192021222324252627
  1. # mypy: allow-untyped-defs
  2. import torch._C._lazy
  3. import torch._C._lazy_ts_backend
  4. def get_tensors_ts_device_data_node(tensors):
  5. """Return tensor ids and eager tensors for DeviceData nodes in the
  6. IR for the passed in lazy tensors.
  7. TODO: This API is currently ts backend specific. We are working on
  8. generalizing it to all backends including XLA.
  9. """
  10. return torch._C._lazy_ts_backend._get_tensors_ts_device_data_node(tensors)
  11. def get_graph_hash(tensors):
  12. """Return the graph hash for the passed in lazy tensors"""
  13. return torch._C._lazy._get_graph_hash(tensors)
  14. def run_cached_graph(hash_str, graph_inputs):
  15. """Running the cached computation graph with the given inputs
  16. TODO: This API is currently ts backend specific. We are working on
  17. generalizing it to all backends including XLA.
  18. """
  19. return torch._C._lazy_ts_backend._run_cached_graph(hash_str, graph_inputs)