| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228 |
- # Copyright 2020 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from argparse import ArgumentParser, Namespace
- from typing import Any, List, Optional
- from ..pipelines import Pipeline, get_supported_tasks, pipeline
- from ..utils import logging
- from . import BaseTransformersCLICommand
- try:
- from fastapi import Body, FastAPI, HTTPException
- from fastapi.routing import APIRoute
- from pydantic import BaseModel
- from starlette.responses import JSONResponse
- from uvicorn import run
- _serve_dependencies_installed = True
- except (ImportError, AttributeError):
- BaseModel = object
- def Body(*x, **y):
- pass
- _serve_dependencies_installed = False
- logger = logging.get_logger("transformers-cli/serving")
- def serve_command_factory(args: Namespace):
- """
- Factory function used to instantiate serving server from provided command line arguments.
- Returns: ServeCommand
- """
- nlp = pipeline(
- task=args.task,
- model=args.model if args.model else None,
- config=args.config,
- tokenizer=args.tokenizer,
- device=args.device,
- )
- return ServeCommand(nlp, args.host, args.port, args.workers)
- class ServeModelInfoResult(BaseModel):
- """
- Expose model information
- """
- infos: dict
- class ServeTokenizeResult(BaseModel):
- """
- Tokenize result model
- """
- tokens: List[str]
- tokens_ids: Optional[List[int]]
- class ServeDeTokenizeResult(BaseModel):
- """
- DeTokenize result model
- """
- text: str
- class ServeForwardResult(BaseModel):
- """
- Forward result model
- """
- output: Any
- class ServeCommand(BaseTransformersCLICommand):
- @staticmethod
- def register_subcommand(parser: ArgumentParser):
- """
- Register this command to argparse so it's available for the transformer-cli
- Args:
- parser: Root parser to register command-specific arguments
- """
- serve_parser = parser.add_parser(
- "serve", help="CLI tool to run inference requests through REST and GraphQL endpoints."
- )
- serve_parser.add_argument(
- "--task",
- type=str,
- choices=get_supported_tasks(),
- help="The task to run the pipeline on",
- )
- serve_parser.add_argument("--host", type=str, default="localhost", help="Interface the server will listen on.")
- serve_parser.add_argument("--port", type=int, default=8888, help="Port the serving will listen to.")
- serve_parser.add_argument("--workers", type=int, default=1, help="Number of http workers")
- serve_parser.add_argument("--model", type=str, help="Model's name or path to stored model.")
- serve_parser.add_argument("--config", type=str, help="Model's config name or path to stored model.")
- serve_parser.add_argument("--tokenizer", type=str, help="Tokenizer name to use.")
- serve_parser.add_argument(
- "--device",
- type=int,
- default=-1,
- help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
- )
- serve_parser.set_defaults(func=serve_command_factory)
- def __init__(self, pipeline: Pipeline, host: str, port: int, workers: int):
- self._pipeline = pipeline
- self.host = host
- self.port = port
- self.workers = workers
- if not _serve_dependencies_installed:
- raise RuntimeError(
- "Using serve command requires FastAPI and uvicorn. "
- 'Please install transformers with [serving]: pip install "transformers[serving]". '
- "Or install FastAPI and uvicorn separately."
- )
- else:
- logger.info(f"Serving model over {host}:{port}")
- self._app = FastAPI(
- routes=[
- APIRoute(
- "/",
- self.model_info,
- response_model=ServeModelInfoResult,
- response_class=JSONResponse,
- methods=["GET"],
- ),
- APIRoute(
- "/tokenize",
- self.tokenize,
- response_model=ServeTokenizeResult,
- response_class=JSONResponse,
- methods=["POST"],
- ),
- APIRoute(
- "/detokenize",
- self.detokenize,
- response_model=ServeDeTokenizeResult,
- response_class=JSONResponse,
- methods=["POST"],
- ),
- APIRoute(
- "/forward",
- self.forward,
- response_model=ServeForwardResult,
- response_class=JSONResponse,
- methods=["POST"],
- ),
- ],
- timeout=600,
- )
- def run(self):
- run(self._app, host=self.host, port=self.port, workers=self.workers)
- def model_info(self):
- return ServeModelInfoResult(infos=vars(self._pipeline.model.config))
- def tokenize(self, text_input: str = Body(None, embed=True), return_ids: bool = Body(False, embed=True)):
- """
- Tokenize the provided input and eventually returns corresponding tokens id: - **text_input**: String to
- tokenize - **return_ids**: Boolean flags indicating if the tokens have to be converted to their integer
- mapping.
- """
- try:
- tokens_txt = self._pipeline.tokenizer.tokenize(text_input)
- if return_ids:
- tokens_ids = self._pipeline.tokenizer.convert_tokens_to_ids(tokens_txt)
- return ServeTokenizeResult(tokens=tokens_txt, tokens_ids=tokens_ids)
- else:
- return ServeTokenizeResult(tokens=tokens_txt)
- except Exception as e:
- raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
- def detokenize(
- self,
- tokens_ids: List[int] = Body(None, embed=True),
- skip_special_tokens: bool = Body(False, embed=True),
- cleanup_tokenization_spaces: bool = Body(True, embed=True),
- ):
- """
- Detokenize the provided tokens ids to readable text: - **tokens_ids**: List of tokens ids -
- **skip_special_tokens**: Flag indicating to not try to decode special tokens - **cleanup_tokenization_spaces**:
- Flag indicating to remove all leading/trailing spaces and intermediate ones.
- """
- try:
- decoded_str = self._pipeline.tokenizer.decode(tokens_ids, skip_special_tokens, cleanup_tokenization_spaces)
- return ServeDeTokenizeResult(model="", text=decoded_str)
- except Exception as e:
- raise HTTPException(status_code=500, detail={"model": "", "error": str(e)})
- async def forward(self, inputs=Body(None, embed=True)):
- """
- **inputs**: **attention_mask**: **tokens_type_ids**:
- """
- # Check we don't have empty string
- if len(inputs) == 0:
- return ServeForwardResult(output=[], attention=[])
- try:
- # Forward through the model
- output = self._pipeline(inputs)
- return ServeForwardResult(output=output)
- except Exception as e:
- raise HTTPException(500, {"error": str(e)})
|