serving.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from argparse import ArgumentParser, Namespace
  15. from typing import Any, List, Optional
  16. from ..pipelines import Pipeline, get_supported_tasks, pipeline
  17. from ..utils import logging
  18. from . import BaseTransformersCLICommand
  19. try:
  20. from fastapi import Body, FastAPI, HTTPException
  21. from fastapi.routing import APIRoute
  22. from pydantic import BaseModel
  23. from starlette.responses import JSONResponse
  24. from uvicorn import run
  25. _serve_dependencies_installed = True
  26. except (ImportError, AttributeError):
  27. BaseModel = object
  28. def Body(*x, **y):
  29. pass
  30. _serve_dependencies_installed = False
  31. logger = logging.get_logger("transformers-cli/serving")
  32. def serve_command_factory(args: Namespace):
  33. """
  34. Factory function used to instantiate serving server from provided command line arguments.
  35. Returns: ServeCommand
  36. """
  37. nlp = pipeline(
  38. task=args.task,
  39. model=args.model if args.model else None,
  40. config=args.config,
  41. tokenizer=args.tokenizer,
  42. device=args.device,
  43. )
  44. return ServeCommand(nlp, args.host, args.port, args.workers)
  45. class ServeModelInfoResult(BaseModel):
  46. """
  47. Expose model information
  48. """
  49. infos: dict
  50. class ServeTokenizeResult(BaseModel):
  51. """
  52. Tokenize result model
  53. """
  54. tokens: List[str]
  55. tokens_ids: Optional[List[int]]
  56. class ServeDeTokenizeResult(BaseModel):
  57. """
  58. DeTokenize result model
  59. """
  60. text: str
  61. class ServeForwardResult(BaseModel):
  62. """
  63. Forward result model
  64. """
  65. output: Any
  66. class ServeCommand(BaseTransformersCLICommand):
  67. @staticmethod
  68. def register_subcommand(parser: ArgumentParser):
  69. """
  70. Register this command to argparse so it's available for the transformer-cli
  71. Args:
  72. parser: Root parser to register command-specific arguments
  73. """
  74. serve_parser = parser.add_parser(
  75. "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
  76. )
  77. serve_parser.add_argument(
  78. "--task",
  79. type=str,
  80. choices=get_supported_tasks(),
  81. help="The task to run the pipeline on",
  82. )
  83. serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
  84. serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
  85. serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
  86. serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
  87. serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
  88. serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
  89. serve_parser.add_argument(
  90. "--device",
  91. type=int,
  92. default=-1,
  93. help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
  94. )
  95. serve_parser.set_defaults(func=serve_command_factory)
  96. def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
  97. self._pipeline = pipeline
  98. self.host = host
  99. self.port = port
  100. self.workers = workers
  101. if not _serve_dependencies_installed:
  102. raise RuntimeError(
  103. "Using serve command requires FastAPI and uvicorn. "
  104. 'Please install transformers with [serving]: pip install "transformers[serving]". '
  105. "Or install FastAPI and uvicorn separately."
  106. )
  107. else:
  108. logger.info(f"Serving model over {host}:{port}")
  109. self._app = FastAPI(
  110. routes=[
  111. APIRoute(
  112. "/",
  113. self.model_info,
  114. response_model=ServeModelInfoResult,
  115. response_class=JSONResponse,
  116. methods=["GET"],
  117. ),
  118. APIRoute(
  119. "/tokenize",
  120. self.tokenize,
  121. response_model=ServeTokenizeResult,
  122. response_class=JSONResponse,
  123. methods=["POST"],
  124. ),
  125. APIRoute(
  126. "/detokenize",
  127. self.detokenize,
  128. response_model=ServeDeTokenizeResult,
  129. response_class=JSONResponse,
  130. methods=["POST"],
  131. ),
  132. APIRoute(
  133. "/forward",
  134. self.forward,
  135. response_model=ServeForwardResult,
  136. response_class=JSONResponse,
  137. methods=["POST"],
  138. ),
  139. ],
  140. timeout=600,
  141. )
  142. def run(self):
  143. run(self._app, host=self.host, port=self.port, workers=self.workers)
  144. def model_info(self):
  145. return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
  146. def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
  147. """
  148. Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
  149. tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
  150. mapping.
  151. """
  152. try:
  153. tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
  154. if return_ids:
  155. tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
  156. return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
  157. else:
  158. return ServeTokenizeResult(tokens=tokens_txt)
  159. except Exception as e:
  160. raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
  161. def detokenize(
  162. self,
  163. tokens_ids: List[int] = Body(None, embed=True),
  164. skip_special_tokens: bool = Body(False, embed=True),
  165. cleanup_tokenization_spaces: bool = Body(True, embed=True),
  166. ):
  167. """
  168. Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
  169. **skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
  170. Flag indicating to remove all leading/trailing spaces and intermediate ones.
  171. """
  172. try:
  173. decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
  174. return ServeDeTokenizeResult(model="", text=decoded_str)
  175. except Exception as e:
  176. raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
  177. async def forward(self, inputs=Body(None, embed=True)):
  178. """
  179. **inputs**: **attention_mask**: **tokens_type_ids**:
  180. """
  181. # Check we don't have empty string
  182. if len(inputs) == 0:
  183. return ServeForwardResult(output=[], attention=[])
  184. try:
  185. # Forward through the model
  186. output = self._pipeline(inputs)
  187. return ServeForwardResult(output=output)
  188. except Exception as e:
  189. raise HTTPException(500, {"error": str(e)})