document_question_answering.py 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import re
  17. import numpy as np
  18. import torch
  19. from ..models.auto import AutoProcessor
  20. from ..models.vision_encoder_decoder import VisionEncoderDecoderModel
  21. from ..utils import is_vision_available
  22. from .tools import PipelineTool
  23. if is_vision_available():
  24. from PIL import Image
  25. class DocumentQuestionAnsweringTool(PipelineTool):
  26. default_checkpoint = "naver-clova-ix/donut-base-finetuned-docvqa"
  27. description = "This is a tool that answers a question about an document (pdf). It returns a string that contains the answer to the question."
  28. name = "document_qa"
  29. pre_processor_class = AutoProcessor
  30. model_class = VisionEncoderDecoderModel
  31. inputs = {
  32. "document": {
  33. "type": "image",
  34. "description": "The image containing the information. Can be a PIL Image or a string path to the image.",
  35. },
  36. "question": {"type": "string", "description": "The question in English"},
  37. }
  38. output_type = "string"
  39. def __init__(self, *args, **kwargs):
  40. if not is_vision_available():
  41. raise ValueError("Pillow must be installed to use the DocumentQuestionAnsweringTool.")
  42. super().__init__(*args, **kwargs)
  43. def encode(self, document: "Image", question: str):
  44. task_prompt = "<s_docvqa><s_question>{user_input}</s_question><s_answer>"
  45. prompt = task_prompt.replace("{user_input}", question)
  46. decoder_input_ids = self.pre_processor.tokenizer(
  47. prompt, add_special_tokens=False, return_tensors="pt"
  48. ).input_ids
  49. if isinstance(document, str):
  50. img = Image.open(document).convert("RGB")
  51. img_array = np.array(img).transpose(2, 0, 1)
  52. document = torch.from_numpy(img_array)
  53. pixel_values = self.pre_processor(document, return_tensors="pt").pixel_values
  54. return {"decoder_input_ids": decoder_input_ids, "pixel_values": pixel_values}
  55. def forward(self, inputs):
  56. return self.model.generate(
  57. inputs["pixel_values"].to(self.device),
  58. decoder_input_ids=inputs["decoder_input_ids"].to(self.device),
  59. max_length=self.model.decoder.config.max_position_embeddings,
  60. early_stopping=True,
  61. pad_token_id=self.pre_processor.tokenizer.pad_token_id,
  62. eos_token_id=self.pre_processor.tokenizer.eos_token_id,
  63. use_cache=True,
  64. num_beams=1,
  65. bad_words_ids=[[self.pre_processor.tokenizer.unk_token_id]],
  66. return_dict_in_generate=True,
  67. ).sequences
  68. def decode(self, outputs):
  69. sequence = self.pre_processor.batch_decode(outputs)[0]
  70. sequence = sequence.replace(self.pre_processor.tokenizer.eos_token, "")
  71. sequence = sequence.replace(self.pre_processor.tokenizer.pad_token, "")
  72. sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
  73. sequence = self.pre_processor.token2json(sequence)
  74. return sequence["answer"]