_mode_utils.py 251 B

1234567891011
  1. # mypy: allow-untyped-defs
  2. import torch
  3. from typing import TypeVar
  4. T = TypeVar('T')
  5. # returns if all are the same mode
  6. def all_same_mode(modes):
  7. return all(tuple(mode == modes[0] for mode in modes))
  8. no_dispatch = torch._C._DisableTorchDispatch