batch_tensor.py 667 B

12345678910111213141516171819202122232425
  1. # Copyright (c) Facebook, Inc. and its affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. from contextlib import contextmanager
  7. from torch._C._functorch import _vmap_add_layers, _vmap_remove_layers
  8. _enabled = False
  9. @contextmanager
  10. def _enable_layers(dims):
  11. global _enabled
  12. assert not _enabled
  13. input = sorted((d._level, d.size) for d in dims if not isinstance(d, int))
  14. n = len(input)
  15. try:
  16. _vmap_add_layers(input)
  17. _enabled = True
  18. yield
  19. finally:
  20. _enabled = False
  21. _vmap_remove_layers(n)