run.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # Copyright 2020 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from argparse import ArgumentParser
  15. from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline
  16. from ..utils import logging
  17. from . import BaseTransformersCLICommand
  18. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  19. def try_infer_format_from_ext(path: str):
  20. if not path:
  21. return "pipe"
  22. for ext in PipelineDataFormat.SUPPORTED_FORMATS:
  23. if path.endswith(ext):
  24. return ext
  25. raise Exception(
  26. f"Unable to determine file format from file extension {path}. "
  27. f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}"
  28. )
  29. def run_command_factory(args):
  30. nlp = pipeline(
  31. task=args.task,
  32. model=args.model if args.model else None,
  33. config=args.config,
  34. tokenizer=args.tokenizer,
  35. device=args.device,
  36. )
  37. format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format
  38. reader = PipelineDataFormat.from_str(
  39. format=format,
  40. output_path=args.output,
  41. input_path=args.input,
  42. column=args.column if args.column else nlp.default_input_names,
  43. overwrite=args.overwrite,
  44. )
  45. return RunCommand(nlp, reader)
  46. class RunCommand(BaseTransformersCLICommand):
  47. def __init__(self, nlp: Pipeline, reader: PipelineDataFormat):
  48. self._nlp = nlp
  49. self._reader = reader
  50. @staticmethod
  51. def register_subcommand(parser: ArgumentParser):
  52. run_parser = parser.add_parser("run", help="Run a pipeline through the CLI")
  53. run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run")
  54. run_parser.add_argument("--input", type=str, help="Path to the file to use for inference")
  55. run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.")
  56. run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.")
  57. run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.")
  58. run_parser.add_argument(
  59. "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)"
  60. )
  61. run_parser.add_argument(
  62. "--column",
  63. type=str,
  64. help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)",
  65. )
  66. run_parser.add_argument(
  67. "--format",
  68. type=str,
  69. default="infer",
  70. choices=PipelineDataFormat.SUPPORTED_FORMATS,
  71. help="Input format to read from",
  72. )
  73. run_parser.add_argument(
  74. "--device",
  75. type=int,
  76. default=-1,
  77. help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)",
  78. )
  79. run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.")
  80. run_parser.set_defaults(func=run_command_factory)
  81. def run(self):
  82. nlp, outputs = self._nlp, []
  83. for entry in self._reader:
  84. output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry)
  85. if isinstance(output, dict):
  86. outputs.append(output)
  87. else:
  88. outputs += output
  89. # Saving data
  90. if self._nlp.binary_output:
  91. binary_path = self._reader.save_binary(outputs)
  92. logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}")
  93. else:
  94. self._reader.save(outputs)