__init__.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. # mypy: allow-untyped-defs
  2. import os
  3. import torch
  4. from torch.jit._serialization import validate_map_location
  5. def _load_for_lite_interpreter(f, map_location=None):
  6. r"""
  7. Load a :class:`LiteScriptModule` saved with :func:`torch.jit._save_for_lite_interpreter`.
  8. Args:
  9. f: a file-like object (has to implement read, readline, tell, and seek),
  10. or a string containing a file name
  11. map_location: a string or torch.device used to dynamically remap
  12. storages to an alternative set of devices.
  13. Returns:
  14. A :class:`LiteScriptModule` object.
  15. Example:
  16. .. testcode::
  17. import torch
  18. import io
  19. # Load LiteScriptModule from saved file path
  20. torch.jit._load_for_lite_interpreter('lite_script_module.pt')
  21. # Load LiteScriptModule from io.BytesIO object
  22. with open('lite_script_module.pt', 'rb') as f:
  23. buffer = io.BytesIO(f.read())
  24. # Load all tensors to the original device
  25. torch.jit.mobile._load_for_lite_interpreter(buffer)
  26. """
  27. if isinstance(f, (str, os.PathLike)):
  28. if not os.path.exists(f):
  29. raise ValueError(f"The provided filename {f} does not exist")
  30. if os.path.isdir(f):
  31. raise ValueError(f"The provided filename {f} is a directory")
  32. map_location = validate_map_location(map_location)
  33. if isinstance(f, (str, os.PathLike)):
  34. cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
  35. else:
  36. cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
  37. f.read(), map_location
  38. )
  39. return LiteScriptModule(cpp_module)
  40. class LiteScriptModule:
  41. def __init__(self, cpp_module):
  42. self._c = cpp_module
  43. super().__init__()
  44. def __call__(self, *input):
  45. return self._c.forward(input)
  46. def find_method(self, method_name):
  47. return self._c.find_method(method_name)
  48. def forward(self, *input):
  49. return self._c.forward(input)
  50. def run_method(self, method_name, *input):
  51. return self._c.run_method(method_name, input)
  52. def _export_operator_list(module: LiteScriptModule):
  53. r"""Return a set of root operator names (with overload name) that are used by any method in this mobile module."""
  54. return torch._C._export_operator_list(module._c)
  55. def _get_model_bytecode_version(f_input) -> int:
  56. r"""Take a file-like object to return an integer.
  57. Args:
  58. f_input: a file-like object (has to implement read, readline, tell, and seek),
  59. or a string containing a file name
  60. Returns:
  61. version: An integer. If the integer is -1, the version is invalid. A warning
  62. will show in the log.
  63. Example:
  64. .. testcode::
  65. from torch.jit.mobile import _get_model_bytecode_version
  66. # Get bytecode version from a saved file path
  67. version = _get_model_bytecode_version("path/to/model.ptl")
  68. """
  69. if isinstance(f_input, (str, os.PathLike)):
  70. if not os.path.exists(f_input):
  71. raise ValueError(f"The provided filename {f_input} does not exist")
  72. if os.path.isdir(f_input):
  73. raise ValueError(f"The provided filename {f_input} is a directory")
  74. if isinstance(f_input, (str, os.PathLike)):
  75. return torch._C._get_model_bytecode_version(os.fspath(f_input))
  76. else:
  77. return torch._C._get_model_bytecode_version_from_buffer(f_input.read())
  78. def _get_mobile_model_contained_types(f_input) -> int:
  79. r"""Take a file-like object and return a set of string, like ("int", "Optional").
  80. Args:
  81. f_input: a file-like object (has to implement read, readline, tell, and seek),
  82. or a string containing a file name
  83. Returns:
  84. type_list: A set of string, like ("int", "Optional"). These are types used in bytecode.
  85. Example:
  86. .. testcode::
  87. from torch.jit.mobile import _get_mobile_model_contained_types
  88. # Get type list from a saved file path
  89. type_list = _get_mobile_model_contained_types("path/to/model.ptl")
  90. """
  91. if isinstance(f_input, (str, os.PathLike)):
  92. if not os.path.exists(f_input):
  93. raise ValueError(f"The provided filename {f_input} does not exist")
  94. if os.path.isdir(f_input):
  95. raise ValueError(f"The provided filename {f_input} is a directory")
  96. if isinstance(f_input, (str, os.PathLike)):
  97. return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
  98. else:
  99. return torch._C._get_mobile_model_contained_types_from_buffer(f_input.read())
  100. def _backport_for_mobile(f_input, f_output, to_version):
  101. r"""Take a input string containing a file name (file-like object) and a new destination to return a boolean.
  102. Args:
  103. f_input: a file-like object (has to implement read, readline, tell, and seek),
  104. or a string containing a file name
  105. f_output: path to new model destination
  106. to_version: the expected output model bytecode version
  107. Returns:
  108. success: A boolean. If backport success, return true, otherwise false
  109. """
  110. if isinstance(f_input, (str, os.PathLike)):
  111. if not os.path.exists(f_input):
  112. raise ValueError(f"The provided filename {f_input} does not exist")
  113. if os.path.isdir(f_input):
  114. raise ValueError(f"The provided filename {f_input} is a directory")
  115. if (isinstance(f_input, (str, os.PathLike))) and (
  116. isinstance(f_output, (str, os.PathLike))
  117. ):
  118. return torch._C._backport_for_mobile(
  119. os.fspath(f_input), os.fspath(f_output), to_version
  120. )
  121. else:
  122. return torch._C._backport_for_mobile_from_buffer(
  123. f_input.read(), str(f_output), to_version
  124. )
  125. def _backport_for_mobile_to_buffer(f_input, to_version):
  126. r"""Take a string containing a file name (file-like object).
  127. Args:
  128. f_input: a file-like object (has to implement read, readline, tell, and seek),
  129. or a string containing a file name
  130. """
  131. if isinstance(f_input, (str, os.PathLike)):
  132. if not os.path.exists(f_input):
  133. raise ValueError(f"The provided filename {f_input} does not exist")
  134. if os.path.isdir(f_input):
  135. raise ValueError(f"The provided filename {f_input} is a directory")
  136. if isinstance(f_input, (str, os.PathLike)):
  137. return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
  138. else:
  139. return torch._C._backport_for_mobile_from_buffer_to_buffer(
  140. f_input.read(), to_version
  141. )
  142. def _get_model_ops_and_info(f_input):
  143. r"""Retrieve the root (top level) operators of a model and their corresponding compatibility info.
  144. These root operators can call other operators within them (traced ops), and
  145. a root op can call many different traced ops depending on internal code paths in the root op.
  146. These traced ops are not returned by this function. Those operators are abstracted into the
  147. runtime as an implementation detail (and the traced ops themselves can also call other operators)
  148. making retrieving them difficult and their value from this api negligible since they will differ
  149. between which runtime version the model is run on. Because of this, there is a false positive this
  150. api can't prevent in a compatibility usecase. All the root ops of a model are present in a
  151. target runtime, but not all the traced ops are which prevents a model from being able to run.
  152. Args:
  153. f_input: a file-like object (has to implement read, readline, tell, and seek),
  154. or a string containing a file name
  155. Returns:
  156. Operators and info: A Dictionary mapping strings (the qualified names of the root operators)
  157. of the model to their OperatorInfo structs.
  158. Example:
  159. .. testcode::
  160. from torch.jit.mobile import _get_model_ops_and_info
  161. # Get bytecode version from a saved file path
  162. ops_and_info = _get_model_ops_and_info("path/to/model.ptl")
  163. """
  164. if isinstance(f_input, (str, os.PathLike)):
  165. if not os.path.exists(f_input):
  166. raise ValueError(f"The provided filename {f_input} does not exist")
  167. if os.path.isdir(f_input):
  168. raise ValueError(f"The provided filename {f_input} is a directory")
  169. if isinstance(f_input, (str, os.PathLike)):
  170. return torch._C._get_model_ops_and_info(os.fspath(f_input))
  171. else:
  172. return torch._C._get_model_ops_and_info(f_input.read())