processing_idefics.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. Processor class for IDEFICS.
  17. """
  18. from typing import Callable, Dict, List, Optional, Union
  19. from urllib.parse import urlparse
  20. from ...feature_extraction_utils import BatchFeature
  21. from ...processing_utils import (
  22. ImagesKwargs,
  23. ProcessingKwargs,
  24. ProcessorMixin,
  25. TextKwargs,
  26. Unpack,
  27. _validate_images_text_input_order,
  28. )
  29. from ...tokenization_utils_base import PreTokenizedInput, TextInput
  30. from ...utils import is_tf_available, is_torch_available
  31. from ...utils.deprecation import deprecate_kwarg
  32. if is_torch_available():
  33. import torch
  34. if is_tf_available():
  35. import tensorflow as tf
  36. IMAGE_TOKEN = "<image>"
  37. class IdeficsImagesKwargs(ImagesKwargs, total=False):
  38. transform: Optional[Callable]
  39. image_size: Optional[Dict[str, int]]
  40. image_mean: Optional[Union[float, List[float]]]
  41. image_std: Optional[Union[float, List[float]]]
  42. class IdeficsTextKwargs(TextKwargs, total=False):
  43. add_eos_token: Optional[bool]
  44. add_end_of_utterance_token: Optional[bool]
  45. class IdeficsProcessorKwargs(ProcessingKwargs, total=False):
  46. text_kwargs: IdeficsTextKwargs
  47. images_kwargs: IdeficsImagesKwargs
  48. _defaults = {
  49. "text_kwargs": {
  50. "add_special_tokens": False,
  51. "padding": "longest",
  52. "add_eos_token": False,
  53. },
  54. "images_kwargs": {},
  55. "common_kwargs": {"return_tensors": "pt"},
  56. }
  57. # copied from m4.training.packing
  58. def incremental_to_binary_attention_mask(incremental_mask, return_tensors, num_classes=-1):
  59. # Set elements >= num_classes to -1
  60. if num_classes != -1:
  61. if return_tensors == "pt":
  62. incremental_mask[incremental_mask >= num_classes] = -1
  63. elif return_tensors == "tf":
  64. incremental_mask = tf.where(incremental_mask >= num_classes, -1, incremental_mask)
  65. # Create mask for negative values
  66. if return_tensors == "pt":
  67. negatives = incremental_mask == -1
  68. incremental_mask[negatives] = 0
  69. attn_mask = torch.nn.functional.one_hot(incremental_mask, num_classes=num_classes)
  70. attn_mask[negatives, :] = 0
  71. elif return_tensors == "tf":
  72. negatives = tf.equal(incremental_mask, -1)
  73. incremental_mask = tf.where(negatives, 0, incremental_mask)
  74. attn_mask = tf.one_hot(incremental_mask, depth=num_classes)
  75. # Reshape 'negatives' to add an extra dimension, making it [batch_size, seq_length, 1]
  76. negatives_expanded = tf.expand_dims(negatives, -1)
  77. attn_mask = tf.where(negatives_expanded, tf.zeros_like(attn_mask), attn_mask)
  78. return attn_mask
  79. # copied from m4.training.packing
  80. def image_attention_mask_for_packed_input_ids(input_ids, tokenizer, return_tensors):
  81. if return_tensors == "pt":
  82. return image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer)
  83. elif return_tensors == "tf":
  84. return image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer)
  85. def image_attention_mask_for_packed_input_ids_pt(input_ids, tokenizer):
  86. image_attention_mask = torch.full_like(input_ids, fill_value=-1)
  87. next_image_attention_mask = torch.full_like(input_ids, fill_value=-1)
  88. image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
  89. eod_token_id = tokenizer.eos_token_id
  90. for batch_idx in range(input_ids.size(0)):
  91. count = -1
  92. seen_eod = False
  93. for idx, token_id in enumerate(input_ids[batch_idx]):
  94. if token_id == image_token_id:
  95. count += 1
  96. image_attention_mask[batch_idx][idx] = count
  97. seen_eod = False
  98. else:
  99. image_attention_mask[batch_idx][idx] = count
  100. if seen_eod:
  101. image_attention_mask[batch_idx][idx] = -1
  102. if token_id == eod_token_id:
  103. seen_eod = True
  104. for batch_idx in range(input_ids.size(0)):
  105. count = -1
  106. seen_eod = False
  107. for idx in range(input_ids[batch_idx].size(0) - 1, -1, -1):
  108. token_id = input_ids[batch_idx][idx]
  109. if token_id == image_token_id:
  110. count += 1
  111. next_image_attention_mask[batch_idx][idx] = count
  112. seen_eod = False
  113. else:
  114. next_image_attention_mask[batch_idx][idx] = count
  115. if token_id == eod_token_id:
  116. seen_eod = True
  117. if seen_eod:
  118. next_image_attention_mask[batch_idx][idx] = -1
  119. non_negative_indices = next_image_attention_mask[batch_idx] != -1
  120. next_image_attention_mask[batch_idx][non_negative_indices] -= count
  121. next_image_attention_mask[batch_idx][non_negative_indices] *= -1
  122. return image_attention_mask, next_image_attention_mask
  123. def image_attention_mask_for_packed_input_ids_tf(input_ids, tokenizer):
  124. image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
  125. eod_token_id = tokenizer.eos_token_id
  126. batch_size = tf.shape(input_ids)[0]
  127. image_attention_mask = tf.fill(tf.shape(input_ids), -1)
  128. next_image_attention_mask = tf.fill(tf.shape(input_ids), -1)
  129. for batch_idx in range(batch_size):
  130. count = -1
  131. seen_eod = False
  132. seq_length = tf.shape(input_ids)[1]
  133. for idx in range(seq_length - 1, -1, -1):
  134. token_id = input_ids[batch_idx, idx].numpy()
  135. if token_id == image_token_id:
  136. count += 1
  137. indices = [[batch_idx, idx]]
  138. updates = [count]
  139. image_attention_mask = tf.tensor_scatter_nd_update(image_attention_mask, indices, updates)
  140. next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
  141. elif token_id == eod_token_id and not seen_eod:
  142. seen_eod = True
  143. count = 0
  144. indices = [[batch_idx, idx]]
  145. updates = [count]
  146. next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
  147. if seen_eod and token_id != eod_token_id:
  148. indices = [[batch_idx, idx]]
  149. updates = [-1]
  150. next_image_attention_mask = tf.tensor_scatter_nd_update(next_image_attention_mask, indices, updates)
  151. return image_attention_mask, next_image_attention_mask
  152. def is_url(string):
  153. """Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately
  154. invalidated the url"""
  155. if " " in string:
  156. return False
  157. result = urlparse(string)
  158. return all([result.scheme, result.netloc])
  159. class IdeficsProcessor(ProcessorMixin):
  160. r"""
  161. Constructs a IDEFICS processor which wraps a LLama tokenizer and IDEFICS image processor into a single processor.
  162. [`IdeficsProcessor`] offers all the functionalities of [`IdeficsImageProcessor`] and [`LlamaTokenizerFast`]. See
  163. the docstring of [`~IdeficsProcessor.__call__`] and [`~IdeficsProcessor.decode`] for more information.
  164. Args:
  165. image_processor (`IdeficsImageProcessor`):
  166. An instance of [`IdeficsImageProcessor`]. The image processor is a required input.
  167. tokenizer (`LlamaTokenizerFast`):
  168. An instance of [`LlamaTokenizerFast`]. The tokenizer is a required input.
  169. image_size (`int`, *optional*, defaults to 224): Image size (assuming a square image)
  170. """
  171. attributes = ["image_processor", "tokenizer"]
  172. valid_kwargs = ["image_size", "add_end_of_utterance_token"]
  173. image_processor_class = "IdeficsImageProcessor"
  174. tokenizer_class = "LlamaTokenizerFast"
  175. def __init__(self, image_processor, tokenizer=None, image_size=224, add_end_of_utterance_token=None, **kwargs):
  176. if image_processor is None:
  177. raise ValueError("You need to specify an `image_processor`.")
  178. if tokenizer is None:
  179. raise ValueError("You need to specify a `tokenizer`.")
  180. super().__init__(image_processor, tokenizer)
  181. self.current_processor = self.image_processor
  182. self.image_token_id = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN)
  183. self.default_image_dims = (
  184. self.image_processor.image_num_channels,
  185. self.image_processor.image_size,
  186. self.image_processor.image_size,
  187. )
  188. self.tokenizer_was_trained_with_end_of_utterance_token = (
  189. True
  190. if "<end_of_utterance>" in self.tokenizer.special_tokens_map.get("additional_special_tokens", [])
  191. else False
  192. )
  193. @deprecate_kwarg(old_name="prompts", version="5.0.0", new_name="text", raise_if_both_names=True)
  194. def __call__(
  195. self,
  196. images=None,
  197. text: Union[
  198. TextInput,
  199. PreTokenizedInput,
  200. List[TextInput],
  201. List[PreTokenizedInput],
  202. List[List[TextInput]],
  203. List[List[PreTokenizedInput]],
  204. ] = None,
  205. audio=None,
  206. videos=None,
  207. **kwargs: Unpack[IdeficsProcessorKwargs],
  208. ) -> BatchFeature:
  209. """This method takes batched or non-batched prompts made of text and images and converts them into prompts that
  210. the model was trained on and prepares the image pixel values for the model to process.
  211. Args:
  212. images (`Union[PIL.Image, str, List[PIL.Image], List[str]]`):
  213. either a single image or a batched list of images - can be passed in when text contains only text prompts,
  214. in order to use the image-text-to-text behavior.
  215. text (`Union[List[TextInput], [List[List[TextInput]]]]`):
  216. either a single prompt or a batched list of prompts - see the detailed description immediately after
  217. the end of the arguments doc section.
  218. return_tensors (`str` or `TensorType`, *optional*, defaults to `TensorType.PYTORCH`):
  219. The type of tensors to return. Can be one of:
  220. - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
  221. Returns:
  222. a dict with entries: `input_ids`, `attention_mask`, `pixel_values`, `image_attention_mask` which can be
  223. directly passed to `model.generate`
  224. Detailed explanation:
  225. Each entry in `text` is either a text to be passed as is or an image that will be processed.
  226. An image can be either an image object (`PIL.Image`) or a url from which the image can be retrieved.
  227. When the processor encounters an image it'll inject `<fake_token_around_image><image><fake_token_around_image>`
  228. entry into the prompt.
  229. Example:
  230. ```python
  231. checkpoint = "HuggingFaceM4/idefics-9b"
  232. processor = AutoProcessor.from_pretrained(checkpoint)
  233. url = "https://hips.hearstapps.com/hmg-prod/images/cute-photos-of-cats-in-grass-1593184777.jpg"
  234. img = processor.image_processor.fetch_images([url])[0]
  235. prompts = [
  236. "User:",
  237. img,
  238. "Describe this image.\nAssistant: An image of two kittens in grass.\n",
  239. "User:",
  240. "https://hips.hearstapps.com/hmg-prod/images/dog-puns-1581708208.jpg",
  241. "Describe this image.\nAssistant:",
  242. ]
  243. inputs = processor(text=prompts, return_tensors="pt")
  244. generated_ids = model.generate(**inputs, max_length=100)
  245. generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
  246. ```
  247. In this example the `prompts` will be converted into:
  248. ```
  249. <s>User:<fake_token_around_image><image><fake_token_around_image>Describe this image.
  250. Assistant: An image of two kittens in grass.
  251. User:<fake_token_around_image><image><fake_token_around_image>Describe this image.
  252. Assistant:'
  253. ```
  254. and the two images will be massaged using [`IdeficsImageProcessor.__call__`] method and placed inside the
  255. `pixel_values` dict entry of the return value.
  256. This example also examplifies that images can be passed as objects or as text urls. It can be seen that the
  257. first image is passed as object and the second one as a url.
  258. To do training do:
  259. ```python
  260. image_transform = transforms.Compose(
  261. [
  262. transforms.RandomResizedCrop(
  263. (w, h), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC
  264. ),
  265. transforms.ToTensor(),
  266. transforms.Normalize(mean=self.image_mean, std=self.image_std),
  267. ]
  268. )
  269. inputs = processor(text=prompts, transform=image_transform, return_tensors="pt")
  270. ```
  271. In order to help debug prompt generation enable `debug=True` which will show you what's happening.
  272. """
  273. if images is None and text is None:
  274. raise ValueError("You need to specify either `text` or `images` and `text`.")
  275. # check if images and text inputs are reversed for BC
  276. images, text = _validate_images_text_input_order(images, text)
  277. if images is None:
  278. # assuming the user wants to use the old behavior with prompts as the only argument
  279. prompts = text
  280. elif text is not None:
  281. # Assuming image-text-to-text behavior:
  282. # Check if batched images are provided
  283. if not isinstance(images, (list, tuple)):
  284. images = [images]
  285. if isinstance(text, str):
  286. text = [text]
  287. # Check if batched images and text are in the correct format
  288. if isinstance(text, (list, tuple)) and len(text) != len(images):
  289. raise ValueError(
  290. "When providing both images and text arguments, the number of text prompts should be the same as the number of images."
  291. "If you want to have several images per prompt, images should be nested as such: images=[[img1, img2], [img3, img4], ...] for text=[prompt1, prompt2, ...]."
  292. )
  293. # Check that only text is present in the prompts
  294. if not all(isinstance(i, str) for i in text):
  295. raise ValueError("When using the image-text-to-text behavior, the prompts should only contain text.")
  296. if isinstance(images[0], (list, tuple)):
  297. # if nested images, nest text as well
  298. text = [[i] for i in text]
  299. prompts = list(zip(images, text))
  300. output_kwargs = self._merge_kwargs(
  301. IdeficsProcessorKwargs,
  302. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  303. **kwargs,
  304. )
  305. add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
  306. add_end_of_utterance_token = output_kwargs["text_kwargs"].pop("add_end_of_utterance_token", None)
  307. # if the value isn't overriden by the user, check if the tokenizer was trained with this token and then use it
  308. if add_end_of_utterance_token is None:
  309. add_end_of_utterance_token = self.tokenizer_was_trained_with_end_of_utterance_token
  310. # turn non-batched prompts into batched
  311. if not any(isinstance(i, (list, tuple)) for i in prompts):
  312. prompts = [prompts]
  313. fake_token = "<fake_token_around_image>"
  314. image_token = "<image>"
  315. end_of_utterance_token = "<end_of_utterance>"
  316. def image_tokens(last_was_image):
  317. if last_was_image:
  318. return image_token + fake_token
  319. else:
  320. return fake_token + image_token + fake_token
  321. all_prompts = []
  322. all_images = []
  323. for sample in prompts:
  324. # the model was trained on samples starting with <s>
  325. full_text = f"{self.tokenizer.bos_token}"
  326. # an image can either be an image object in the item or the url, everything else is a verbatim prompt text
  327. image_objects = []
  328. last_was_image = False
  329. last_was_text = False
  330. for i, item in enumerate(sample):
  331. if i > 0:
  332. last_was_text = True if not last_was_image else False
  333. if isinstance(item, str):
  334. item = item.strip(" ")
  335. if is_url(item):
  336. image = self.image_processor.fetch_images(item)
  337. full_text += image_tokens(last_was_image)
  338. image_objects.append(image)
  339. last_was_image = True
  340. else:
  341. # we add end_of_utterance_token between each subsequent text prompts (but not at the last one!)
  342. if add_end_of_utterance_token and last_was_text:
  343. full_text += end_of_utterance_token
  344. full_text += item
  345. last_was_image = False
  346. else:
  347. # must be an image obj
  348. full_text += image_tokens(last_was_image)
  349. image_objects.append(item)
  350. last_was_image = True
  351. if add_eos_token:
  352. full_text += self.tokenizer.eos_token
  353. image_objects = self.image_processor(image_objects, **output_kwargs["images_kwargs"])
  354. all_prompts.append(full_text)
  355. all_images.append(image_objects)
  356. # For BC
  357. return_tensors = output_kwargs["text_kwargs"].pop("return_tensors", "pt")
  358. text_encoding = self.tokenizer(all_prompts, **output_kwargs["text_kwargs"])
  359. all_texts = text_encoding["input_ids"]
  360. all_attention_masks = text_encoding["attention_mask"]
  361. # max_num_images has to be at least 1 even when there are no images
  362. max_num_images = max(len(x) for x in all_images)
  363. max_num_images = max(1, max_num_images)
  364. at_least_one_image = sum(len(x) for x in all_images) > 0
  365. output_input_ids = []
  366. output_images = []
  367. output_attention_masks = []
  368. for text_single, attention_mask, extracted_images in zip(all_texts, all_attention_masks, all_images):
  369. padded_input_ids = text_single
  370. image_count = padded_input_ids.count(self.image_token_id)
  371. local_max_num_images = min(image_count, max_num_images)
  372. current_images = extracted_images[:local_max_num_images]
  373. if len(current_images) > 0:
  374. if return_tensors == "pt":
  375. padded_image_tensor = torch.zeros(max_num_images, *current_images.size()[1:])
  376. padded_image_tensor[: current_images.size(0)] = current_images
  377. elif return_tensors == "tf":
  378. # Assuming current_images is a TensorFlow tensor
  379. # Get the shape of current_images, excluding the first dimension
  380. image_shape = tf.shape(current_images)[1:]
  381. # Create a shape for the padded_image_tensor
  382. padded_shape = tf.concat([[max_num_images], image_shape], axis=0)
  383. # Create the padded_image_tensor of zeros
  384. padded_image_tensor = tf.zeros(padded_shape, dtype=current_images.dtype)
  385. # Get the number of images (assuming current_images has shape [num_images, height, width, channels])
  386. num_images = tf.shape(current_images)[0]
  387. # Update the padded_image_tensor with the values from current_images
  388. indices = tf.reshape(tf.range(num_images), (-1, 1))
  389. updates = current_images
  390. padded_image_tensor = tf.tensor_scatter_nd_update(padded_image_tensor, indices, updates)
  391. else:
  392. if return_tensors == "pt":
  393. padded_image_tensor = torch.zeros(max_num_images, *self.default_image_dims)
  394. elif return_tensors == "tf":
  395. padded_image_tensor = tf.zeros((max_num_images, *self.default_image_dims))
  396. output_images.append(padded_image_tensor)
  397. if return_tensors == "pt":
  398. output_input_ids.append(torch.tensor(padded_input_ids))
  399. output_attention_masks.append(torch.tensor(attention_mask))
  400. elif return_tensors == "tf":
  401. output_input_ids.append(tf.convert_to_tensor(padded_input_ids, dtype=tf.int32))
  402. output_attention_masks.append(attention_mask)
  403. if return_tensors == "pt":
  404. output_input_ids = torch.stack(output_input_ids)
  405. output_images = torch.stack(output_images)
  406. output_attention_masks = torch.stack(output_attention_masks)
  407. elif return_tensors == "tf":
  408. output_input_ids = tf.stack(output_input_ids)
  409. output_images = tf.stack(output_images)
  410. output_attention_masks = tf.stack(output_attention_masks)
  411. if at_least_one_image:
  412. image_attention_mask, _ = image_attention_mask_for_packed_input_ids(
  413. output_input_ids, self.tokenizer, return_tensors
  414. )
  415. image_attention_mask = incremental_to_binary_attention_mask(
  416. image_attention_mask, return_tensors, num_classes=max_num_images
  417. )
  418. else:
  419. # in full language mode we set the image mask to all-0s
  420. if return_tensors == "pt":
  421. image_attention_mask = torch.zeros(
  422. output_input_ids.shape[0], output_input_ids.shape[1], 1, dtype=torch.bool
  423. )
  424. elif return_tensors == "tf":
  425. image_attention_mask = tf.zeros(
  426. (output_input_ids.shape[0], output_input_ids.shape[1], 1), dtype=tf.bool
  427. )
  428. return BatchFeature(
  429. data={
  430. "input_ids": output_input_ids,
  431. "attention_mask": output_attention_masks,
  432. "pixel_values": output_images,
  433. "image_attention_mask": image_attention_mask,
  434. }
  435. )
  436. def batch_decode(self, *args, **kwargs):
  437. """
  438. This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
  439. refer to the docstring of this method for more information.
  440. """
  441. return self.tokenizer.batch_decode(*args, **kwargs)
  442. def decode(self, *args, **kwargs):
  443. """
  444. This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
  445. the docstring of this method for more information.
  446. """
  447. return self.tokenizer.decode(*args, **kwargs)
  448. @property
  449. def model_input_names(self):
  450. tokenizer_input_names = self.tokenizer.model_input_names
  451. image_processor_input_names = self.image_processor.model_input_names
  452. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))