static_module.py 893 B

123456789101112131415161718192021222324252627
  1. # mypy: allow-untyped-defs
  2. # Owner(s): ["module: unknown"]
  3. import torch
  4. class StaticModule:
  5. def __init__(self, scripted):
  6. # this is an nn.Module
  7. if hasattr(scripted, "_c"):
  8. self.static_module = torch._C._jit_to_static_module(scripted._c)
  9. else:
  10. self.static_module = torch._C._jit_to_static_module(scripted.graph)
  11. def __call__(self, *args, **kwargs):
  12. return self.static_module(*args, **kwargs)
  13. def benchmark(self, args, kwargs, warmup_runs, main_runs):
  14. self.static_module.benchmark(args, kwargs, warmup_runs, main_runs)
  15. def runAsync(self, args, kwargs):
  16. return self.static_module.runAsync(args, kwargs)
  17. def benchmark_individual_ops(self, args, kwargs, warmup_runs, main_runs):
  18. return self.static_module.benchmark_individual_ops(
  19. args, kwargs, warmup_runs, main_runs
  20. )