processing_kosmos2.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692
  1. # coding=utf-8
  2. # Copyright 2023 Microsoft Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Processor class for KOSMOS-2."""
  16. import copy
  17. import math
  18. import re
  19. from typing import List, Optional, Tuple, Union
  20. from ...image_processing_utils import BatchFeature
  21. from ...image_utils import ImageInput, is_batched
  22. from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack
  23. from ...tokenization_utils import AddedToken
  24. from ...tokenization_utils_base import BatchEncoding, TextInput
  25. BboxInput = Union[
  26. List[Tuple[int, int]],
  27. List[Tuple[float, float, float, float]],
  28. List[List[Tuple[int, int]]],
  29. List[List[Tuple[float, float, float]]],
  30. ]
  31. class Kosmos2ImagesKwargs(ImagesKwargs, total=False):
  32. bboxes: Optional[List[float]]
  33. num_image_tokens: Optional[int]
  34. first_image_token_id: Optional[int]
  35. class Kosmos2TextKwargs(TextKwargs, total=False):
  36. add_eos_token: Optional[bool]
  37. class Kosmos2ProcessorKwargs(ProcessingKwargs, total=False):
  38. text_kwargs: Kosmos2TextKwargs
  39. images_kwargs: Kosmos2ImagesKwargs
  40. _defaults = {
  41. "text_kwargs": {
  42. "add_special_tokens": True,
  43. "padding": False,
  44. "stride": 0,
  45. "return_overflowing_tokens": False,
  46. "return_special_tokens_mask": False,
  47. "return_offsets_mapping": False,
  48. "return_token_type_ids": False,
  49. "verbose": True,
  50. "add_eos_token": False,
  51. },
  52. "images_kwargs": {
  53. "num_image_tokens": 64,
  54. },
  55. }
  56. class Kosmos2Processor(ProcessorMixin):
  57. r"""
  58. Constructs an KOSMOS-2 processor which wraps a KOSMOS-2 image processor and a KOSMOS-2 tokenizer into a single
  59. processor.
  60. [`Kosmos2Processor`] offers all the functionalities of [`CLIPImageProcessor`] and some functionalities of
  61. [`XLMRobertaTokenizerFast`]. See the docstring of [`~Kosmos2Processor.__call__`] and [`~Kosmos2Processor.decode`]
  62. for more information.
  63. Args:
  64. image_processor (`CLIPImageProcessor`):
  65. An instance of [`CLIPImageProcessor`]. The image processor is a required input.
  66. tokenizer (`XLMRobertaTokenizerFast`):
  67. An instance of ['XLMRobertaTokenizerFast`]. The tokenizer is a required input.
  68. num_patch_index_tokens (`int`, *optional*, defaults to 1024):
  69. The number of tokens that represent patch indices.
  70. """
  71. attributes = ["image_processor", "tokenizer"]
  72. valid_kwargs = ["num_patch_index_tokens"]
  73. image_processor_class = "CLIPImageProcessor"
  74. tokenizer_class = "AutoTokenizer"
  75. def __init__(self, image_processor, tokenizer, num_patch_index_tokens=1024, *kwargs):
  76. tokenizer.return_token_type_ids = False
  77. self.eod_token = "</doc>"
  78. self.boi_token = "<image>"
  79. self.eoi_token = "</image>"
  80. self.eoc_token = "</chunk>"
  81. self.eol_token = "</line>"
  82. self.bop_token = "<phrase>"
  83. self.eop_token = "</phrase>"
  84. self.boo_token = "<object>"
  85. self.eoo_token = "</object>"
  86. self.dom_token = "</delimiter_of_multi_objects/>"
  87. self.grd_token = "<grounding>"
  88. self.tag_tokens = [
  89. self.eod_token,
  90. self.boi_token,
  91. self.eoi_token,
  92. self.eoc_token,
  93. self.eol_token,
  94. self.bop_token,
  95. self.eop_token,
  96. self.boo_token,
  97. self.eoo_token,
  98. self.dom_token,
  99. self.grd_token,
  100. ]
  101. self.num_patch_index_tokens = num_patch_index_tokens
  102. patch_index_tokens = [f"<patch_index_{str(x).zfill(4)}>" for x in range(self.num_patch_index_tokens)]
  103. tokens_to_add = []
  104. for token in self.tag_tokens + patch_index_tokens:
  105. tokens_to_add.append(AddedToken(token, lstrip=True, rstrip=False, normalized=False))
  106. tokenizer.add_tokens(tokens_to_add)
  107. super().__init__(image_processor, tokenizer)
  108. def __call__(
  109. self,
  110. images: ImageInput = None,
  111. text: Union[TextInput, List[TextInput]] = None,
  112. audio=None,
  113. videos=None,
  114. **kwargs: Unpack[Kosmos2ProcessorKwargs],
  115. ) -> BatchFeature:
  116. """
  117. This method uses [`CLIPImageProcessor.__call__`] method to prepare image(s) for the model, and
  118. [`XLMRobertaTokenizerFast.__call__`] to prepare text for the model.
  119. Please refer to the docstring of the above two methods for more information.
  120. The rest of this documentation shows the arguments specific to `Kosmos2Processor`.
  121. Args:
  122. bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*):
  123. The bounding bboxes associated to `texts`.
  124. num_image_tokens (`int`, *optional* defaults to 64):
  125. The number of (consecutive) places that are used to mark the placeholders to store image information.
  126. This should be the same as `latent_query_num` in the instance of `Kosmos2Config` you are using.
  127. first_image_token_id (`int`, *optional*):
  128. The token id that will be used for the first place of the subsequence that is reserved to store image
  129. information. If unset, will default to `self.tokenizer.unk_token_id + 1`.
  130. add_eos_token (`bool`, defaults to `False`):
  131. Whether or not to include `EOS` token id in the encoding when `add_special_tokens=True`.
  132. """
  133. if images is None and text is None:
  134. raise ValueError("You have to specify either images or text.")
  135. output_kwargs = self._merge_kwargs(
  136. Kosmos2ProcessorKwargs,
  137. tokenizer_init_kwargs=self.tokenizer.init_kwargs,
  138. **kwargs,
  139. )
  140. bboxes = output_kwargs["images_kwargs"].pop("bboxes", None)
  141. num_image_tokens = output_kwargs["images_kwargs"].pop("num_image_tokens", 64)
  142. first_image_token_id = output_kwargs["images_kwargs"].pop("first_image_token_id", None)
  143. add_eos_token = output_kwargs["text_kwargs"].pop("add_eos_token", False)
  144. add_special_tokens = output_kwargs["text_kwargs"]["add_special_tokens"]
  145. padding = output_kwargs["text_kwargs"]["padding"]
  146. return_tensors = output_kwargs["text_kwargs"].setdefault("return_tensors", None)
  147. encoding = BatchFeature()
  148. if images is not None:
  149. image_encoding = self.image_processor(images, **output_kwargs["images_kwargs"])
  150. encoding.update(image_encoding)
  151. if text is not None:
  152. text = self.preprocess_examples(text, images, bboxes, num_image_tokens=num_image_tokens)
  153. if add_special_tokens and not add_eos_token:
  154. if isinstance(text, str):
  155. text = f"{self.tokenizer.bos_token}{text}"
  156. elif isinstance(text, list):
  157. text = [f"{self.tokenizer.bos_token}{s}" for s in text]
  158. output_kwargs["text_kwargs"]["add_special_tokens"] = (
  159. output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token
  160. )
  161. output_kwargs["text_kwargs"]["padding"] = padding if images is None else False
  162. output_kwargs["text_kwargs"]["return_tensors"] = return_tensors if images is None else None
  163. text_encoding = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
  164. encoding.update(text_encoding)
  165. output_kwargs["text_kwargs"]["add_special_tokens"] = add_special_tokens
  166. output_kwargs["text_kwargs"]["padding"] = padding
  167. output_kwargs["text_kwargs"]["return_tensors"] = return_tensors
  168. if text is not None and images is not None:
  169. # Use the id of the first token after <unk>
  170. if first_image_token_id is None:
  171. first_image_token_id = self.tokenizer.unk_token_id + 1
  172. # To see if we need one more `0` (for `<s>`) at the beginning of `image_embeds_position_mask`.
  173. with_bos = add_special_tokens
  174. # The first (actual) `<image>` token is always at the 1st or 2nd place (after `<s>` if any). Here we look
  175. # for the second `<image>` token (which indicate the first image token).
  176. start_index = int(with_bos) + 1
  177. # Add `image_embeds_position_mask`: the leading and trailing `0` are for `boi` and `eoi` tokens. The `1` indicates
  178. # the places of image tokens.
  179. image_token_ids = list(range(first_image_token_id, first_image_token_id + num_image_tokens))
  180. base_image_embeds_position_mask = [0] + [1] * num_image_tokens + [0]
  181. # loop over `encoding["input_ids"]`
  182. input_ids = []
  183. image_embeds_position_mask = []
  184. all_input_ids = encoding["input_ids"]
  185. # not batched -> (changed to) batch of size 1
  186. if isinstance(text, str):
  187. all_input_ids = [all_input_ids]
  188. encoding["attention_mask"] = [encoding["attention_mask"]]
  189. for text_ids in all_input_ids:
  190. # change the ids for the fake `<image>` tokens in `input_ids`
  191. text_ids = text_ids[:start_index] + image_token_ids + text_ids[start_index + num_image_tokens :]
  192. input_ids.append(text_ids)
  193. mask = copy.copy(base_image_embeds_position_mask)
  194. if with_bos:
  195. # for `<s>`
  196. mask = [0] + mask
  197. # trailing part (which are not related to the image)
  198. mask += [0] * (len(text_ids) - len(mask))
  199. image_embeds_position_mask.append(mask)
  200. if isinstance(text, list):
  201. sorted_length = sorted(
  202. [(idx, len(x)) for idx, x in enumerate(text_encoding.input_ids)], key=lambda x: x[-1]
  203. )
  204. _, min_len_not_padded = sorted_length[0]
  205. idx, _ = sorted_length[-1]
  206. output_kwargs["text_kwargs"]["add_special_tokens"] = (
  207. output_kwargs["text_kwargs"]["add_special_tokens"] and add_eos_token
  208. )
  209. output_kwargs["text_kwargs"]["return_tensors"] = None
  210. text_encoding = self.tokenizer(text=[text[idx]], **output_kwargs["text_kwargs"])
  211. max_len_padded = len(text_encoding.input_ids[0])
  212. if min_len_not_padded != max_len_padded:
  213. if self.tokenizer.padding_side == "right":
  214. input_ids = [x + [self.tokenizer.pad_token_id] * (max_len_padded - len(x)) for x in input_ids]
  215. image_embeds_position_mask = [
  216. x + [0] * (max_len_padded - len(x)) for x in image_embeds_position_mask
  217. ]
  218. encoding["attention_mask"] = [
  219. x + [0] * (max_len_padded - len(x)) for x in encoding["attention_mask"]
  220. ]
  221. elif self.tokenizer.padding_side == "left":
  222. input_ids = [[self.tokenizer.pad_token_id] * (max_len_padded - len(x)) + x for x in input_ids]
  223. image_embeds_position_mask = [
  224. [0] * (max_len_padded - len(x)) + x for x in image_embeds_position_mask
  225. ]
  226. encoding["attention_mask"] = [
  227. [0] * (max_len_padded - len(x)) + x for x in encoding["attention_mask"]
  228. ]
  229. # un-batch if necessary
  230. if isinstance(text, str) and return_tensors is None:
  231. input_ids = input_ids[0]
  232. encoding["attention_mask"] = encoding["attention_mask"][0]
  233. image_embeds_position_mask = image_embeds_position_mask[0]
  234. # update (with the target tensor type if specified)
  235. encoding.update(
  236. BatchEncoding(
  237. data={
  238. "input_ids": input_ids,
  239. "attention_mask": encoding["attention_mask"],
  240. "image_embeds_position_mask": image_embeds_position_mask,
  241. },
  242. tensor_type=return_tensors,
  243. )
  244. )
  245. return encoding
  246. def _check_bboxes_for_single_text(self, bboxes):
  247. """
  248. Check `bboxes` for a single text example. It could be
  249. - `None`: no bounding box associated to a text.
  250. - A list with each element being the bounding boxes associated to one `<phrase> ... </phrase>` pair found
  251. in a text. This could be:
  252. - `None`: no bounding box associated to a `<phrase> ... </phrase>` pair.
  253. - A tuple of 2 integers: A single bounding box specified by patch indices.
  254. - A tuple of 4 float point number: A single bounding box specified by (normalized) coordinates.
  255. - A list containing the above 2 tuple types: Multiple bounding boxes for a
  256. `<phrase> ... </phrase>` pair.
  257. """
  258. if bboxes is None:
  259. return
  260. elif not isinstance(bboxes, list):
  261. raise ValueError("`bboxes` (for a single text example) should be `None` or a list.")
  262. # `bbox` is the bounding boxes for a single <phrase> </phrase> pair
  263. for bbox in bboxes:
  264. if bbox is None:
  265. continue
  266. elif not isinstance(bbox, list):
  267. bbox = [bbox]
  268. for element in bbox:
  269. if not isinstance(element, tuple) or not (
  270. (len(element) == 2 and all(isinstance(x, int) for x in element))
  271. or (len(element) == 4 and all(isinstance(x, float) for x in element))
  272. ):
  273. raise ValueError(
  274. "Each element in `bboxes` (for a single text example) should be either `None`, a tuple containing "
  275. "2 integers or 4 float point numbers, or a list containing such tuples. Also "
  276. "make sure the arguments `texts` and `bboxes` passed to `preprocess_text` are both in "
  277. "batches or both for a single example."
  278. )
  279. def _preprocess_single_example(self, text, image, bboxes, img_info_tokens):
  280. text = text.strip()
  281. if image is not None:
  282. # Add `<image> ... (fake) image tokens ... </image>`
  283. text = f"{img_info_tokens} {text}"
  284. # Add `<object> <patch_idx_xxxx> <patch_idx_yyy> </object>` after `<phrase> phrase text </phrase>`
  285. text = self._insert_patch_index_tokens(text, bboxes)
  286. return text
  287. def preprocess_examples(
  288. self,
  289. texts: Union[TextInput, List[TextInput]],
  290. images: ImageInput = None,
  291. bboxes: BboxInput = None,
  292. num_image_tokens: Optional[int] = 64,
  293. ) -> Union[str, List[str]]:
  294. """Add image and bounding box information to `texts` as image and patch index tokens.
  295. Args:
  296. texts (`Union[TextInput, List[TextInput]]`): The texts to be processed.
  297. images (`ImageInput`, *optional*): The images associated to `texts`.
  298. bboxes (`Union[List[Tuple[int]], List[Tuple[float]], List[List[Tuple[int]]], List[List[Tuple[float]]]]`, *optional*):
  299. The bounding bboxes associated to `texts`.
  300. num_image_tokens (`int`, *optional*, defaults to 64):
  301. The number of image tokens (used as latent queries). This should corresponds to the `latent_query_num`
  302. attribute in `Kosmos2Config`.
  303. Returns:
  304. `Union[TextInput, List[TextInput]]`: The processed texts with image and patch index tokens.
  305. """
  306. # These are fake `<image>` tokens enclosed between (the actual) `<image>` token and `</image>`.
  307. img_tokens = [self.boi_token] * num_image_tokens
  308. img_info_tokens = " ".join([self.boi_token] + img_tokens + [self.eoi_token])
  309. # make batch to simplify processing logic
  310. batched = True
  311. if isinstance(texts, str):
  312. batched = False
  313. texts = [texts]
  314. if images is None:
  315. images = [None] * len(texts)
  316. elif not is_batched(images):
  317. images = [images]
  318. if len(texts) != len(images):
  319. raise ValueError(
  320. f"The number of examples in `texts` and `images` should be the same. Got {len(texts)} v.s. {len(images)} instead."
  321. )
  322. if not batched:
  323. self._check_bboxes_for_single_text(bboxes)
  324. bboxes = [bboxes]
  325. elif bboxes is not None:
  326. if not isinstance(bboxes, list):
  327. raise ValueError("`bboxes` should be `None` or a list (as a batch) when `texts` is passed as a batch.")
  328. for x in bboxes:
  329. self._check_bboxes_for_single_text(x)
  330. else:
  331. bboxes = [None] * len(texts)
  332. if len(bboxes) != len(texts):
  333. raise ValueError(
  334. f"The number of examples in `texts` and `bboxes` should be the same. Got {len(texts)} v.s. {len(bboxes)} instead."
  335. )
  336. result = [
  337. self._preprocess_single_example(text, image, bbox, img_info_tokens)
  338. for text, image, bbox in zip(texts, images, bboxes)
  339. ]
  340. # un-batch if necessary
  341. if not batched:
  342. result = result[0]
  343. return result
  344. # Copied from transformers.models.blip.processing_blip.BlipProcessor.batch_decode with BertTokenizerFast->PreTrainedTokenizer
  345. def batch_decode(self, *args, **kwargs):
  346. """
  347. This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.batch_decode`]. Please
  348. refer to the docstring of this method for more information.
  349. """
  350. return self.tokenizer.batch_decode(*args, **kwargs)
  351. # Copied from transformers.models.blip.processing_blip.BlipProcessor.decode with BertTokenizerFast->PreTrainedTokenizer
  352. def decode(self, *args, **kwargs):
  353. """
  354. This method forwards all its arguments to PreTrainedTokenizer's [`~PreTrainedTokenizer.decode`]. Please refer to
  355. the docstring of this method for more information.
  356. """
  357. return self.tokenizer.decode(*args, **kwargs)
  358. def post_process_generation(self, text, cleanup_and_extract=True):
  359. caption = text.split(self.eoi_token)[-1]
  360. if cleanup_and_extract:
  361. return clean_text_and_extract_entities_with_bboxes(caption)
  362. return caption
  363. @property
  364. # Copied from transformers.models.blip.processing_blip.BlipProcessor.model_input_names
  365. def model_input_names(self):
  366. tokenizer_input_names = self.tokenizer.model_input_names
  367. image_processor_input_names = self.image_processor.model_input_names
  368. return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
  369. def _insert_patch_index_tokens(self, text: str, bboxes: Union[List[Tuple[int]], List[Tuple[float]]]) -> str:
  370. if bboxes is None or len(bboxes) == 0:
  371. return text
  372. matched_phrases = list(re.finditer(r"<phrase>.+?</phrase>", string=text))
  373. if len(matched_phrases) != len(bboxes):
  374. raise ValueError(
  375. f"The number of elements in `bboxes` should be the same as the number of `<phrase> ... </phrase>` pairs in `text`. Got {len(matched_phrases)} v.s. {len(bboxes)} instead."
  376. )
  377. # insert object's patch index tokens
  378. # the found `<phrase> ... </phrase>` pairs.
  379. curr_pos = 0
  380. buffer = []
  381. for matched, bbox in zip(matched_phrases, bboxes):
  382. _, end = matched.span()
  383. buffer.append(text[curr_pos:end])
  384. curr_pos = end
  385. # A phrase without bbox
  386. if bbox is None:
  387. continue
  388. # A phrase with a single bbox
  389. if isinstance(bbox, tuple):
  390. bbox = [bbox]
  391. patch_index_strings = []
  392. # A phrase could have multiple bboxes
  393. if not all(box is not None for box in bbox):
  394. raise ValueError(
  395. "The multiple bounding boxes for a single phrase should not contain any `None` value."
  396. )
  397. for box in bbox:
  398. patch_index_1, patch_index_2 = self._convert_bbox_to_patch_index_tokens(box)
  399. patch_index_strings.append(f"{patch_index_1} {patch_index_2}")
  400. # `bbox` being an empty list
  401. if len(patch_index_strings) == 0:
  402. continue
  403. position_str = " </delimiter_of_multi_objects/> ".join(patch_index_strings)
  404. buffer.append(f"<object> {position_str} </object>")
  405. # remaining
  406. if curr_pos < len(text):
  407. buffer.append(text[curr_pos:])
  408. text = "".join(buffer)
  409. return text
  410. def _convert_bbox_to_patch_index_tokens(
  411. self, bbox: Union[Tuple[int, int], Tuple[float, float, float, float]]
  412. ) -> Tuple[str, str]:
  413. # already computed patch indices
  414. if len(bbox) == 2:
  415. idx_1, idx_2 = bbox
  416. # bbox specified with (normalized) coordinates
  417. else:
  418. # use `self.tokenizer` to get `num_patches_per_side`
  419. num_patches_per_side = int(math.sqrt(self.num_patch_index_tokens))
  420. idx_1, idx_2 = coordinate_to_patch_index(bbox, num_patches_per_side)
  421. token_1 = f"<patch_index_{str(idx_1).zfill(4)}>"
  422. token_2 = f"<patch_index_{str(idx_2).zfill(4)}>"
  423. return token_1, token_2
  424. def coordinate_to_patch_index(bbox: Tuple[float, float, float, float], num_patches_per_side: int) -> Tuple[int, int]:
  425. """Convert a bounding box to a pair of patch indices.
  426. Args:
  427. bbox (`Tuple[float, float, float, float]`):
  428. The 4 coordinates of the bounding box, with the format being (x1, y1, x2, y2) specifying the upper-left and
  429. lower-right corners of the box. It should have x2 > x1 and y2 > y1.
  430. num_patches_per_side (`int`): the number of patches along each side.
  431. Returns:
  432. `Tuple[int, int]`: A pair of patch indices representing the upper-left patch and lower-right patch.
  433. """
  434. (x1, y1, x2, y2) = bbox
  435. if not (x2 > x1 and y2 > y1):
  436. raise ValueError("The coordinates in `bbox` should be `(x1, y1, x2, y2)` with `x2 > x1` and `y2 > y1`.")
  437. ul_x = math.floor(x1 * num_patches_per_side)
  438. ul_y = math.floor(y1 * num_patches_per_side)
  439. lr_x = math.ceil(x2 * num_patches_per_side - 1)
  440. lr_y = math.ceil(y2 * num_patches_per_side - 1)
  441. ul_idx = ul_y * num_patches_per_side + ul_x
  442. lr_idx = lr_y * num_patches_per_side + lr_x
  443. return ul_idx, lr_idx
  444. # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L35C1-L75C38
  445. # (with format modifications)
  446. def patch_index_to_coordinate(ul_idx: int, lr_idx: int, num_patches_per_side: int):
  447. """
  448. Given a grid of length `num_patches_per_side` and the indices of the upper-left and lower-right corners of a
  449. bounding box, returns the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2).
  450. Args:
  451. ul_idx (`int`): the index of the grid cell that corresponds to the upper-left corner of the bounding box.
  452. lr_idx (`int`): the index of the grid cell that corresponds to the lower-right corner of the bounding box.
  453. num_patches_per_side (`int`): the number of patches along each side.
  454. Returns:
  455. `Tuple[float]`: the normalized coordinates of the bounding box, in the form (x1, y1, x2, y2).
  456. """
  457. # Compute the size of each cell in the grid
  458. cell_size = 1.0 / num_patches_per_side
  459. # Compute the x and y indices of the upper-left and lower-right corners of the bounding box
  460. ul_x = ul_idx % num_patches_per_side
  461. ul_y = ul_idx // num_patches_per_side
  462. lr_x = lr_idx % num_patches_per_side
  463. lr_y = lr_idx // num_patches_per_side
  464. # Compute the normalized coordinates of the bounding box
  465. if ul_idx == lr_idx:
  466. x1 = ul_x * cell_size
  467. y1 = ul_y * cell_size
  468. x2 = lr_x * cell_size + cell_size
  469. y2 = lr_y * cell_size + cell_size
  470. elif ul_x == lr_x or ul_y == lr_y:
  471. x1 = ul_x * cell_size
  472. y1 = ul_y * cell_size
  473. x2 = lr_x * cell_size + cell_size
  474. y2 = lr_y * cell_size + cell_size
  475. else:
  476. x1 = ul_x * cell_size + cell_size / 2
  477. y1 = ul_y * cell_size + cell_size / 2
  478. x2 = lr_x * cell_size + cell_size / 2
  479. y2 = lr_y * cell_size + cell_size / 2
  480. return x1, y1, x2, y2
  481. # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L4-L33
  482. # (with format modifications)
  483. def extract_entities_with_patch_indices(text):
  484. """Extract entities contained in `text`. The bounding bboxes is given in the form of patch indices.
  485. This functioin is only intended to be used within `clean_text_and_extract_entities_with_bboxes` where further
  486. processing happens, including converting to normalized coordinates and whitespace character cleaning up.
  487. Examples:
  488. ```python
  489. >>> text = "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>."
  490. >>> entities = extract_entities_with_patch_indices(text)
  491. >>> entities
  492. [(' a snowman', (31, 41), [(44, 863)]), (' a fire', (130, 137), [(5, 911)])]
  493. ```"""
  494. # The regular expression pattern for matching the required formats
  495. pattern = r"(?:(<phrase>([^<]+)</phrase>))?<object>((?:<patch_index_\d+><patch_index_\d+></delimiter_of_multi_objects/>)*<patch_index_\d+><patch_index_\d+>)</object>"
  496. # Find all matches in the given string
  497. matches = re.finditer(pattern, text)
  498. # Initialize an empty list to store the valid patch_index combinations
  499. entities_with_patch_indices = []
  500. for match in matches:
  501. # span of a `phrase` that is between <phrase> and </phrase>
  502. span = match.span(2)
  503. phrase_tag, phrase, match_content = match.groups()
  504. if not phrase_tag:
  505. phrase = None
  506. # We take the starting position of `<object>`
  507. span = (match.span(0)[0], match.span(0)[0])
  508. # Split the match_content by the delimiter to get individual patch_index pairs
  509. patch_index_pairs = match_content.split("</delimiter_of_multi_objects/>")
  510. entity_bboxes = []
  511. for pair in patch_index_pairs:
  512. # Extract the xxxx and yyyy values from the patch_index pair
  513. x = re.search(r"<patch_index_(\d+)>", pair)
  514. y = re.search(r"<patch_index_(\d+)>", pair[1:])
  515. if x and y:
  516. if phrase:
  517. entity_bboxes.append((int(x.group(1)), int(y.group(1))))
  518. else:
  519. entity_bboxes.append((int(x.group(1)), int(y.group(1))))
  520. if phrase:
  521. entities_with_patch_indices.append((phrase, span, entity_bboxes))
  522. else:
  523. for bbox in entity_bboxes:
  524. # fake entity name
  525. entity = f"<patch_index_{bbox[0]}><patch_index_{bbox[1]}>"
  526. entities_with_patch_indices.append((entity, span, [bbox]))
  527. return entities_with_patch_indices
  528. def adjust_entity_positions(entity, text):
  529. """Adjust the positions of the entities in `text` to be relative to the text with special fields removed."""
  530. entity_name, (start, end) = entity
  531. # computed the length of strings with special fields (tag tokens, patch index tokens, etc.) removed
  532. adjusted_start = len(re.sub("<.*?>", "", text[:start]))
  533. adjusted_end = len(re.sub("<.*?>", "", text[:end]))
  534. adjusted_entity = (entity_name, (adjusted_start, adjusted_end))
  535. return adjusted_entity
  536. def _cleanup_spaces(text, entities):
  537. """Remove the spaces around the text and the entities in it."""
  538. new_text = text.strip()
  539. leading_spaces = len(text) - len(text.lstrip())
  540. new_entities = []
  541. for entity_name, (start, end), bboxes in entities:
  542. entity_name_leading_spaces = len(entity_name) - len(entity_name.lstrip())
  543. entity_name_trailing_spaces = len(entity_name) - len(entity_name.rstrip())
  544. start = start - leading_spaces + entity_name_leading_spaces
  545. end = end - leading_spaces - entity_name_trailing_spaces
  546. entity_name = entity_name.strip()
  547. new_entities.append((entity_name, (start, end), bboxes))
  548. return new_text, new_entities
  549. # copied from https://github.com/microsoft/unilm/blob/97e4923e97d3ee10b57e97013556e3fd0d207a9b/kosmos-2/demo/decode_string.py#L77-L87
  550. # (with format modifications)
  551. def clean_text_and_extract_entities_with_bboxes(text, num_patches_per_side=32):
  552. """Remove the tag tokens from `text`, extract entities in it with some cleaning up of white characters.
  553. Examples:
  554. ```python
  555. >>> text = "<grounding> An image of<phrase> a snowman</phrase><object><patch_index_0044><patch_index_0863></object> warming himself by<phrase> a fire</phrase><object><patch_index_0005><patch_index_0911></object>."
  556. >>> clean_text, entities = clean_text_and_extract_entities_with_bboxes(text)
  557. >>> clean_text
  558. 'An image of a snowman warming himself by a fire.'
  559. >>> entities
  560. [('a snowman', (12, 21), [(0.390625, 0.046875, 0.984375, 0.828125)]), ('a fire', (41, 47), [(0.171875, 0.015625, 0.484375, 0.890625)])]
  561. ```"""
  562. # remove special fields (tag tokens, patch index tokens, etc.)
  563. processed_text = re.sub("<.*?>", "", text)
  564. entities_with_patch_indices = extract_entities_with_patch_indices(text)
  565. entities = []
  566. for item in entities_with_patch_indices:
  567. entity, bboxes = item[0:2], item[2]
  568. adjusted_entity = adjust_entity_positions(entity, text)
  569. bboxes_in_coords = [patch_index_to_coordinate(bbox[0], bbox[1], num_patches_per_side) for bbox in bboxes]
  570. entities.append(adjusted_entity + (bboxes_in_coords,))
  571. return _cleanup_spaces(processed_text, entities)