| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- # Copyright 2020 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from argparse import ArgumentParser
- from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
- from ..utils import logging
- from . import BaseTransformersCLICommand
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
- def try_infer_format_from_ext(path: str):
- if not path:
- return "pipe"
- for ext in PipelineDataFormat.SUPPORTED_FORMATS:
- if path.endswith(ext):
- return ext
- raise Exception(
- f"Unable to determine file format from file extension {path}. "
- f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
- )
- def run_command_factory(args):
- nlp = pipeline(
- task=args.task,
- model=args.model if args.model else None,
- config=args.config,
- tokenizer=args.tokenizer,
- device=args.device,
- )
- format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
- reader = PipelineDataFormat.from_str(
- format=format,
- output_path=args.output,
- input_path=args.input,
- column=args.column if args.column else nlp.default_input_names,
- overwrite=args.overwrite,
- )
- return RunCommand(nlp, reader)
- class RunCommand(BaseTransformersCLICommand):
- def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
- self._nlp = nlp
- self._reader = reader
- @staticmethod
- def register_subcommand(parser: ArgumentParser):
- run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
- run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
- run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
- run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
- run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
- run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
- run_parser.add_argument(
- "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
- )
- run_parser.add_argument(
- "--column",
- type=str,
- help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
- )
- run_parser.add_argument(
- "--format",
- type=str,
- default="infer",
- choices=PipelineDataFormat.SUPPORTED_FORMATS,
- help="Input format to read from",
- )
- run_parser.add_argument(
- "--device",
- type=int,
- default=-1,
- help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
- )
- run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
- run_parser.set_defaults(func=run_command_factory)
- def run(self):
- nlp, outputs = self._nlp, []
- for entry in self._reader:
- output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
- if isinstance(output, dict):
- outputs.append(output)
- else:
- outputs += output
- # Saving data
- if self._nlp.binary_output:
- binary_path = self._reader.save_binary(outputs)
- logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
- else:
- self._reader.save(outputs)
|