tvm.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # mypy: ignore-errors
  2. import functools
  3. import importlib
  4. import logging
  5. import os
  6. import sys
  7. import tempfile
  8. from types import MappingProxyType
  9. from typing import Optional
  10. import torch
  11. from .common import device_from_inputs, fake_tensor_unsupported
  12. from .registry import register_backend
  13. log = logging.getLogger(__name__)
  14. @register_backend
  15. @fake_tensor_unsupported
  16. def tvm(
  17. gm,
  18. example_inputs,
  19. *,
  20. options: Optional[MappingProxyType] = MappingProxyType(
  21. {"scheduler": None, "trials": 20000, "opt_level": 3}
  22. ),
  23. ):
  24. import tvm # type: ignore[import]
  25. from tvm import relay # type: ignore[import]
  26. from tvm.contrib import graph_executor # type: ignore[import]
  27. jit_mod = torch.jit.trace(gm, example_inputs)
  28. device = device_from_inputs(example_inputs)
  29. shape_list = [(f"inp_{idx}", i.shape) for idx, i in enumerate(example_inputs)]
  30. example_outputs = gm(*example_inputs)
  31. if len(example_outputs) == 0:
  32. log.warning("Explicitly fall back to eager due to zero output")
  33. return gm.forward
  34. mod, params = relay.frontend.from_pytorch(jit_mod, shape_list)
  35. if device.type == "cuda":
  36. dev = tvm.cuda(device.index)
  37. target = tvm.target.cuda()
  38. else:
  39. dev = tvm.cpu(0)
  40. target = tvm.target.Target(llvm_target())
  41. scheduler = options.get("scheduler", None)
  42. if scheduler is None:
  43. scheduler = os.environ.get("TVM_SCHEDULER", None)
  44. trials = options.get("trials", 20000)
  45. opt_level = options.get("opt_level", 3)
  46. if scheduler == "auto_scheduler":
  47. from tvm import auto_scheduler
  48. log_file = tempfile.NamedTemporaryFile()
  49. if not os.path.exists(log_file):
  50. tasks, task_weights = auto_scheduler.extract_tasks(
  51. mod["main"], params, target
  52. )
  53. for task in tasks:
  54. print(task.compute_dag)
  55. else:
  56. print("No tasks")
  57. if len(tasks) != 0:
  58. tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
  59. if not os.path.exists(log_file):
  60. assert trials > 0
  61. tune_option = auto_scheduler.TuningOptions(
  62. num_measure_trials=trials,
  63. measure_callbacks=[auto_scheduler.RecordToFile(log_file)],
  64. early_stopping=2000,
  65. )
  66. try:
  67. tuner.tune(tune_option)
  68. except Exception:
  69. if os.path.exists(log_file):
  70. os.unlink(log_file)
  71. raise
  72. with auto_scheduler.ApplyHistoryBest(log_file):
  73. with tvm.transform.PassContext(
  74. opt_level=opt_level, config={"relay.backend.use_auto_scheduler": True}
  75. ):
  76. lib = relay.build(mod, target=target, params=params)
  77. elif scheduler == "meta_schedule":
  78. from tvm import meta_schedule as ms
  79. with tempfile.TemporaryDirectory() as work_dir:
  80. if device.type != "cuda":
  81. # meta_schedule needs num-cores to be specified
  82. # here we use the maximum core count
  83. target = tvm.target.Target(
  84. f"{llvm_target()} --num-cores {ms.utils.cpu_count(logical=False)}"
  85. )
  86. # TODO(shingjan): This could be replaced by tvm.contrib.torch.optimize_torch
  87. # once USE_PT_TVMDSOOP is updated and turned on by default in TVM.
  88. assert trials > 0
  89. database = ms.relay_integration.tune_relay(
  90. mod=mod,
  91. target=target,
  92. work_dir=work_dir,
  93. max_trials_global=trials,
  94. num_trials_per_iter=64,
  95. params=params,
  96. strategy="evolutionary",
  97. opt_level=opt_level,
  98. )
  99. lib = ms.relay_integration.compile_relay(
  100. database=database,
  101. mod=mod,
  102. target=target,
  103. params=params,
  104. opt_level=opt_level,
  105. )
  106. elif scheduler == "default" or not scheduler:
  107. # no autotuning
  108. with tvm.transform.PassContext(opt_level=opt_level):
  109. lib = relay.build(mod, target=target, params=params)
  110. else:
  111. raise NotImplementedError(
  112. "This tuning option is invalid/not implemented for torchdynamo's TVM-related backend. "
  113. "There are three available options: default, auto_scheduler and meta_schedule."
  114. )
  115. m = graph_executor.GraphModule(lib["default"](dev))
  116. def to_torch_tensor(nd_tensor):
  117. """A helper function to transfer a NDArray to torch.tensor."""
  118. if nd_tensor.dtype == "bool":
  119. # DLPack does not support boolean so it can't be handled by
  120. # torch.utils.dlpack.from_pack. Workaround by going through
  121. # numpy, although this brings additional data copy overhead.
  122. return torch.from_numpy(nd_tensor.numpy())
  123. return torch.utils.dlpack.from_dlpack(nd_tensor.to_dlpack())
  124. def to_tvm_tensor(torch_tensor):
  125. """A helper function to transfer a torch.tensor to NDArray."""
  126. if torch_tensor.dtype == torch.bool:
  127. # same reason as above, fallback to numpy conversion which
  128. # could introduce data copy overhead
  129. return tvm.nd.array(torch_tensor.cpu().numpy())
  130. return tvm.nd.from_dlpack(torch_tensor)
  131. def exec_tvm(*i_args):
  132. args = [a.contiguous() for a in i_args]
  133. shape_info, _ = m.get_input_info()
  134. active_inputs = {name for name, _ in shape_info.items()}
  135. for idx, arg in enumerate(args, 0):
  136. if arg.dim() != 0:
  137. if arg.requires_grad:
  138. arg = arg.detach()
  139. inp_name = f"inp_{idx}"
  140. if inp_name not in active_inputs:
  141. log.warning(
  142. "input %s skipped as not found in tvm's runtime library",
  143. inp_name,
  144. )
  145. continue
  146. m.set_input(
  147. inp_name,
  148. to_tvm_tensor(arg),
  149. )
  150. m.run()
  151. return [to_torch_tensor(m.get_output(i)) for i in range(m.get_num_outputs())]
  152. return exec_tvm
  153. tvm_meta_schedule = functools.partial(tvm, scheduler="meta_schedule")
  154. tvm_auto_scheduler = functools.partial(tvm, scheduler="auto_scheduler")
  155. def has_tvm():
  156. try:
  157. importlib.import_module("tvm")
  158. return True
  159. except ImportError:
  160. return False
  161. @functools.lru_cache(None)
  162. def llvm_target():
  163. if sys.platform == "linux":
  164. cpuinfo = open("/proc/cpuinfo").read()
  165. if "avx512" in cpuinfo:
  166. return "llvm -mcpu=skylake-avx512"
  167. elif "avx2" in cpuinfo:
  168. return "llvm -mcpu=core-avx2"
  169. return "llvm"