prepare.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # mypy: allow-untyped-defs
  2. from typing import List, Optional
  3. import torch
  4. from torch.backends._nnapi.serializer import _NnapiSerializer
  5. ANEURALNETWORKS_PREFER_LOW_POWER = 0
  6. ANEURALNETWORKS_PREFER_FAST_SINGLE_ANSWER = 1
  7. ANEURALNETWORKS_PREFER_SUSTAINED_SPEED = 2
  8. class NnapiModule(torch.nn.Module):
  9. """Torch Module that wraps an NNAPI Compilation.
  10. This module handles preparing the weights, initializing the
  11. NNAPI TorchBind object, and adjusting the memory formats
  12. of all inputs and outputs.
  13. """
  14. # _nnapi.Compilation is defined
  15. comp: Optional[torch.classes._nnapi.Compilation] # type: ignore[name-defined]
  16. weights: List[torch.Tensor]
  17. out_templates: List[torch.Tensor]
  18. def __init__(
  19. self,
  20. shape_compute_module: torch.nn.Module,
  21. ser_model: torch.Tensor,
  22. weights: List[torch.Tensor],
  23. inp_mem_fmts: List[int],
  24. out_mem_fmts: List[int],
  25. compilation_preference: int,
  26. relax_f32_to_f16: bool,
  27. ):
  28. super().__init__()
  29. self.shape_compute_module = shape_compute_module
  30. self.ser_model = ser_model
  31. self.weights = weights
  32. self.inp_mem_fmts = inp_mem_fmts
  33. self.out_mem_fmts = out_mem_fmts
  34. self.out_templates = []
  35. self.comp = None
  36. self.compilation_preference = compilation_preference
  37. self.relax_f32_to_f16 = relax_f32_to_f16
  38. @torch.jit.export
  39. def init(self, args: List[torch.Tensor]):
  40. assert self.comp is None
  41. self.out_templates = self.shape_compute_module.prepare(self.ser_model, args) # type: ignore[operator]
  42. self.weights = [w.contiguous() for w in self.weights]
  43. comp = torch.classes._nnapi.Compilation()
  44. comp.init2(
  45. self.ser_model,
  46. self.weights,
  47. self.compilation_preference,
  48. self.relax_f32_to_f16,
  49. )
  50. self.comp = comp
  51. def forward(self, args: List[torch.Tensor]) -> List[torch.Tensor]:
  52. if self.comp is None:
  53. self.init(args)
  54. comp = self.comp
  55. assert comp is not None
  56. outs = [torch.empty_like(out) for out in self.out_templates]
  57. assert len(args) == len(self.inp_mem_fmts)
  58. fixed_args = []
  59. for idx in range(len(args)):
  60. fmt = self.inp_mem_fmts[idx]
  61. # These constants match the values in DimOrder in serializer.py
  62. # TODO: See if it's possible to use those directly.
  63. if fmt == 0:
  64. fixed_args.append(args[idx].contiguous())
  65. elif fmt == 1:
  66. fixed_args.append(args[idx].permute(0, 2, 3, 1).contiguous())
  67. else:
  68. raise ValueError("Invalid mem_fmt")
  69. comp.run(fixed_args, outs)
  70. assert len(outs) == len(self.out_mem_fmts)
  71. for idx in range(len(self.out_templates)):
  72. fmt = self.out_mem_fmts[idx]
  73. # These constants match the values in DimOrder in serializer.py
  74. # TODO: See if it's possible to use those directly.
  75. if fmt in (0, 2):
  76. pass
  77. elif fmt == 1:
  78. outs[idx] = outs[idx].permute(0, 3, 1, 2)
  79. else:
  80. raise ValueError("Invalid mem_fmt")
  81. return outs
  82. def convert_model_to_nnapi(
  83. model,
  84. inputs,
  85. serializer=None,
  86. return_shapes=None,
  87. use_int16_for_qint16=False,
  88. compilation_preference=ANEURALNETWORKS_PREFER_SUSTAINED_SPEED,
  89. relax_f32_to_f16=False,
  90. ):
  91. (
  92. shape_compute_module,
  93. ser_model_tensor,
  94. used_weights,
  95. inp_mem_fmts,
  96. out_mem_fmts,
  97. retval_count,
  98. ) = process_for_nnapi(
  99. model, inputs, serializer, return_shapes, use_int16_for_qint16
  100. )
  101. nnapi_model = NnapiModule(
  102. shape_compute_module,
  103. ser_model_tensor,
  104. used_weights,
  105. inp_mem_fmts,
  106. out_mem_fmts,
  107. compilation_preference,
  108. relax_f32_to_f16,
  109. )
  110. class NnapiInterfaceWrapper(torch.nn.Module):
  111. """NNAPI list-ifying and de-list-ifying wrapper.
  112. NNAPI always expects a list of inputs and provides a list of outputs.
  113. This module allows us to accept inputs as separate arguments.
  114. It returns results as either a single tensor or tuple,
  115. matching the original module.
  116. """
  117. def __init__(self, mod):
  118. super().__init__()
  119. self.mod = mod
  120. wrapper_model_py = NnapiInterfaceWrapper(nnapi_model)
  121. wrapper_model = torch.jit.script(wrapper_model_py)
  122. # TODO: Maybe make these names match the original.
  123. arg_list = ", ".join(f"arg_{idx}" for idx in range(len(inputs)))
  124. if retval_count < 0:
  125. ret_expr = "retvals[0]"
  126. else:
  127. ret_expr = "".join(f"retvals[{idx}], " for idx in range(retval_count))
  128. wrapper_model.define(
  129. f"def forward(self, {arg_list}):\n"
  130. f" retvals = self.mod([{arg_list}])\n"
  131. f" return {ret_expr}\n"
  132. )
  133. return wrapper_model
  134. def process_for_nnapi(
  135. model, inputs, serializer=None, return_shapes=None, use_int16_for_qint16=False
  136. ):
  137. model = torch.jit.freeze(model)
  138. if isinstance(inputs, torch.Tensor):
  139. inputs = [inputs]
  140. serializer = serializer or _NnapiSerializer(
  141. config=None, use_int16_for_qint16=use_int16_for_qint16
  142. )
  143. (
  144. ser_model,
  145. used_weights,
  146. inp_mem_fmts,
  147. out_mem_fmts,
  148. shape_compute_lines,
  149. retval_count,
  150. ) = serializer.serialize_model(model, inputs, return_shapes)
  151. ser_model_tensor = torch.tensor(ser_model, dtype=torch.int32)
  152. # We have to create a new class here every time this function is called
  153. # because module.define adds a method to the *class*, not the instance.
  154. class ShapeComputeModule(torch.nn.Module):
  155. """Code-gen-ed module for tensor shape computation.
  156. module.prepare will mutate ser_model according to the computed operand
  157. shapes, based on the shapes of args. Returns a list of output templates.
  158. """
  159. pass
  160. shape_compute_module = torch.jit.script(ShapeComputeModule())
  161. real_shape_compute_lines = [
  162. "def prepare(self, ser_model: torch.Tensor, args: List[torch.Tensor]) -> List[torch.Tensor]:\n",
  163. ] + [f" {line}\n" for line in shape_compute_lines]
  164. shape_compute_module.define("".join(real_shape_compute_lines))
  165. return (
  166. shape_compute_module,
  167. ser_model_tensor,
  168. used_weights,
  169. inp_mem_fmts,
  170. out_mem_fmts,
  171. retval_count,
  172. )